A small GPT-style language model training prototype in Rust, using a hand-written matrix backend (column-major storage, no external tensor framework) and a FlashAttention-style attention implementation.
This project is based on:
- Andrej Karpathy's
microgptconcept and reference implementation. - The FlashAttention algorithm (Dao et al., 2022).
- DeepSeek's Manifold-Constrained Hyper-Connections (mHC) for residual stream management.
- Uses an explicit
Matrixmodule (src/algebra.rs) instead of a tape/autodiff node graph. - Keeps computations in matrix form for forward and backward passes (
src/model.rs). - Uses column-major matrix storage for performance.
- Implements FlashAttention-style tiled attention with online softmax — the full N×N attention matrix is never materialized.
- The backward pass recomputes attention weights on-the-fly from Q, K, V and stored per-query statistics (max, sum-of-exponentials), instead of loading stored attention probability matrices.
- Replaces standard residual connections with Manifold-Constrained Hyper-Connections (mHC): the hidden state is expanded into 2 parallel streams, mixed via a doubly-stochastic matrix (Birkhoff polytope constraint), aggregated before each sublayer, and distributed back after.
- Loads a text dataset from
input.txt(line-based documents). - Builds a character-level vocabulary from the dataset.
- Trains a tiny decoder-only transformer on next-token prediction.
- Prints periodic training loss.
- Runs autoregressive sampling after training to generate text.
- Data preparation (
src/main.rs)
- Reads non-empty lines from
input.txt. - Shuffles documents with a Python-compatible MT19937 RNG.
- Builds charset and adds a special BOS token id.
- Model definition (
src/model.rs)
- Token embedding (
wte) + positional embedding (wpe). - One attention block:
- mHC: doubly-stochastic stream mixing → aggregation
- RMSNorm
- Multi-head causal self-attention (
wq,wk,wv,wo) using FlashAttention - mHC: distribute sublayer output back to streams
- One MLP block:
- mHC: doubly-stochastic stream mixing → aggregation
- RMSNorm
fc1 -> ReLU -> fc2- mHC: distribute sublayer output back to streams
- Contraction from 2 streams back to 1, then final projection to vocabulary logits (
lm_head).
- Manifold-Constrained Hyper-Connections (
src/model.rs)
- Expands the hidden state from 1 stream to 2 parallel streams after embedding.
- Before each sublayer: mixes streams with a doubly-stochastic matrix A = [[t, 1-t], [1-t, t]] where t = sigmoid(learned_param) (Birkhoff polytope constraint, equivalent to Sinkhorn-Knopp for 2×2), then aggregates to a single stream via learned weights.
- After each sublayer: distributes the output back to both streams via learned weights, added to the mixed residual.
- Contracts back to 1 stream before the final projection.
- Initialized to recover standard residual connections (alpha=[0.5,0.5], beta=[1.0,1.0], contract=[0.5,0.5]).
- Adds 12 learnable scalar parameters (7 small matrices) with full backward pass support.
- FlashAttention (
src/model.rs)
- Tiled forward pass: processes Q, K, V in blocks (tile size 4), using online softmax to incrementally compute attention output without allocating the full attention matrix.
- Stores only per-query max and sum-of-exponentials (2N floats per head) instead of the full N×N attention probability matrix.
- Tiled backward pass: recomputes attention scores and probabilities on-the-fly from Q, K, V and the stored statistics, then accumulates gradients for Q, K, V in a single pass.
- Math primitives (
src/algebra.rs)
- Column-major
Matrixtype. - Core ops: matmul, transpose, ReLU, softmax-by-column, scaling.
- RMSNorm forward/backward utilities.
- Training loop (
src/main.rs)
- For each step, creates a token sequence:
[BOS] + encoded_document + [BOS]- truncated to the model block size before the final BOS.
- Forward pass -> logits.
- Cross-entropy loss and logits gradient.
- Backward pass to compute gradients for all parameters.
- Parameter update with Adam-like optimizer (
step_all).
- Inference (
src/main.rs)
- Starts from BOS.
- Repeats:
- forward pass on current prefix
- temperature scaling
- softmax + weighted sampling
- stop on BOS or max length.
cargo run --release