nanoGPT + Rust :: Part 1

Mixing Python Notebooks + Rust

I recently followed Andrej Karpathy’s Let’s build GPT: from scratch, in code, spelled out, implementing a simplified GPT in Python. The video is a fantastic learning resource. I recommend everyone go through it themselves to understand how transformers work.

Out of curiosity, I set up a hybrid Rust + PyTorch development environment and started re-implementing nanoGPT in Rust using the PyTorch Rust bindings.

This post is the first of a series that will walk through this process. In this post, we’ll cover integrating Rust into Python and Notebooks. We’ll use Rust to optimize a (simple) Python text pre-processing function.

Starting point: nanogpt.py

First, let’s get our initial Python environment set up. Starting from a fresh conda environment:

conda create --name nanogpt-rs
conda install ipykernel pytorch matplotlib autopep8 maturin

We’ll then generate our mixed Python/Rust project using Maturin:

mautrin new --mixed --bindings pyo3 nanogpt-rs

This will give us a starter project complete with a Rust library, a Python module wrapping that library, and the configuration to build and publish this package. From here, I copied my Python implementation and a simple Python Notebook to this folder. Feel free to clone the project starting at this point:

git clone --branch starter https://github.com/kcking/nanogpt-rs

Note that this implementation only contains features from the Youtube tutorial. Quite a few improvements have been added upstream since.

Make sure to select your new environment in VSCode for running Python Notebook cells. The first cell in nanogpt.ipynb should plot the count of each character in the TinyShakespeare dataset. Even when writing a model in Rust, I find it useful to have a Python Notebook open for experimentation and data analysis.

Rusty Data Processing

Processing data is one of the biggest wins for using Rust in an LLM environment. PyTorch Tensor operations are already optimized, backed by C++, and able to run on GPUs. But data processing code is generally bespoke Python, which can be slow and error prone.

Let’s set up a Python module backed by Rust that we can use to pre-process our TinyShakespeare dataset.

Using Maturin

The maturin starter project is immediately ready to be called from Python. Whenever we make a change to the Rust code, we just run maturin develop to build and install the new module in the current environment.

The nanogpt_rs module can now be imported into .py/.ipynb files:

import nanogpt_rs
# included in the default maturin template
print(nanogpt_rs.sum_as_string(1, 2))

NOTE: You will need to restart your ipynb kernel in VSCode to pick up any changes to local Rust modules.

We can also automatically re-run maturin develop using cargo-watch:

cargo install cargo-watch
cargo watch -s 'maturin develop' -w src

Preprocessing TinyShakespeare

The first step in training our LLM is defining the input space, or vocabulary, of the data. Since this is a character-based LLM, our vocabulary is just every unique character of the entire dataset. The implementation is almost as simple as Python’s set(text).

#[pyfunction]
fn unique_characters(s: &str) -> HashSet<char> {
    HashSet::from_iter(s.chars())
}

Now we can compare the two implementations in our Notebook using the timeit module:

import nanogpt
import nanogpt_rs
import timeit

# Test correctness
assert(set(nanogpt.chars) == nanogpt_rs.unique_characters(nanogpt.text))

iters = 100
py_time = timeit.timeit('set(nanogpt.text)', setup="import nanogpt", number=iters)
rs_time = timeit.timeit('nanogpt_rs.unique_characters(nanogpt.text)', setup="import nanogpt, nanogpt_rs", number=iters)

print(f"python: {py_time/iters:.4f}s/iter, rust: {rs_time/iters:.4f}s/iter")

which outputs

python: 0.0079s/iter, rust: 0.2327s/iter

whoa, wait a second… the Rust implementation is 30x slower! It turns out maturin develop compiles Rust in debug mode by default. We can add the --release flag to maturin to compile in release mode.

cargo watch -s 'maturin develop --release' -w src

this nets us a ~30x speedup: python: 0.0069s/iter, rust: 0.0082s/iter. Rust is now within striking distance, but still a bit slower.

ahash

Rust’s default HashSet implementation is based on Google’s SwissTable design which appears to be slower than Python’s. The ahash crate uses accelerated AES operations for hashing. We can use it as a drop-in replacement for std::collections::HashSet.

cargo add ahash
// use std::collections::HashSet
use ahash::HashSet;

This quick change brings our time to python: 0.0069s/iter, rust: 0.0029s/iter. Rust is now ~2.5x faster!

Custom Hash Function

If we want to push performance even further, we can write our own custom hash function. Our hash inputs are just single characters, and there’s only 65 of them in the dataset. We really only care about the Set in HashSet. The best hash function for this input may just be the identity function. We can implement it using Rust’s Hasher trait:

#[derive(Debug, Clone, Default)]
struct CharHasher {
    c: char,
}

impl Hasher for CharHasher {
    fn write(&mut self, bytes: &[u8]) {
        //  `char`s are fixed-width 4 bytes in Rust.
        if bytes.len() != 4 {
            return;
        }
        //  Gracefully fall-back to `0`
        let i = u32::from_ne_bytes(bytes.try_into().unwrap_or([0; 4]));
        //  Gracefully fall-back to 'a'
        self.c = char::from_u32(i).unwrap_or('a');
    }

    fn finish(&self) -> u64 {
        self.c as _
    }
}

impl BuildHasher for CharHasher {
    type Hasher = Self;

    fn build_hasher(&self) -> Self::Hasher {
        Self::default()
    }
}

We can then specify our custom Hasher in the function signature:

#[pyfunction]
fn unique_characters(s: &str) -> HashSet<char, CharHasher> {
    HashSet::from_iter(s.chars())
}

NOTE: This hash function only works for chars and relies on the implementation details of std::collections::HashMap. Specifically, the Hash implementation for char writes its bytes in native-endian. It could be made more generic with a little extra logic like tracking how many bytes have been written so far.

With our custom hash function, the new timings are python: 0.0069s/iter, rust: 0.0019s/iter – Rust is now more than 3.5x faster!

We could potentially push further and use SmallVec (a stack-allocated Vec), avoiding a HashSet altogether. I personally liked the Hasher solution kept all of the ergonomics of HashSet, but feel free to try it! With larger datasets, we can also start to consider streaming using async I/O and parallel processing using rayon.

Adding Python Types

Rust gives us a great type system to work with, but as currently configured, our Python code cannot take advantage of these types. Ideally pyo3 would generate Python type hints based on our Rust code, and this has been a long-standing request.

According to the pyo3 guide, the easiest way to add Python types is to manually add them in a .pyi file.

# python/nanogpt_rs/nanogpt_rs.pyi
def unique_characters(s: str) -> set[str]: ...

We’ll now get Python type errors if we try to pass anything besides a str to unique_characters ✨.

I had to reload my VSCode window for it to start picking up the new typings.

Wrapping Up

That’s it for Part 1 of this series! You now have the tools to go forth and replace critical pieces of Python with high-performance Rust. Hopefully you learned something along the way! All of the code for this post can be found in this GitHub Repo.

Stay tuned for my next post on writing PyTorch Modules (Neural Network Layers) in Rust!