Turning a byte pair encoded string back to its surface form
The huggingface tokenizers library, can be used to train many varieties of sub-word tokenizers. In short, a sub-word tokenizer is a tokenizer that learns to split up strings into frequently occurring pieces. Ideally, a sub-word tokenizer is exhaustive, which means that it can split up any string, even strings contains sequences it has never seen before, into sub-word tokens. A truly exhaustive sub-word tokenizer is useful because it will never ever encounter an <UNK>
symbol, i.e., a thing it doesn’t know what to do with. Reaching this state is difficult when tokenizing on the characte level, however, as there are tens of thousands of unique unicode characters, and it is undesirable to give all of these unicode characters separate tokens in the vocabulary.
Byte-level pre-tokenization
One way to overcome this issue is by using byte-level pre-tokenization, which decomposes each unicode character into a sequence of one to four bytes. As there are only 256 unique bytes, the base vocabulary of the subword tokenizer becomes a lot smaller. Simultaneously, you will be guaranteed to never encounter an unknown token, as all characters have unique byte decompositions. The downside of this is of course that your sequences also become a lot longer, since many characters decompose into multiple bytes.
Implementation and problem statement
Pre-tokenization on the byte level often leads to incomprehensible tokens. For example, the string お問い合わせください
(which means “Please contact us”, according to Google Translate) is converted to the unreadable ãģĬåķıãģĦåIJĪãĤıãģĽãģıãģłãģķãģĦ
. One naive solution would be to map all these characters back to their bytes, as follows:
string = "ãģĬåķıãģĦåIJĪãĤıãģĽãģıãģłãģķãģĦ"
bytes([ord(char) for char in string]).encode("utf-8")
But this leads to an unfortunate error, because not all characters are actually single bytes.
Solution and explanation
This happens because the byte pre-tokenizer in transformers actually remaps some characters to ordinals > 256. This is likely because the canonical character table contains spaces, tabs, and various control characters.
We therefore first need to remap the characters to their correct values, which you can do with the following function:
from itertools import chain
from typing import Dict
def create_char_map() -> Dict[str, int]:
pairs = [("!", "~"), ("\xA1", "\xAC"), ("\xAE", "\xFF")]
regular = set(chain(*[range(ord(s), ord(e) + 1) for s, e in pairs]))
mapping = {}
i = 0
for ordinal in range(256):
if ordinal in regular:
mapping[chr(ordinal)] = ordinal
else:
mapping[chr(256 + i)] = ordinal
i += 1
return mapping
This is based on the same function in Rust, see here. In short, this function defines character ranges in which characters get mapped to characters outside of the byte range. First remapping the characters in a token using this map will lead to the desired result:
mapping = create_char_map()
string = "ãģĬåķıãģĦåIJĪãĤıãģĽãģıãģłãģķãģĦ"
remapped = bytes([mapping[char] for char in string])
decoded_string = remapped.decode("utf-8")
expected = "お問い合わせください"
assert expected == decoded_string
As you can see, the assert passes, and we correctly decoded the entire string. NICE!