Disclaimer: this is not final. If you find mistakes, please report them.

Part IV: tANS/FSE coding

This post is part of a series on entropy coding and writing our own Zstandard library in Rust.

Part I · II · III · IV · V · VI · VII · VIII · IX · X

Welcome back to the series. In the previous episode we covered two of the three compression techniques used in Zstandard: Lempel-Ziv and Huffman coding. The third ingredient is what makes all the difference, and it deserves that we discuss it at length today: tANS encoding (called FSE in Zstandard). This is really the main “new thing” that makes Zstandard different from boring old zip.

So, yea, some more compression theory for you today I guess? Please keep your hands inside the vehicle at all times.

Two limitations of Huffman coding

Recall from last time that in Huffman coding we replace every input symbol by a sequence of bits, according to a correspondance table (look at Part III for details on how we build this table). The main takeaway is that wa map common symbols to short sequences, and sequences can always be distringuished from one another (there is no decoding ambiguity). This is all very classical.

Huffman coding faces two fundamental limitations however:

  • Every input symbol is dealt with idependently. While this may not sound like a bad thing in itself, it means that Huffman coding misses some opportunity to save memory in some easy cases: in English for instance, the letter q is almost always followed by the letter u. But Huffman coding has no way to capture such contextual hints, and would process the inputs uq or qu letter-per-letter, yielding an output that has the same length in both cases.
  • The smallest output is 1 bit. If a symbol appears with overwhelming probability in the input stream, it will be assigned the shortest possible sequence: 1. But this is actually a waste of space because, from an information-theoretic perspective, we would only need say 0.001 bit. (Of course, there is no such thing as 0.001 bit: this should be understood to mean on average that’s how much space we use for this symbol).

Generally speaking, Huffman coding maps (on average) each input symbol to an integral number of bits. This is optimal if an only if the frequency at which input symbols occur are powers of two: 1/2, 1/4, 1/8, etc. On real-world data such a remarkable alignment is very unlikely to occur: Huffman coding accumulates a small inefficiency that builds up as we process the input.

We could extend the Huffman algorithm to work on a pair of symbols (rather than individual symbols), which would somewhat alleviate these phenomena, but it would cause the table become prohibitively large, and the issues would remain: they would just be less salient. In particular Huffman coding still wouldn’t deal efficiently with longer sequences that are completely predictable.

Huffman coding revisited

The above limitations were noticed fairly early in the history of data compression, but it was widely believed that no approach could avoid them. This belief was shook starting around 2007 when Duda published a paper on so-called “asymmetric number systems”. The paper was densely written, some would say heavy and clumsy, and was chiefly a theoretical observation about how we could tweak the way bits are read in general. The application to data compression was mentioned, but the paper failed to garner attention back then, in no small part because of its (perceived or actual) difficulty.

It took the combined efforts of Collet, Bloom, and others to try and make sense of it all, and — much more importantly — turn this highly concentrated har-skinned kernel of an idea into a fully-grown algorithm, together with several implementation. Around 2013, the first efficient implementation of tANS (“tabled asymmetric number system”) compression was released by Collet, who called it FSE – finite state entropy. It’s been tweaked ever since and shown to somewhat outperform other flavours of ANS, at least when agressively optimised.

Duda, Collet, and Bloom all wrote extensively to try and convey how ANS works, why it works, why we should care and how we should implement it. I’ve read them all, and I’ll give links below, but in a fit of egomaniac pride, I am going to approach things from a different angle. It’s a new take as to what tANS coding really is, that I personally find somewhat more to my taste. Because it’s not the way tANS is usually discussed, in case you don’t understand my explanation you can go and read the “classical” explanations of Duda, Collet, Bloom (and vice versa).

My claim is that tANS is (just) a convolutional encoder, and so it Huffman coding. From a error correction perspective these codes are literally the worst you can ever hope to build: a one bit mistake anywhere will likely ruin any hope of recovering anything of value. Nevertheless, all the tools and algorithms coming from error correction theory will be useful there, and if you have some background in the field you’ll recognise right away most of the tricks that had to be reinvented by people trying to make sense of tANS.

Note: it may sound like I’m blaming authors for their pedagogy here – I’m not. ANS is a complex topic and their presentation is suited to the various proofs and results that they show about it. Hindsight and focus on a single variant of ANS both allow me to be somewhat more free in how to approach things, and if it weren’t for these authors careful exploration of ANS and their distillation and selection and their expository work I wouldn’t even have something to say. This blog post is not a replacement for theirs, it’s a companion.

To get started let’s set an example, which we can easily work out with pen and peper. Assume we want to compactly represent an input made of two possible symbols: “red” and “blue” (R and B). If each colour occurs with equal probability, we can simply use 0 and 1, respectively, to represent each symbol. But if R is much more likely than B, using Huffman coding we’d chose instead something like 1 and 01.

Huffman coding, as discussed above, simply takes an input symbol and outputs the corresponding sequence. We can represent this using the following diagram:


Fig. 1: Huffman coding, revisited as a convolutional encoder.


Here’s how to read this:

  • Starting from state \(A\)
  • If we receive a red input symbol, follow the red arrow (called a transition) and output the sequence it carries (which is always 1)
  • If we receive a blue input symbol instead, follow the blue transition and output the sequence it carries (which is always 10)
  • Following a transition brings us to a new state (which, in this case, is the same state \(A\) we started from.)

Note that we output the Huffman codes in reverse (10 instead of 01), the reason for which will shine during decoding.

We could have “looped” the arrows from \(A\) back to itself, as there is only one “state”, but we’ll refrain from doing that just now. In fact, let’s lay out more copies of \(A\) to handle several input symbols one after the other:


Fig. 2: Huffman coding, revisited as a convolutional encoder.


Also the state plays no role here so you may be wondering why we bring it up at all. Bear we me for a moment. Let’s see how this encoder plays out on an example input: RRRBRBBRRRRR (as expected, red symbols happen more often than blue symbols). The input is 12 symbols long.

Passing through our Huffman encoder starting from leftmost state \(A\) yields: 1, 1, 1, 10, 1, 10, 10, 1, 1, 1, 1, 1. Altogether this gives 111101101011111 which is a total of 15 bits. The sequence of states we pass through is, unsuprisingly \(A, A, \dotsc, A\).

How do we decode 111101101011111? We work backwards: starting from the final state (right-most \(A\)), we reverse the arrows of transitions and proceed as follows

  • We read the output, starting from the end
  • We consume one bit: 111101101011111
  • There is a transition with this sequence: it is a red transition going from final \(A\) to preceding \(A\). It is the only transition we can take.
  • Therefore we output R and follow the arrow to the new state.

Fig. 3: Huffman decoding, the reversed arrows.


Continuing that way we obtain R, R, R, R, R and then we reach a zero bit (this zero: 111101101011111) that doesn’t correspond to any valid transition. So we consume another bit: 111101101011111. The sequence 10 matches the blue transition, and we output B. One (or two) bits at a time, we recover the original input — albeit in reverse order: RRRRRBBRBRRR.

Note: one way to avoid the headache is have the encoder also read its input from end to beginning. This way the decoder’s output is in the normal order.

If you followed to this point, you are doing great. Yes, this is a somewhat contrived way to Huffman encode/decode. But this presentation will make it painless to now handle tANS.

Trellis encoding

In Huffman coding we had a single state (booooring) \(A\) and transitions. What is we have multiple states instead? We’ll call them \(A, B, C, D\). Transitions will now be able to transport us from one state to another. Let’s first draw our treillis, the graph that represents transitions, and then we’ll see why this can be interesting.


Fig. 4: A more interesting encoder.


In the above diagram, some transitions carry no output sequence, they are represented as dotted arrow. Note that the blue arrows always carry a 2-bit sequence, the red arrows always carry a 1-bit sequence, and the dotted arrow carry a 0-bit sequence. As we did for Huffman coding, we should think of the graph as being repeated:


Fig. 5: A more interesting encoder.


Okay so let’s see that in action, with the same input as before: RRRBRBBRRRRR. Starting from left-most state \(A\),

  • Input is R we follow the red dotted transition, output nothing, and get to new state \(C\);
  • Next input is R we follow the red transition, output 0, and get to new state \(A\);
  • Next input is R we output nothing and get to \(C\);
  • Next input is B we follow the blue transition, output 01, and get to new state \(B\);
  • Next input is ̀R, output nothing, get to \(D\);
  • Next input is B, output 11, get to \(B\);
  • Next input is B, output 10, get to \(B\);
  • Next input is R, output nothing, get to \(D\);
  • Next input is R, output 1, get to \(A\);
  • Next input is R, output nothing, get to \(C\);
  • Next input is R, output 0, get to \(A\);
  • Next input is R, output nothing and get to \(C\).

All in all, our encoded produced the following bit sequence: 001111010 (total: 8 bits), and went throug states \(ACACBDBBDACAC\). Hopefully you could follow along easily.

To decode, we’ll follow the exact same approach as for our earlier Huffman example, except now it will be easier. Start by reversing the transitions:


Fig. 6: The decoder trellis.


Then starting from final state \(C\) (remember, that’s where we landed at the end of encoding) and reading the encoded sequence 001111010 backwards,

  • There is only one transition possible from \(C\) so we take it. It is a red transition so we output R, and get to state \(A\).
  • There are only two transitions possible from \(A\): we need to consume one bit to tell which one to take. 001111010: we output R and go to state \(C\).
  • No choice: output R and go to state \(A\).
  • Take one bit 001111010: output R and go to state \(D\).
  • No choice: output R and go to state \(B\).
  • There are four transitions possible from \(B\), we need to consume two bits to tell which one to take. 001111010: we output B and go to state \(B\).
  • Take two bits 001111010: we output B and go to state \(D\).
  • No choice: output R and go to state \(B\).
  • Take two bits 001111010: we output B and go to state \(C\).
  • No choice: output R and go to state \(A\)
  • Take one bit 001111010: output R and go to state \(C\).
  • No choice: output R and go to state \(A\).

There’s no more input so we’re finished (phew!). What did we output? RRRRRBBRBRRR in other terms the original message, but in reverse (as expected). What states did we visit? \(CACADBBDBCACA\), in other terms the same states as during encoding, but in reverse.

As we said above for Huffman coding, if we make the encoder and decoder both read their input from the end to the beginning, then upon decoding everything is in the right order.

Did you notice that our compressed output is about half as short as the Huffman code we obtained previously? Indeed our encoder spends a decent amount of time outputting nothing.

And that’s the key to this “having multiple states” deal: states tell the decoder what to expect next: \(D\) means “next symbol is red but the one after is blue!”, or \(C\) means “two next symbols are red, but beware for the third”, etc.

Thus symbols B always get turned into 2-bit sequences, but symbols R may be implicitly assumed (= 0-bit sequence) a large portion of the time.

Because the encoder is allowed to sometimes output nothing, the average number of bits used to encode a common symbol becomes smaller than one. This is the reason why we can outperform Huffman coding.

But wait! In Huffman coding we had to store the decoding table (or something equivalent to it). Here we need to store the graph in Fig. 4, which looks more complicated, and therefore takes some more memory!

This is true but even then, for long enough messages our trellis encoder will always outperform the Huffman encoder!

Give me the tables!

Now you should be telling yourself:

This is great! I just need a way to figure out this graph!

And you’d be right.

Luckily for us, it’s given to us when decrypting a file. We just need to recover it, and then run our decoding algorithm. As for Huffman tables, we only really need the weight associated to each symbol, from which we can build the table ourselves.

In fact, Zstandard defines standard distributions, which can be used instead of actual frequency counts. This is less precise but we don’t have to store them in every file, so that’s nice. Here are the distributions for LZ sequence elements (litteral_len, match_len, offset):

// File: zrs/zarc/src/fse.rs

const DEFAULT_LITLEN_WEIGHT: [i16; 36] = [
    4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, //
    2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, //
    -1, -1, -1, -1,
];

const DEFAULT_MATCHLEN_WEIGHT: [i16; 53] = [
    1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, //
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, //
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, //
    -1, -1, -1, -1, -1,
];

const DEFAULT_OFFSET_WEIGHT: [i16; 29] = [
    1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, //
    1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
];

Note that some symbols have a -1 weight: this is a special value. We could introduce an enum here, but it feels like this would be overkill. We can always change later if it bites us in the arse.

So where we’re using predefined weights, we merely have to recompute the decoding table from this information. Do we though? The Zstandard specs give us the complete decoding table for this case, so why bother?

Two reasons. First reason is, we’ll need this anyway, because we won’t always have predefined tables. Second reason is that these tables are huge and I’m not bothered to re-type them nicely in Rust, especially in light of the first reason.

So let’s get crackin’! Here are the rules of table building:

  • The table as power-of-two lmany entries. It tells us all we need to output a symbol, advance in the bitstream, and go to the next state.
  • Every entry in the table contains three fields: symbol, nb_bits, and baseline

    • dtable[state].symbol represents the literal to be output
    • dtable[state].nb_bits tells us how many bits to consume from the input.
    • the bits we just read from the input form an integer value offset, and dtable[state].baseline + offset tells us which state to jump to, to continue decoding
  • We start with weight -1 symbols: they are put in the table, starting from the end;
  • Define a variable table_position initialised to 0
  • For each symbol (with weight 0 and above), put it in as many cells as its weight:

    • If index table_position is not occupied, put symbol in it
    • Update table_position: (this is meant to scatter symbols around)

      table_position += (table_size >> 1) + (table_size >> 3) + 3;
      table_position &= table_size - 1;
      
    • Repeat until we have placed as many copies as the symbol’s weight
  • For each symbol, we then need to determine the baseline and nb_bits. But let’s first try getting the symbols right.

Let’s start with the smaller one: offset codes. It has an accuracy of 5 bits, i.e., 32 states. Also let’s start with Python to make sure we get things right:

# File: zrs/fse.py

DEFAULT_OFFSET_WEIGHT = [
    1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
    1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
]

# Allocate table
accuracy = 5
table_size = 1 << accuracy
table = [0] * table_size

# Deal with -1 symbols
last_pos = -1
for (symbol, weight) in enumerate(DEFAULT_OFFSET_WEIGHT):
    if weight == -1:
        table[last_pos] = symbol
        last_pos -= 1

# Deal with the rest
table_position = 0
for (symbol, weight) in enumerate(DEFAULT_OFFSET_WEIGHT):
    if weight < 1:
        continue

    # Attribute as many cells as weight to this symbol
    repetitions = 0
    while repetitions < weight:
        if table[table_position] == 0:
            table[table_position] = symbol
            repetitions += 1

        # Update position
        table_position += (table_size >> 1) + (table_size >> 3) + 3
        table_position &= table_size - 1

print(table)
# [0, 6, 9, 15, 21, 3, 7, 12, 18, 23, 5, 8, 14, 20, 2, 7, 11, 17, 22, 4, 8, 13, 19, 1, 6, 10, 16, 28, 27, 26, 25, 24]

This is exactly the expected table! Testing this with the two other tables works like a charm too (isn’t it nice when everything works first time?), provided we care to set their accuracy to the correct value of 6.

Then we need to compute the baseline and nb_bits. Let’s start with nb_bits: it indicates how many bits are necessary to represent this state, i.e. and integer ot size table_size / weight. Rounding errors are dealt with by adding one bit to the first few states (in the order they appear in the table). This requires changing a little bit our logic, but nothing too bad:

# File: zrs/fse.py

# ...

# Deal with -1 symbols
last_pos = -1
for (symbol, weight) in enumerate(DEFAULT_OFFSET_WEIGHT):
    if weight == -1:
        #                          New!
        table[last_pos] = (symbol, accuracy)
        last_pos -= 1

# Deal with the rest
table_position = 0
for (symbol, weight) in enumerate(DEFAULT_OFFSET_WEIGHT):
    if weight < 1:
        continue

    # Attribute as many cells as weight to this symbol
    repetitions = 0

    # We now collect positions
    positions = []
    while repetitions < weight:
        if table[table_position] == 0:
            positions.append(table_position)
            repetitions += 1

        # Update position
        table_position += (table_size >> 1) + (table_size >> 3) + 3
        table_position &= table_size - 1

    # Compute nb_bits
    sorted_positions = sorted(positions)

    # Compute how many bits we need
    weight_log = (weight-1).bit_length()
    next_pow2 = 1 << weight_log
    count_double = next_pow2 - weight
    base_bits = accuracy - weight_log

    # Set nb_bits, adding 1 bit to account for rounding error
    for (i, position) in enumerate(sorted_positions):
        nb_bits = base_bits
        if i < count_double:
            nb_bits += 1
        table[position] = (symbol, nb_bits)


# Pretty print
print("state\tsym\tnb_bits")
for (i, u) in enumerate(table):
    print("%s\t%s" % (i, '\t'.join(map(lambda x: "%s" % x, u))))

Aaaaaand… success! I won’t lie, it will take a moment to read these innocent “compute how many bits we need” lines, but they do the work.

Finally, baseline. The standard is very obscure as to how exactly this is attributed, so here’s my understanding (which may not reflect what Zstandard really does?):

  • Starting at the state in position table[sorted_positions[count_double]]
  • Set baseline (which is 0), increment baseline by 1 << nb_bits
  • Go to the next state for this symbol (if necessary, we wrap back to the first state in the list and continue)
  • Set baseline (which is 0), increment baseline by 1 << nb_bits
  • etc. until all states have received their baseline

Also, symbols with -1 weight have a baseline always equal to 0.

# File: zrs/fse.py

# ...

# Deal with -1 symbols
last_pos = -1
for (symbol, weight) in enumerate(DEFAULT_OFFSET_WEIGHT):
    if weight == -1:
        #                                    New!
        table[last_pos] = (symbol, accuracy, 0)
        last_pos -= 1

# Deal with the rest
table_position = 0
for (symbol, weight) in enumerate(DEFAULT_OFFSET_WEIGHT):
    # ...

    # New!
    # Compute baseline
    baseline = 0
    position = count_double
    for k in range(weight):
        actual_position = (position + k) % weight
        (symbol, nb_bits) = table[sorted_positions[actual_position]]
        table[sorted_positions[actual_position]] = (symbol, nb_bits, baseline)
        baseline += 1 << nb_bits

# Pretty print
#                           New!
print("state\tsym\tnb_bits\tbaseline")
for (i, u) in enumerate(table):
    print("%s\t%s" % (i, '\t'.join(map(lambda x: "%s" % x, u))))

Wow isn’t this getting messy! But hey, when we run this, it matches exactly the reference tables. Not just this one, but all of them.

The above Python script is far from optimal, it’s just there to check our understanding of how this table it built. It is going to be a pain to make it clean and write the code in Rust instead. We could just take that Python and have it output the complete predefined decoding table in Rust format, as long-ass arrays that we just include! and… We wouldn’t. Would we?

Wouldn’t you like to know

There’s a natural question that we need to ask ourselves regarding the way we built our decoding table: what’s the deal with keeping similar symbols far away (the table_position computation)?

It’s very natural: when symbols are far from another, we need more bits to describe the offset. We want to use, on average, the minimal possible number of bits for our state transitions. If we packed states closely together, then transitions between these states would be ok, they would have small offsets; but transitions to other states would have to make long jumps, and that means large offsets.

The clever trick with spreading states all over, is that they are neither close nor far, which means that we don’t need so many bits to encode the offsets, which in the end means better compression.

Next time we’ll turn this monstruosity of a table building algorithm into Rust code, which will allow us to start decompressing data. Isn’t that nice? Go to Part V!

Did you like this? Do you want to support this kind of stuff? Please share around. You can also buy me a coffee.