Skip to content

mplekh/rust-matrixmicrogpt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

rust-flashmicrogpt

A small GPT-style language model training prototype in Rust, using a hand-written matrix backend (column-major storage, no external tensor framework) and a FlashAttention-style attention implementation.

Origin

This project is based on:

  • Andrej Karpathy's microgpt concept and reference implementation.
  • The FlashAttention algorithm (Dao et al., 2022).
  • DeepSeek's Manifold-Constrained Hyper-Connections (mHC) for residual stream management.

How this version is different

  • Uses an explicit Matrix module (src/algebra.rs) instead of a tape/autodiff node graph.
  • Keeps computations in matrix form for forward and backward passes (src/model.rs).
  • Uses column-major matrix storage for performance.
  • Implements FlashAttention-style tiled attention with online softmax — the full N×N attention matrix is never materialized.
  • The backward pass recomputes attention weights on-the-fly from Q, K, V and stored per-query statistics (max, sum-of-exponentials), instead of loading stored attention probability matrices.
  • Replaces standard residual connections with Manifold-Constrained Hyper-Connections (mHC): the hidden state is expanded into 2 parallel streams, mixed via a doubly-stochastic matrix (Birkhoff polytope constraint), aggregated before each sublayer, and distributed back after.

What this program does

  • Loads a text dataset from input.txt (line-based documents).
  • Builds a character-level vocabulary from the dataset.
  • Trains a tiny decoder-only transformer on next-token prediction.
  • Prints periodic training loss.
  • Runs autoregressive sampling after training to generate text.

How it works

  1. Data preparation (src/main.rs)
  • Reads non-empty lines from input.txt.
  • Shuffles documents with a Python-compatible MT19937 RNG.
  • Builds charset and adds a special BOS token id.
  1. Model definition (src/model.rs)
  • Token embedding (wte) + positional embedding (wpe).
  • One attention block:
    • mHC: doubly-stochastic stream mixing → aggregation
    • RMSNorm
    • Multi-head causal self-attention (wq, wk, wv, wo) using FlashAttention
    • mHC: distribute sublayer output back to streams
  • One MLP block:
    • mHC: doubly-stochastic stream mixing → aggregation
    • RMSNorm
    • fc1 -> ReLU -> fc2
    • mHC: distribute sublayer output back to streams
  • Contraction from 2 streams back to 1, then final projection to vocabulary logits (lm_head).
  1. Manifold-Constrained Hyper-Connections (src/model.rs)
  • Expands the hidden state from 1 stream to 2 parallel streams after embedding.
  • Before each sublayer: mixes streams with a doubly-stochastic matrix A = [[t, 1-t], [1-t, t]] where t = sigmoid(learned_param) (Birkhoff polytope constraint, equivalent to Sinkhorn-Knopp for 2×2), then aggregates to a single stream via learned weights.
  • After each sublayer: distributes the output back to both streams via learned weights, added to the mixed residual.
  • Contracts back to 1 stream before the final projection.
  • Initialized to recover standard residual connections (alpha=[0.5,0.5], beta=[1.0,1.0], contract=[0.5,0.5]).
  • Adds 12 learnable scalar parameters (7 small matrices) with full backward pass support.
  1. FlashAttention (src/model.rs)
  • Tiled forward pass: processes Q, K, V in blocks (tile size 4), using online softmax to incrementally compute attention output without allocating the full attention matrix.
  • Stores only per-query max and sum-of-exponentials (2N floats per head) instead of the full N×N attention probability matrix.
  • Tiled backward pass: recomputes attention scores and probabilities on-the-fly from Q, K, V and the stored statistics, then accumulates gradients for Q, K, V in a single pass.
  1. Math primitives (src/algebra.rs)
  • Column-major Matrix type.
  • Core ops: matmul, transpose, ReLU, softmax-by-column, scaling.
  • RMSNorm forward/backward utilities.
  1. Training loop (src/main.rs)
  • For each step, creates a token sequence:
    • [BOS] + encoded_document + [BOS]
    • truncated to the model block size before the final BOS.
  • Forward pass -> logits.
  • Cross-entropy loss and logits gradient.
  • Backward pass to compute gradients for all parameters.
  • Parameter update with Adam-like optimizer (step_all).
  1. Inference (src/main.rs)
  • Starts from BOS.
  • Repeats:
    • forward pass on current prefix
    • temperature scaling
    • softmax + weighted sampling
    • stop on BOS or max length.

Run

cargo run --release

About

A small GPT-style language model training prototype in Rust, using a hand-written matrix backend

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages