Disclaimer: this is not final. If you find mistakes, please report them.

Part IX: Revisiting tANS/FSE

This post is part of a series on entropy coding and writing our own Zstandard library in Rust.

Part I · II · III · IV · V · VI · VII · VIII · IX · X

It’s been a bit more than I year when I started this project. At first, I thought it’d be fairly straightforward, and I’d get to a decent prototype reasonably fast. But then life happened and this was shelved along with many other projects.

Since I’m unearthing this project, I also wanted to have a fresh look at the theory behing tANS/FSE (see Part IV for an introduction). We described this algorithm as a convolutional coder that improves on Huffmann coding by keeping state and “amortising” (not always outputting something for every input). Hence the name FSE makes sense: this is entropy coding and decoding by means of a finite state machine. But we have so far avoided entirely the “ANS” part, and that’s what I want to discuss in this post.

It will also show how one can have a full working ANS compress/decompress if we forget for a moment about Zstandard’s gunk, idiosyncrasies, and optimisations. This makes this post mostly self-contained: a welcome thing if you don’t want to read through all the 8 previous pieces!

By the end of this post, we’ll have written, in Python and in Rust, a compressor and decompressor (a codec) that is reasonably efficient and entropy-optimal, using only basic arithmetics.

Is that enough to keep you interested?

Representing integers

Say I have some number \(x\): how should I write it down? The standard way to answer this question for humans is to write the decimal digits of \(x\), say 12345. The standard computer way to write the same number is to use binary digits instead: 11000000111001. Let’s be explicit about the radix being used:

\[x = 12345_{10} = 11000000111001_2.\]

Now we could be cheeky and write one digit in base 2, the next in base 3, the next in base 4 etc. so that the number written `4110’ in that system represents

\[\begin{align*} 4110_{(5432)} & = 4 \cdot (4 \cdot 3 \cdot 2) + 1 \cdot (3\cdot 2) + 1 \cdot (2) + 0 \\ & = 96 + 6 + 2 \\ & = 104. \end{align*}\]

This is an example of a mixed radix numeral system. You’d think such systems are less practical than the usual fixed radix systems we use everyday. But how to we measure time? Milliseconds (base 1000), seconds (base 60), minutes (base 60), hours (base 24), day… If you were to translate UTC time into milliseconds or back, you’d be doing the kind of computations I’m talking about.

In general, let \(B = (b_1, b_2, \dotsc)\) be a list of integers larger than 2, we can write an integer in “\(B\)-ary”; and given a “\(B\)-ary” list of digits, we can reconstruct \(x\). All good. What does it have to do with compression?

def from_mixed_radix(digits, bases):
    if len(digits) == 0:
        return 0
    return digits[0] + bases[0] * from_mixed_radix(digits[1:], bases[1:])

# Digits and bases are given right-to-left
print(from_mixed_radix([0, 1, 1, 4], [2, 3, 4, 5]))
# > 104

Imagine you are given a sequence of small integers \(a_1, a_2, \dotsc, a_k\), and you know in advance that each \(a_i\) is some value uniformly distributed between \(0\) and some value \(b_i - 1\). If you were to associate a codeword to each \(a_i\) — the way Huffman coding works — you would be wasting a lot of space! For instance, if \(a_1 \in \{0, \dotsc, 9\}\) then you’d need a 4 bits codeword, when slightly more than 3 should (in theory) suffice.

With mixed radix representation we can simply consider the number \(x\) whose digits in base \(B = (b_1, \dotsc, b_k)\) are the \(a_1, \dotsc, a_k\) — and this is information-theoretically optimal, assuming that the \(a_i\) are uniformly distributed in their range.

So there we have it, optimal compression. That sounds nice. Provided the assumption holds. Which of course it doesn’t for real-world data, does it?

Sampling and storing digits

In the real world, a digit \(a_i\) may follow any arbitrary distribution \(p_i\) over \(\{0, \dotsc, b_i - 1\}\). How would we go if we were to produce such numbers?

The traditional way is as follows: notice that \(\sum_{u=0}^{b_i - 1} p_i(u) = 1\); therefore we can take the interval \([0, 1]\) and partition it in regions

  • \(R_0\) between \(0\) and \(p_i(1)\)
  • \(R_1\) between \(p_i(1)\) and \(p_i(1) + p_i(2)\)
  • \(R_2\) between \(p_i(1)+ p_i(2)\) and \(p_i(1) + p_i(2) + p_i(3)\)
  • etc.
  • \(R_{b_i-1}\) between \(p_i(1) + \cdots + p_i(b_i - 2)\) and \(1\).

The length of each region \(R_u\) is be design exactly \(p_i(u)\).

The bottom line is this: if we take a random value \(z\) uniformly in the interval \([0, 1]\) then it falls in the region \(R_u\) with probability \(p_i(u)\). We can do this for each digit.

Now, from an implementation standpoint, it is often more tractable to scale probabilities and floor them on \(\{0, \dotsc, 2^{k} - 1\}\) for some \(k\). There is a loss of precision in doing so, but let’s put that under the rug at the moment.

from random import random
from math import floor

k = 16  # Probability scaling
b = 4   # Radix

# Prob of each digit, sums to 1
p = [0.1, 0.3, 0.4, 0.2] 

# Rescale and make sure that the proba sum to 1
scaled_p = [floor(p_u * (1 << k)) for p_u in p[:-1]]
scaled_p.append((1 << k) - sum(scaled_p))

# Compute regions
regions = []
start = 0
for u in range(b):
    regions.append(range(start, start + scaled_p[u]))
    start += scaled_p[u]

def get_region(z):
    for u in range(b - 1):
        if z in regions[u]:
            return u
    return b - 1

def get_digit():
    # Sample z uniformly
    z = floor(random() * (1 << k))

    # Find in which region it fell
    return get_region(z)

# Test it out
N = 100000
digits = [get_digit() for _ in range(N)]

observed_p = [len([x for x in digits if x == u]) / N for u in range(b)]
print(observed_p)

# > [0.10084, 0.30154, 0.39751, 0.20011]

Notice however that any value \(z\) that falls in a region \(R_iu\) will produce the same digit, \(u\). So, instead of writing a number in terms of its digits, we may write it in terms of values \(z\). Why would we do that?

from random import choice

digits = [0, 1, 1, 1, 2, 2, 2, 2, 3, 3]

def z_from_digit(u):
    return choice(regions[u])

encoded = [z_from_digit(u) for u in digits]

decoded = [get_region(z_i) for z_i in encoded]

print(encoded)
# > [111, 16600, 24224, 12677, 49041, 43789, 35209, 50876, 63764, 65319]

print(decoded)
# > [0, 1, 1, 1, 2, 2, 2, 2, 3, 3]

Ok, ok, that sounds silly: we replaced a digit that we could store on 2 bits by a 16-bit integer…

But now this integer is uniformly distributed! So we can use the mixed radix representation from earlier. We’ll use the fact that if \(z\) is uniformly distributed mod \(2^k\), then it is uniformly distributed (approximately) modulo \(p(u)\) for each \(u\).

Let’s work it out together: say we have the sequence of digits \(2,1, 0, 2\) in base 4, each being distributed according to \(p\) (assumed rescaled to \(2^k\)):

  • The state \(s\) is initialised to 0
  • The first digit is 2: let \(z = p(0) + p(1)\) and \(s \gets z\). This was easy.
  • The next digit is 1. This time, we’ll set \(z \gets p(0) + (s \bmod p(1))\). Then
\[s \gets 2^k \frac{s}{p(1)} + z\]
  • Similarly for all the subsequent digits \(u\), we’ll compute
\[\begin{align*} z &\gets (s \bmod p(u)) + p(0) + p(1) + \cdots + p(u-1) \\ s & = \gets 2^k \frac{s}{p(u)} + z \end{align*}\]

The end result of this encoding is \(s\). Decoding will work in reverse:

def decode_one(s):
    z = s & ((1 << k) - 1)    # k lowest bits
    s >>= k
    for u, p_u in enumerate(p):
        if z >= p_u:
            z -= p_u
        else:
            break
    s = s * p[u] + z
    return u, s

This function takes a state and returns the last digit encoded together with an updated state that has this digit removed. We can write the corresponding encode_one function and test them out:

def encode_one(s, u):
    z = sum(p[0:u]) + s % p[u] 
    s = s // p[u]
    s = (s << k) | z
    return s

def encode(seq):
    s = 0
    for u in seq:
        s = encode_one(s, u)
    return s

def decode(s):
    seq = []
    while s != 0:
        u, s = decode_one(s)
        seq.append(u)
    return seq

test_seq = [1, 1, 2, 0, 2, 2, 1, 2, 1]
print(encode(test_seq))

# > 70065640

print(decode(70065640))

# > [1, 2, 1, 2, 2, 0, 2, 1, 1]

As you see, initial sequence of digit is reversed, FIFO-style; nothing we can’t deal with. Furthermore, instead of having a collection of 9 16-bit digits as a result, we have a single 27-bit digit, that’s somewhat better than using \(z\) directly. In fact, look at what happens when we consider even longer sequences:

test_seq = [get_digit() for _ in range(10000)]

s = encode(test_seq)

In this case, \(s\) is a 18,500-bit digit. I’ll grant you it’s huge (compared to every day digits), but what did we expect? Storing ten thousands numbers on two bits would require… 20,000 bits. We’ve beaten Huffman!

The Shannon entropy for this particular distribution is 1.8464… bits/symbol: we’re pretty close! In fact, the reason we’re not even closer is that we approximated our probabilities. Better approximation (i.e., a larger value of \(k\)) would get us even closer, but at the price of having to handle larger numbers.

This is called the bits-back trick, a technique first described in a 1990 paper by Wallace. As you saw, combined with mixed radix encoding, this already brings us to the limit of what can be achieved in theory. But… it is not efficient.

Focusing on the head

What’s the issue? It has to do with the state. It quickly becomes a large integer, which is updated entirely every time we encode a symbol. The same happens when decoding. But really, that’s not necessary.

Indeed, as for fixed radix integers, we should be able (to some extent) to read the most significant digits without knowing the least significant ones. Thus, while decoding at least, we would expect to mostly care about only a portion of state.

Let’s flesh this idea out. Our state will be split between a head and a ̀tail, and try to work on the head as much as possible, using the tail as storage.

k = 16                  # Probability scaling "precision" factor 
mask = (1 << k) - 1     # The number 111111...11

# Our new state
state = {'head': 0, 'tail': []}

def decode_one(s):
    # This part is unchanged, except that we now operate only on the
    # 'head' part of the state
    z = s['head'] & mask
    s['head'] >>= k
    for u, p_u in enumerate(p):
        if z >= p_u:
            z -= p_u
        else:
            break
    s['head'] = s['head'] * p[u] + z

    # This part is new: if the head gets too small, we load some
    # content from the tail
    if s['head'] >> k == 0 and len(s['tail'] != 0:
        s['head'] = (s['head'] << k) | s['tail'].pop()

    return u, s

def encode_one(u, s):
    # This part is new, we check ahead of time whether the head
    # would grow too large: if so, we offload some of it to the tail
    if (s['head'] >> k) >= p[u]:
        s['tail'].append(s['head'] & mask)
        s['head'] >>= k

    # This part is unchanged, except that we now operate only on the
    # 'head' part of the state
    z = sum(p[0:u]) + s['head'] % p[u] 
    s['head'] = s['head'] // p[u]
    s['head'] = (s['head'] << k) | z

    return s

If necessary, we can serialize head and tail into a single array:

def flush(s):
    data = s['tail'][:]
    head = s['head']
    while head != 0:
        data.append(head & mask)
        head >>= k
    return data

This mixed radix representation, where we use a tail as storage, is essentially what is referred to in the literature as a “streaming asymmetric number system”, or (s)ANS for short. The tANS we keep talking about is a different implementation of the same idea using tables, which makes it particularly efficient. But let’s keep to the simple version here, and call it ANS.

Minor arithmetic improvements

Now let’s have a deep loop at the basic operations in our code. Binary operations (shift, and, or) are efficient, but we also do arithmetics (addition, subtraction, division, modulo) which is still costly. We can do a few things about it:

  • Instead of computing sum(p[0:u]) every time we need it, we can precompute this value for all u into an array sum_p for instance;
  • Divisions and remainders may be sped up using Barrett’s algorithm, as we’ll discuss below
  • Multiplication by p[u] may be sped up using a variety of techniques (Booth recoding, ternary recoding, double-base recoding). However, for such small numbers as the ones considered here, it seems that this optimisation wouldn’t be worth it.

To describe Barrett’s algorithm, let’s say we have a number \(n\) known in advance, and a number \(a\) given to us. We’d like to know the quotient and remainder of the division of \(a\) by \(n\). Barrett’s algorithm gives us the answer while avoiding the costly division by \(n\).

The trick is to precompute, once and for all, the integer

\[m = (1 << \kappa) / n\]

where \(\kappa\) is chosen large enough. The number \(m\) is an approximation of \(2^\kappa / n\) and therefore \(q \gets ma >> \kappa\) is an approximation of the quotient we’re looking for; similarly for the remainder \(r \gets a - qn\) .

This approximation turns out to be really good: for well chosen \(\kappa\) we either hit the right answer, or we are just one off. The whole algorithm looks like this:

kappa = 16
n = 1234

m = (1 << kappa) // n

def qr(a, m, n):
    q = (m * a) >> kappa
    r = a - q * n
    if r > n:
        r -= n
        q += 1
    return q, r

print(qr(5678, m, n))
# > (4, 742)

We can apply this idea to make our encoding faster: compute the m value for each p[u] and replace the lines

z = sum(p[0:u]) + s['head'] % p[u] 
s['head'] = s['head'] // p[u]

in our encode_one function by

q, r = qr(s['head'], m[u], p[u])
z = sum_p[u] + r
s['head'] = q

And now… 🦀 RIIR 🦀

Since we have something that works, let’s make it worth our time. Unlike Python, Rust will require more care in some of our operations. Let’s start with Barrett reduction first: here’s a direct translation of the Python code:

struct BInt {
    value: u128,
    prec: u128,
    kappa: usize,
}

impl BInt {
    pub const fn new(value: u64, kappa: usize) -> Self {
        let value = value as u128;
        let prec = (1 << kappa) / value;
        Self { value, prec, kappa }
    }

    pub const fn qr(&self, a: u64) -> (u64, u64) {
        let a = a as u128;
        let mut q = (self.prec * a) >> self.kappa;
        let mut r = a - q * self.value;

        while r > self.value {
            r -= self.value;
            q += 1;
        }

        (q as u64, r as u64)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_barrett() {
        let b1234 = BInt::new(1234, 8);
        let (q, r) = b1234.qr(5678);

        assert_eq!(q, 4);
        assert_eq!(r, 742);
    }

}

Notice a few changes however:

  • The most obvious is that we need to specify integer sizes. We’ll only be working with positive numbers, so unsigned is fine, and we’ll use a value of \(k = 32\) so that the head won’t exceed 64 bits at worst. So what’s the deal with u128? Simply put, multiplying two u64 together may overflow. By using u128 we avoid having to deal with that issue. The results are always u64 though (and smaller than the inputs).
  • In qr, we replaced the corrective subtraction if by a while. Didn’t I say that for carefully-chosen \(\kappa\) we would need at most one subtraction? Yes, yes I did. But since I don’t want to think about determining the optimal \(\kappa\) value I’ll just accept that sometimes I do more than one subtraction.
  • We marked this functions const. This notion does not exist in Python. Functions that are ̀const can be called at compile time, and there’s no good reason to waste this opportunity.

On to the implementation of our state, which will be a minimal struct:

pub struct State {
    head: u64,
    tail: Vec<u64>,
}

impl State {
    pub const fn new(head: u64, tail: Vec<u64>) -> Self {
        Self { head, tail }
    }

}

Then it’s time for the heart of the matter, our codec. Let’s start with its fields

pub struct AnsCodec {
    state: State,       // Current state
    sum_p: Vec<u64>,    // Cumulative probas (precomputed)
    bar_p: Vec<BInt>,   // Barrett-ready probas (precomputed)
    probas: Vec<u64>,   // Probas for each symbol
}

The parts that need precomputation will be handled at the coded creation, in the most obvious possible way:

const PARAM_K: usize = 16;
const MASK: u64 = (1 << PARAM_K) - 1;

impl AnsCodec {
    pub fn new(probas: Vec<u64>, initial_state: State, kappa: usize) -> Self {
        let mut bar_p = Vec::with_capacity(probas.len());
        let mut sum_p = Vec::with_capacity(probas.len());
        let mut counter = 0;

        // Precompute for Barrett reduction and cumulative proba
        for p_u in probas.clone() {
            bar_p.push(BInt::new(p_u, kappa));
            sum_p.push(counter);
            counter += p_u ;
        }

        Self {
            state: initial_state,
            sum_p,
            bar_p,
            probas,
        }
    }
}

As a quick note, we don’t perform any checks here: we assume that probas sum to PARAM_K and that kappa is smaller than 64. We could such checks, and maybe later we will, but for now let’s focus on the important stuff. Encoding is a straightforward brainless port of the Python version:

impl AnsCodec {
    // ...

    pub fn encode_one(&mut self, u: usize) {
        if (self.state.head >> PARAM_K) >= self.probas[u] {
            self.state.tail.push(self.state.head & MASK);
            self.state.head >>= PARAM_K;
        }

        let (q, r) = self.bar_p[u].qr(self.state.head);
        let z = self.sum_p[u] + r;
        self.state.head = q;
        self.state.head = (self.state.head << PARAM_K) | z;
    }    
}

To port decoding we must focus temporarily on how Python handles for loops with breaks, which is different from Rust’s philosophy. Take for instance the following code

for i in range(100):
    if f(i) == 0:
        break
print(i)

Two things can happen here: either there exists a value i in the interval such that f(i) is zero, and then what we display is the first such value, or there is no such value and what we display is the last value of the interval, namely 99. In any case, the loop variable i is used outside of the for loop’s scope. How should we write that in Rust?

A “dirty” way is to define i out of the loop, and do things manually:

let mut i = 0;
while i < 100 {
    if f(i) == 0 {
        break;
    }
    i += 1;
}
println!("{}", i);

This will work, and this doesn’t look like such a pain to implement the for loop. However, in many ways, the semantics of what we’re trying to do are lost and furthermore, this does not play well with iterators. So here’s a more contrived way to do the same:

let i = (0..n)
    .find(|i| f(*i) == 0)
    .or_else(|| Some(n-1))
    .unwrap()

println!("{}", i);

Interestingly, while both codes yield the same result in the end, rustc produces slightly different assembly for them. I am not sure why or whether one version is better than the other at low level. Alas… we can’t have side-effects in the loop with the “clean” version.

So we have to go with the “dirty” one, at least for now. (Ok, I could go something inbetween, but for the sake of readability and being a straight port from Python, let’s stick to that. For now.)

impl AndCodec {
    // ...
    pub fn decode_one(&mut self) -> usize {
        let mut z = self.state.head & MASK;
        self.state.head >>= PARAM_K;

        let mut u = 0;
        while u < self.probas.len() {
            if z >= self.probas[u] {
                z -= self.probas[u];
            } else {
                break;
            }
            u += 1;
        }

        self.state.head = self.state.head * self.probas[u] + z;

        if (self.state.head >> PARAM_K == 0) && !self.state.tail.is_empty() {
            self.state.head = (self.state.head << PARAM_K) | self.state.tail.pop().unwrap();
        }

        u
    }
}

See? It’s not that bad. We have to unwrap() at some point, but since we just checked that the tail wasn’t empty, this will never panic. Let’s wrap these methods into something more easy to wield, starting with encoding:

impl AndCodec {
    // ...

    pub fn encode(&mut self, input: &mut Vec<usize>) {
        // We encode in reverse, so that decoding is in the right order
        while !input.is_empty() {
            self.encode_one(input.pop().unwrap());
        }
    }
}

Neat, isn’t it? What about ̀ decode`? There’s one thing to think about first: how do we know that we’re done decoding? The answer is: we can’t know just looking at the state. (That’s something we discussed before in this series, but if you haven’t read the other posts — and who could blame you — tANS/FSE can output symbols even in the absence of input!) The easy workaround is to provide the expected decoded length and stop when we reach that:

impl AnsCodec {
    // ...

    pub fn decode(&mut self, n: usize) -> Vec<usize> {
        (0..n).map(|_| self.decode_one()).collect()
    }
}

And… that’s it! Let’s test this all.

#[cfg(test)]
mod tests {
    // ..

    #[test]
    fn test_codec() {
        // Empty state
        let s = State::new(0, vec![]);

        // Probabilities: 0.1, 0.3, 0.4, 0.2, scaled to PARAM_K = 16
        let probas = vec![6553, 19660, 26214, 13109];

        // Some sequence we want to encode
        let mut msg = vec![2, 1, 0, 0, 1, 1, 1, 2, 2, 1, 0, 0, 0];

        // Initialize codec
        let kappa = 64;
        let mut codec = AnsCodec::new(probas, s, kappa);

        let left = msg.clone();
        let n = msg.len();

        // Encode
        codec.encode(&mut msg);

        // Decode
        let right = codec.decode(n);

        assert_eq!(left, right);
    }
}

For good measure, let’s add a flush and the corresponding load method to the state, so that we can serialize and deserialize more easily.

impl State {
    pub const fn new(head: u64, tail: Vec<u64>) -> Self {
        Self { head, tail }
    }

    pub fn flush(mut self) -> Vec<u64> {
        let mut head = self.head;
        while head != 0 {
            self.tail.push(head & MASK);
            head >>= PARAM_K;
        }
        self.tail
    }

    pub fn load(&mut self) {
        if (self.head >> PARAM_K == 0) && !self.tail.is_empty() {
            self.head = (self.head << PARAM_K) | self.tail.pop().unwrap();
        }
    }
}

A larger test

Our codec is finished. Let’s put it through a somewhat more interesting test: compressing text. We’ll pick the enwik8dataset which is built from Wikipedia contents, and try to encode it. The dataset actually contains a lot of special characters (it contains XML, Unicode, etc.) but we’ll treat it as a sequence of bytes. As a prior distribution, we’ll use the actual byte distribution thanks to the following short Python script. We’ll also bump our value of \(k\) from 16 to 32.

# -*- coding: utf-8 -*-

from collections import defaultdict
from tqdm import tqdm

# Read file
with open("enwik8", "rb") as f:
    text = f.read()


# Count frequencies and symbols
total = 0
d = defaultdict(int)
for x in tqdm(text):
    d[x] += 1

    total += 1

# Rescale probabilities
PARAM_K = 32
symbols = []
probas = []
for symbol, freq in d.items():
    probas.append((freq * (1 << PARAM_K))//total)
    symbols.append(symbol)

probas.append((1 << PARAM_K) - sum(probas))

print(', '.join(map(hex, symbols)))
print(', '.join(map(hex, probas)))

The reason we need to collect which symbols appear too is that we don’t want any 0 probability in the array (try and see why). But since we do that, we’ll have to remap the input bytes. Does it sound complicated? Here’s a silly example: say that only symbols 3, 7, 9 appear with nonzero probability: we’ll map them to 0, 1, 2 respectively before passing them to our codec. Then, upon decoding, we’ll map 0, 1, 2 back to 3, 7, 9.

When we do that, we compress (if we compile for release) in less than a second and the result is 32 * 15882834 + 64 bits. This is 63.5% of the original size!. Based on the statistics our Python script gave us, this is extremely close to the file’s Shannon entropy. This is good! And 15 ns per byte is not too bad either.

“Final” words

There we have it. A complete codec, less than 250 lines of Rust (including comments and tests! Less than 130 loc without tests.)

Hopefully, this bare bones codec sheds some light on how the idea of a number system relates to data compression, and gives you another way to understand Huffman and tANS/FSE coders.

By the way, don’t be fooled by Shannon: if you run tar with gzip compression on enwiki8(which takes 4.5 seconds on my computer) the resuling file will be 36.5% of the original size. Entropy coding can’t beat entropy, but dictionary methods can (and that’s why we’ve been using them for all this time!)

Did you like this? Do you want to support this kind of stuff? Please share around. You can also buy me a coffee.