Tips for Fast Tokenizing
Note: This is a draft, I'm not done with this post yet
I've been working on a lexer for my programming language. For reference, the excellent Rust crate Logos generates a lexer that works at ~536.8 MB/s on my computer. My lexer works at ~1359.1 MB/s.
I want to share all the tricks I use to make lexing fast.
Avoid Indirect Jumps
My initial version of the code had a big match statement. It looked something like:
let mut position = 0;
while not_done() {
let b: u8 = bytes[position];
match b {
b' ' | b'\t' | b'\r' | b'\n' => { /* ... */ }
b'a'..=b'z' | b'A'..=b'Z' | b'_' => { /* ... */ }
b'0'..=b'9' => { /* ... */ }
b'(' => { /* ... */ }
b')' => { /* ... */ }
b'[' => { /* ... */ }
b']' => { /* ... */ }
b'{' => { /* ... */ }
b'}' => { /* ... */ }
b';' => { /* ... */ }
b',' => { /* ... */ }
b'+' => { /* ... */ }
b'-' => { /* ... */ }
b'%' => { /* ... */ }
b'/' => { /* ... */ }
/* ... */
unknown => { /* emit error */ }
}
}Rust compiles match statements like these into a jump table, which is a table where each entry is the address of the corresponding match arm's code. The CPU reads the table and jumps where it says to go.
Your CPU hates this kind of jump table. This type of jump is a branch, because it can change the control flow depending on data, and it is an indirect branch, because the target of the branch is dynamic.
Computers want to read ahead in the machine code and start working on future instructions early (so-called out-of-order execution). In order to read ahead, the CPU has to know where the branch jumps to before the destination is actually decided, and it does this by guessing. Once execution catches up, it can validate its guess and continue or realize the mistake and reset. This process is called "branch prediction".
As you can imagine, a correctly predicted branch is essentially free, but an incorrectly predicted branch requires a handful of cycles to reset, and prevents the CPU from reading very far ahead. Unpredictible branches are the main bottleneck in my lexer. Any control flow that is affected by the specific source string is inherently unpredictible -- the branch predictor has to guess what's in the string without ever looking at the string's contents.
Jump threading
One way you can mitigate branch prediction problems is with a trick called jump threading. The idea of jump threading is to replace the dispatch code with multiple copies. This exposes more information to your CPU's branch predictor which makes better at guessing. (TODO explain jump threading better)
Jump threading is known to be very effective, but I don't explore jump threading because Rust is still working on the way to express the idea the compiler. There's an experimental feature flag that jankily solves the problem, and another experimental feature that's more general, and gives you the tools to solve this problem. In this post, I'll stick to things that are enabled by default in nightly Rust, but if you're in a language other than Rust, or more willing to play with experimental features than I am, I strongly recommend checking out jump threading.
Use if statements instead of switch/match for dispatch
You can remove the indirect branch, and instead use a sequence of direct branches (if statements):
let mut position = 0;
while not_done() {
let b: u8 = bytes[position];
if matches!(b, b' ' | b'\t' | b'\r' | b'\n') { /* ... */ }
else if matches!(b, b'a'..=b'z' | b'A'..=b'Z' | b'_') { /* ... */ }
else if matches!(b, b'0'..=b'9') { /* ... */ }
else if b == b'(' { /* ... */ }
else if b == b')' { /* ... */ }
else if b == b'[' { /* ... */ }
else if b == b']' { /* ... */ }
else if b == b'{' { /* ... */ }
else if b == b'}' { /* ... */ }
else if b == b';' { /* ... */ }
else if b == b',' { /* ... */ }
else if b == b'+' { /* ... */ }
else if b == b'-' { /* ... */ }
else if b == b'%' { /* ... */ }
else if b == b'/' { /* ... */ }
/* ... */
else { /* emit error */ }
}Sort by frequency
This part is the part that I'm least sure about, because the theory doesn't make sense to me, but experimentally, the best way I've found to lay out the branches is in one single big if-else chain, as shown above. Except you'll want to sort the checks from most frequent to least frequent. I looked into shapes other than a single if-else chain. For example, I tried making Huffman decision trees based on the probability of each type of token, but this made everything worse.
Use 256-bit masks
These branches (and likely more in your code to skip to the next token) want to tell if your byte is contained in some fixed set. For example, consider this condition:
if matches!(b, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_') {
// ...
}This generates branchy code:
mov eax, ebx ; let a: u8 = b
and al, -33 ; a &= ~0x10 // upcase
add al, -65 ; a -= 'A' // convert 'A'->0, 'B'->1, etc
cmp al, 26 ; if a < 26
jb .MATCHED ; { jump to MATCHED }
cmp bl, 48 ; if b < '0'
jb .SKIP ; { jump to SKIP }
cmp al, 58 ; bool c = b < '9'+1
setb cl ;
cmp al, 95 ; bool d = b == '_'
sete dl ;
or dl, bl ; if c | d
jne .SKIP ; { jump to SKIP }
; fallthrough to MATCHEDNote how there are 3 separate branches generated from this one if statement.
However, there's a way to do this with less branching: You can construct a table of 256 bits -- one bit in the table for each possible input byte -- and then test the appropriate bit.
In Rust, there's no u256, but you can use an array of four u64s.
const TABLE: [u64; 4] = [287948901175001088, 576460745995190270, 0, 0];
if ((TABLE[b as usize / 64] >> (b % 64)) & 1) == 1 {
// ...
}This version doesn't have extra branches (just the one from the if, of course):
mov eax, ebx ; u8 a = b
shr al, 6 ; a >>= 6 // divide by 64 with bit tricks
lea rcx, [rip + .TABLE] ; c = &TABLE
mov rax, qword ptr [rcx + 8*rax] ; u64 a = c[a]
bt rax, bl ; ((a >> (b % 64)) & 1) == 1 // wow, what a useful instruction!
jae .SKIP ; if not, skip the bodyThis way, the branch predictor doesn't need to separately predict if the byte is [a-zA-Z] and then predict if it's < '0', and then predict if it is [0-9_] to determine how control flow occurs. Instead, the branch predictor only predicts the final result.
The magic number constants aren't fun to maintain, so I use Rust's const expression stuff to factor the magic constant table back into readable code. This may need nightly features if you want to do something similar:
struct BitSet {
table: [u64; 4],
}
impl BitSet {
pub const fn new(members: &[u8]) -> BitSet {
let mut table = [0; 4];
let mut i = 0;
while i < members.len() {
let index = members[i] as usize;
table[index / 64] |= 1 << (index % 64);
i += 1;
}
BitSet { table }
}
pub const fn contains(&self, index: u8) -> bool {
((self.table[index as usize / 64] >> (index % 64)) & 1) == 1
}
}
const LETTERS_AND_UNDERSCORE_AND_DIGITS: BitSet =
BitSet::new(b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_");
if LETTERS_AND_UNDERSCORE_AND_DIGITS.contains(b) {
// ...
}Merge similar branches
Profiling for me still shows branches causing problems.
For my language, the tokens ;, ,, :, ., (, ), [, ], {, and } are great candidates to merge together because they all have essentially the same code.
The word essentially is doing some work -- the code isn't exactly the same, it differs on which token is being emitted (Token::Semicolon, Token::Comma, etc). So, to merge the code, you simply need to add a lookup table to cover the rest:
const TOKENS: [Token; 256] = const {
// fill with dummy tokens
let mut tokens = [Token::Eof; 256];
// store the right token for each byte that matters
tokens[b';' as usize] = Token::Semicolon;
tokens[b',' as usize] = Token::Comma;
tokens[b':' as usize] = Token::Colon;
tokens[b'.' as usize] = Token::Dot;
...
tokens
};
let mut position = 0;
while not_done() {
let b: u8 = bytes[position];
if const { BitSet::new(b";,:.()[]{}") }.contains(b) {
let tk = TOKENS[b as usize];
// ...
}
else if LETTERS_AND_UNDERSCORE.contains(b) { /* ... */ }
else if WHITESPACE.contains(b) { /* ... */ }
...
}It turns out these tokens, once combined, are so common that they go first on my if chain.
Use transmutation to map bytes
Sometimes, you can change your Token definition so that getting the tk from b is trivial.
#[derive(Default, Debug, Eq, PartialEq, Hash, Clone, Copy)]
#[repr(u8)]
pub enum Token {
...
...
Semicolon = b';',
Comma = b',',
Dot = b'.',
Colon = b':',
OpenParen = b'(',
CloseParen = b')',
...
}Now that Token's been modified, you can cast the byte directly into a Token:
if const { BitSet::new(b";,:.()[]{}") }.contains(b) {
let tk: Token = unsafe { std::mem::transmute(b) };
// ...
}I don't like to use unsafe in my code, but sometimes, you don't even need to call the mem::transmute!
In a release build, LLVM can sense that you're close to optimal code and will optimize on its own, without you explicitly transmuting:
if const { BitSet::new(b";,:.()[]{}") }.contains(b) {
let tk = match b {
b';' => Token::Semicolon,
b',' => Token::Comma,
b':' => Token::Colon,
b'.' => Token::Dot,
...
_ => unreachable!(),
};
// ...
}LLVM gets really close to perfect codegen with the above code, but doesn't quite take it home. LLVM sees how to fold the match arms, but it can't statically confirm the unreachable arm is unreachable. It essentially simplifies the above code to let tk = if matches!(b, b';' | b',' | ...) { std::mem::transmute(b) } else { unreachable!() };. To eliminate that final if statement, You have to either put every possible byte into the match (ie, map b'\0' to whatever your 0 token is, b'\1' to your 1 token, etc, so that after LLVM makes its transformation, the matches! code is trivial), or you can use unsafe { std::hint::unreachable_unchecked() } to make it fully go away.
I don't like this approach. In my code, I refuse to use unsafe, which feels very messy here, and it makes code that's sensitive to the definition of Token. But you may want to consider a strategy like this, depending on your situation.
Use Perfect Hash Tables and Keywords
One tricky problem is distinguishing between identifiers and keywords. Matching each keyword requires lot of string comparison, which is slow, and forces the branch predictor to guess, character by character, how your identifier is to each keyword.
let mut position = 0;
while not_done() {
let b: u8 = bytes[position];
// ...
if LETTERS_AND_UNDERSCORE.contains(b) {
let start_position = position;
position += 1;
while LETTERS_AND_UNDERSCORE_AND_DIGITS.contains(bytes[position]) {
position += 1;
}
let word = &bytes[start_position..position];
let tk = match word {
b"break" => Token::Break,
b"class" => Token::Class,
b"const" => Token::Const,
b"continue" => Token::Continue,
b"defer" => Token::Defer,
...
_ => Token::Identifier,
};
your_code_to_emit_a_token(tk, start_position..position);
}This match statement is not good. Instead of string comparison, which goes character by character, you want to process multiple bytes at once. Specifically, a u64 is 8 bytes, and can be processed all at once!
My specific set of keywords only has words that are 8 or less bytes (no interface, constexpr, etc), so u64 is good enough for me, but you may need to use u128.
Getting a word into a u64 cleanly is tricky:
// step 0, scan the word normally
let start_position = position;
position += 1;
while LETTERS_AND_UNDERSCORE_AND_DIGITS.contains(b) {
position += 1;
}
let len = position - start_position;
// preconditions to consider
if len > 8 { panic!("there was an identifier that can't fit into a u64"); }
if start_position + 8 <= bytes.len() { panic!("read off the end"); }
// step 1: convert to array
let word_slice = &bytes[start_position..start_position+8];
// step 2: convert to array
let word_array = word_slice.try_into().unwrap();
// step 3: array to u64, using u64::from_le_bytes
let word_u64 = u64::from_le_bytes(word_array);
// step 4: remove the extra bytes
// a u64 can only hold 64 bits, and left and right shift fill with 0's
// so you shake the number back and forth, to keep only `len * 8` bits left
let word = (word_u64 << (64 - len * 8)) >> (64 - len * 8);That's a lot of steps, and some tight preconditions. The first condition is easy: If a word is more than 8 bytes long, it's not a keyword. To defuse the second condition, you have to ensure that you won't ever have an identifier within 8 bytes of the end of the string. This can be achieved via padding:
bytes.extend_from_slice(b"\0\0\0\0\0\0\0\0");
{
// your entire lexer code
}
bytes.truncate(bytes.len() - 8);Now that the edge cases are handled, these steps successfully convert the word into a u64. Now, you just need to check if the word is a keyword. But how do you do this quickly? a HashMap of words?
You can do better, since you know the specific keyword list, you can build your own specialized HashMap algorithm, with a "perfect" hash algorithm.
A "perfect" hash is one that has no collisions for the set of values you care about. In this case, the task is to invent a hash function for your specific set of keywords.
My hash function is a wrapping multiplication by a tuneable parameter, followed by a right shift to make the number smaller than the number of buckets in my keyword map.
const fn hash<const N: usize>(parameter: u64, word: u64) -> usize {
(word.wrapping_mul(parameter) >> (64 - N.checked_ilog2().unwrap())) as usize
}But how is this perfect? Well, it depends on which parameter you use. For example, with my specific set of keywords, I can use the parameter 14645584635661411151 and N=32. Note how there are no collisions between the keywords -- each word gets its own index.
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"break\0\0\0" .try_into().unwrap())), 13)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"class\0\0\0" .try_into().unwrap())), 1)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"const\0\0\0" .try_into().unwrap())), 19)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"continue" .try_into().unwrap())), 5)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"defer\0\0\0" .try_into().unwrap())), 23)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"else\0\0\0\0" .try_into().unwrap())), 3)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"enum\0\0\0\0" .try_into().unwrap())), 25)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"fn\0\0\0\0\0\0".try_into().unwrap())), 9)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"for\0\0\0\0\0" .try_into().unwrap())), 12)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"if\0\0\0\0\0\0".try_into().unwrap())), 22)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"inline\0\0" .try_into().unwrap())), 14)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"let\0\0\0\0\0" .try_into().unwrap())), 26)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"loop\0\0\0\0" .try_into().unwrap())), 2)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"match\0\0\0" .try_into().unwrap())), 15)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"mut\0\0\0\0\0" .try_into().unwrap())), 18)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"return\0\0" .try_into().unwrap())), 24)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"struct\0\0" .try_into().unwrap())), 4)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"while\0\0\0" .try_into().unwrap())), 11)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"void\0\0\0\0" .try_into().unwrap())), 0)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"true\0\0\0\0" .try_into().unwrap())), 20)
assert_eq!(hash::<32>(14645584635661411151, u64::from_le_bytes(b"false\0\0\0" .try_into().unwrap())), 7)Then, to do a full lookup, you just need to hash your word and check if it matches the word in the corresponding bucket.
pub const fn lookup_word(&self, word: u64) -> T {
let (candidate, value) = self.table[hash::<N>(self.parameter, word)];
if word == candidate {
value
} else {
self.default
}
}You can see this is a method, which goes on a struct that represents a perfect hash map:
struct KeywordHashMap<T: Copy, const N: usize> {
table: [(u64, T); N],
parameter: u64,
default: T,
}TODO introduce the next code block
let start_position = position;
position += 1;
while LETTERS_AND_UNDERSCORE_AND_DIGITS.contains(b) {
position += 1;
}
let len = position - start_position;
if start_position + 8 <= bytes.len() { panic!("read off the end"); }
let tk = if len > 8 {
Token::Identifier
} else {
let word = (u64::from_le_bytes(bytes[start_position..start_position+8].try_into().unwrap()) << (64 - len * 8)) >> (64 - len * 8);
const {
KeywordHashMap::<Token, 32>::new(
&[
("break", Token::Break),
("class", Token::Class),
("const", Token::Const),
("continue", Token::Continue),
("defer", Token::Defer),
("else", Token::Else),
("enum", Token::Enum),
("fn", Token::Fn),
("for", Token::For),
("if", Token::If),
("inline", Token::Inline),
("let", Token::Let),
("loop", Token::Loop),
("match", Token::Match),
("mut", Token::Mut),
("return", Token::Return),
("struct", Token::Struct),
("while", Token::While),
("void", Token::Void),
("true", Token::True),
("false", Token::False),
],
Token::Identifier,
)
}.lookup_word(word)
};The only remaining question is the implementation of KeywordHashMap::new, which has to choose a parameter to make the hash "perfect".
impl<T: Copy, const N: usize> KeywordHashMap<T, N> {
pub const fn lookup_word(&self, word: u64) -> T {
let (candidate, value) = self.table[hash::<N>(self.parameter, word)];
if word == candidate {
value
} else {
self.default
}
}
/// Constructing a KeywordHashMap at compile time by trying random parameters until one works
/// and by random I mean pseudorandom, generated with LCG
pub const fn new(keyword_list: &[(&str, T)], default: T) -> KeywordHashMap<T, N> {
let mut parameter = 1_u64;
loop {
if let Some(table) = KeywordHashMap::try_new(keyword_list, parameter, default) {
return table;
}
parameter = parameter.wrapping_mul(16364136223846793005).wrapping_add(1);
}
}
/// A parameter makes the hash "perfect" if it makes chooses a different bucket for each input
const fn try_new(keyword_list: &[(&str, T)], parameter: u64, default: T) -> Option<KeywordHashMap<T, N>> {
let mut i = 0;
let mut table = [(0, default); N];
let mut seen = [false; N];
while i < keyword_list.len() {
let (keyword, token) = keyword_list[i];
let bucket = hash::<N>(parameter, str_to_u64(keyword));
if seen[bucket] {
return None;
}
table[bucket] = (str_to_u64(keyword), token);
seen[bucket] = true;
i += 1;
}
Some(KeywordHashMap {
table,
parameter,
default,
})
}
}
/// I can't use `try_into()` in const right now, so I use this const-compatible slow way to convert to u64
const fn str_to_u64(f: &str) -> u64 {
let mut buf = [0; 8];
let mut i = 0;
assert!(f.len() <= 8, "string is too long to fit in register");
while i < f.len() {
buf[i] = f.as_bytes()[i];
i += 1;
}
u64::from_le_bytes(buf)
}TODO, conclusion to perfect hash maps
Defeat bound checks
Of course, it's possible to remove all the bounds checks using unsafe Rust, but I'm personally sticking to safe Rust. This means I need to give LLVM as much context as possible, so that it can decide on its own to remove bounds checks that it can see are unreachable.
TODO show the bound check for reading the input
Defeat bound checks, part 2
TODO show the trick to remove the Vec resizing from the hot loop
Scanning with SIMD
TODO explain why using SIMD for parsing is really really frustrating.
TODO find a spot to introduce &[u8] vs &str
TODO show final result