nanoGPT + Rust :: Part 1
Mixing Python Notebooks + Rust
I recently followed Andrej Karpathy’s Let’s build GPT: from scratch, in code, spelled out, implementing a simplified GPT in Python. The video is a fantastic learning resource. I recommend everyone go through it themselves to understand how transformers work.
Out of curiosity, I set up a hybrid Rust + PyTorch development environment and started re-implementing nanoGPT in Rust using the PyTorch Rust bindings.
This post is the first of a series that will walk through this process. In this post, we’ll cover integrating Rust into Python and Notebooks. We’ll use Rust to optimize a (simple) Python text pre-processing function.
Starting point: nanogpt.py
First, let’s get our initial Python environment set up. Starting from a fresh conda
environment:
conda create --name nanogpt-rs
conda install ipykernel pytorch matplotlib autopep8 maturin
We’ll then generate our mixed Python/Rust project using Maturin:
mautrin new --mixed --bindings pyo3 nanogpt-rs
This will give us a starter project complete with a Rust library, a Python module wrapping that library, and the configuration to build and publish this package. From here, I copied my Python implementation and a simple Python Notebook to this folder. Feel free to clone the project starting at this point:
git clone --branch starter https://github.com/kcking/nanogpt-rs
Note that this implementation only contains features from the Youtube tutorial. Quite a few improvements have been added upstream since.
Make sure to select your new environment in VSCode for running Python Notebook cells. The first cell in nanogpt.ipynb
should plot the count of each character in the TinyShakespeare dataset. Even when writing a model in Rust, I find it useful to have a Python Notebook open for experimentation and data analysis.
Rusty Data Processing
Processing data is one of the biggest wins for using Rust in an LLM environment. PyTorch Tensor
operations are already optimized, backed by C++, and able to run on GPUs. But data processing code is generally bespoke Python, which can be slow and error prone.
Let’s set up a Python module backed by Rust that we can use to pre-process our TinyShakespeare dataset.
Using Maturin
The maturin starter project is immediately ready to be called from Python. Whenever we make a change to the Rust code, we just run maturin develop
to build and install the new module in the current environment.
The nanogpt_rs
module can now be imported into .py
/.ipynb
files:
import nanogpt_rs
# included in the default maturin template
print(nanogpt_rs.sum_as_string(1, 2))
NOTE: You will need to restart your ipynb kernel in VSCode to pick up any changes to local Rust modules.
We can also automatically re-run maturin develop
using cargo-watch
:
cargo install cargo-watch
cargo watch -s 'maturin develop' -w src
Preprocessing TinyShakespeare
The first step in training our LLM is defining the input space, or vocabulary, of the data. Since this is a character-based LLM, our vocabulary is just every unique character of the entire dataset. The implementation is almost as simple as Python’s set(text)
.
#[pyfunction]
fn unique_characters(s: &str) -> HashSet<char> {
HashSet::from_iter(s.chars())
}
Now we can compare the two implementations in our Notebook using the timeit
module:
import nanogpt
import nanogpt_rs
import timeit
# Test correctness
assert(set(nanogpt.chars) == nanogpt_rs.unique_characters(nanogpt.text))
iters = 100
py_time = timeit.timeit('set(nanogpt.text)', setup="import nanogpt", number=iters)
rs_time = timeit.timeit('nanogpt_rs.unique_characters(nanogpt.text)', setup="import nanogpt, nanogpt_rs", number=iters)
print(f"python: {py_time/iters:.4f}s/iter, rust: {rs_time/iters:.4f}s/iter")
which outputs
python: 0.0079s/iter, rust: 0.2327s/iter
whoa, wait a second… the Rust implementation is 30x slower! It turns out maturin develop
compiles Rust in debug mode by default. We can add the --release
flag to maturin to compile in release mode.
cargo watch -s 'maturin develop --release' -w src
this nets us a ~30x speedup: python: 0.0069s/iter, rust: 0.0082s/iter
. Rust is now within striking distance, but still a bit slower.
ahash
Rust’s default HashSet
implementation is based on Google’s SwissTable design which appears to be slower than Python’s. The ahash
crate uses accelerated AES operations for hashing. We can use it as a drop-in replacement for std::collections::HashSet
.
cargo add ahash
// use std::collections::HashSet
use ahash::HashSet;
This quick change brings our time to python: 0.0069s/iter, rust: 0.0029s/iter
. Rust is now ~2.5x faster!
Custom Hash Function
If we want to push performance even further, we can write our own custom hash function. Our hash inputs are just single characters, and there’s only 65 of them in the dataset. We really only care about the Set
in HashSet
. The best hash function for this input may just be the identity function. We can implement it using Rust’s Hasher
trait:
#[derive(Debug, Clone, Default)]
struct CharHasher {
c: char,
}
impl Hasher for CharHasher {
fn write(&mut self, bytes: &[u8]) {
// `char`s are fixed-width 4 bytes in Rust.
if bytes.len() != 4 {
return;
}
// Gracefully fall-back to `0`
let i = u32::from_ne_bytes(bytes.try_into().unwrap_or([0; 4]));
// Gracefully fall-back to 'a'
self.c = char::from_u32(i).unwrap_or('a');
}
fn finish(&self) -> u64 {
self.c as _
}
}
impl BuildHasher for CharHasher {
type Hasher = Self;
fn build_hasher(&self) -> Self::Hasher {
Self::default()
}
}
We can then specify our custom Hasher in the function signature:
#[pyfunction]
fn unique_characters(s: &str) -> HashSet<char, CharHasher> {
HashSet::from_iter(s.chars())
}
NOTE: This hash function only works for
char
s and relies on the implementation details ofstd::collections::HashMap
. Specifically, the Hash implementation forchar
writes its bytes in native-endian. It could be made more generic with a little extra logic like tracking how many bytes have been written so far.
With our custom hash function, the new timings are python: 0.0069s/iter, rust: 0.0019s/iter
– Rust is now more than 3.5x faster!
We could potentially push further and use SmallVec
(a stack-allocated Vec
), avoiding a HashSet
altogether. I personally liked the Hasher
solution kept all of the ergonomics of HashSet
, but feel free to try it! With larger datasets, we can also start to consider streaming using async I/O and parallel processing using rayon
.
Adding Python Types
Rust gives us a great type system to work with, but as currently configured, our Python code cannot take advantage of these types. Ideally pyo3
would generate Python type hints based on our Rust code, and this has been a long-standing request.
According to the pyo3
guide, the easiest way to add Python types is to manually add them in a .pyi
file.
# python/nanogpt_rs/nanogpt_rs.pyi
def unique_characters(s: str) -> set[str]: ...
We’ll now get Python type errors if we try to pass anything besides a str
to unique_characters
✨.
I had to reload my VSCode window for it to start picking up the new typings.
Wrapping Up
That’s it for Part 1 of this series! You now have the tools to go forth and replace critical pieces of Python with high-performance Rust. Hopefully you learned something along the way! All of the code for this post can be found in this GitHub Repo.
Stay tuned for my next post on writing PyTorch Module
s (Neural Network Layers) in Rust!