Drex is an experimental transformer architecture with a four-tier memory hierarchy, implementing a validated episodic/semantic split associative memory module. The architecture was developed through 16 phases of hypothesis-driven research across 247+ experiments, with each architectural decision grounded in controlled ablation studies (SUPPORTED / INCONCLUSIVE / REFUTED verdicts, β₯2/3 seed confirmation).
The research specifically resolved a previously undocumented failure mode in EMA-based
associative memory at short sequence lengths β the EMA bootstrap problem β and confirmed
a validated fix (Ξ±(L) = 0.95^(96/L)) that keeps memory time constants calibrated
across L=16β128. The full research log is in research/experiments/ (cat1βcat48, 247+
experiment result files).
Current state (Phase 16 complete / Phase 17 in progress): The architecture is fully implemented, hardened, and training-ready. A 9-page arXiv preprint has been drafted (
paper/main.tex). Exp A (baseline) is running (step ~16k/50k); Exp B (episodic memory) will auto-start on Exp A completion. Evaluation results (passkey recall, BABILong) are pending final checkpoints β that is the current milestone. See Current Results.
Drex uses a four-tier memory hierarchy:
| Layer | Mechanism | Scope |
|---|---|---|
| L1 | Sliding-window causal attention | In-context (short range) |
| L2 | Infini-Attention delta-rule matrix | Cross-segment (medium range) |
| L3 | Titans-style MLP weight snapshots | Disk (long range, async) |
| L4 | Episodic/semantic split delta-rule | Per-segment associative recall |
The episodic/semantic memory layer is the primary research contribution. Key properties:
- Two H/2 associative matrices:
M_sem(semantic, uniform weight) andM_epi(episodic, recency-weighted writes) - Delta-rule update:
Ξ = (k β MkΜ) β kΜ, written via EMA with(1βΞ±)smoothing - Length-adaptive EMA:
Ξ±(L) = 0.95^(96/L)β keeps Ο/L β 0.21 constant across L=16β128, solving the EMA bootstrap failure at short sequences (Phase 11) - OR relative-norm write gate: fires when
βk β vpβ β₯ threshΒ·βkβon either branch; thresh*=0.70 (confirmed exp_48_1, Phase 12) - Null retrieval gate: learned scalar
g = Ο(wΒ·q)suppresses empty-memory reads - Soft concatenated retrieval:
concat(r_sem, r_epi)β no learned combination gate (exp_38_3 ruled this out)
Validated write rates at thresh=0.70:
- L=32: wr=0.581 (target: 0.20β0.70) β
- L=96: wr=0.421 (target: 0.15β0.50) β
See ARCHITECTURE_FINDINGS.md for the full specification, confidence classifications, and the complete list of research dead ends.
Status: no published end-to-end benchmark yet.
The architecture components are fully validated (exp_48_1, Phase 12; 199-test suite, 100% branch coverage). The production implementation trains on TinyStories with write rates in-range. The following comparisons are planned:
| Experiment | Config | Status |
|---|---|---|
| Exp A β baseline | DrexTransformer (no MemoryModule), 256d, 4L, 512 ctx | Pending |
| Exp B β episodic memory | DrexTransformer + MemoryModule (thresh=0.70), same config | Pending |
Evaluation targets: passkey recall at 2k/4k/8k/16k context; BABILong tasks 1β5 at 2k/4k/8k context. Results will be published at results/TRAINING_RUNS.md.
- Python β₯ 3.11
- Rust toolchain (for the
drex._sysextension β SnapshotStore, PrefetchEngine) - PyTorch β₯ 2.3.0
git clone https://github.com/wesleyscholl/drex.git
cd drex
# Install maturin (Rust-Python build tool)
pip install maturin
# Build and install the Rust extension + Python package
maturin develop --releasepip install -e ".[dev]"Network note: If your environment has a corporate SSL proxy, training data download may fail with a certificate error. Use
--no-ssl-verify(development only) or--data-file /path/to/tinystories.txtto load from a local file.
python scripts/train.py \
--steps 1000 --log-every 100 \
--d-model 128 --n-layers 3 --n-heads 4 \
--use-episodic-memory --no-ssl-verifypython scripts/train.py \
--steps 50000 --log-every 200 \
--d-model 256 --n-layers 4 --n-heads 4 --ff-mult 4 \
--batch-size 8 --segment-len 512 --dropout 0.1 \
--lr 3e-4 --warmup-steps 2000 \
--val-every 1000 --val-max-chars 500000 \
--reset-on-boundary \
--ckpt-dir checkpoints/exp_a \
--no-ssl-verifypython scripts/train.py \
--steps 50000 --log-every 200 \
--d-model 256 --n-layers 4 --n-heads 4 --ff-mult 4 \
--batch-size 8 --segment-len 512 --dropout 0.1 \
--lr 3e-4 --warmup-steps 2000 \
--use-episodic-memory --episodic-gate-thresh 0.70 \
--val-every 1000 --val-max-chars 500000 \
--reset-on-boundary \
--ckpt-dir checkpoints/exp_b \
--no-ssl-verifypython scripts/eval_passkey.py \
--checkpoint checkpoints/exp_b/step_0050000_final.safetensors \
--use-episodic-memory --episodic-gate-thresh 0.70 \
--report-write-rate \
--lengths 2048 4096 8192 16384python scripts/eval_babilong.py \
--checkpoint checkpoints/exp_b/step_0050000_final.safetensors \
--use-episodic-memory --episodic-gate-thresh 0.70 \
--lengths 2048 4096 8192 \
--tasks 1 2 3 4 5 --trials 10from drex.models.memory import MemoryModule
for module in model.modules():
if isinstance(module, MemoryModule):
wr = module.last_write_rate()
module.assert_write_rate_valid() # raises if outside [0.10, 0.85]# Run full test suite with 100% branch coverage requirement
PYTHONPATH=python pytest tests/python/
# Run a specific test class
PYTHONPATH=python pytest tests/python/test_memory.py::TestMemoryModule -v241 tests, 100% branch coverage (enforced by pytest --cov configuration).
drex/
βββ python/drex/
β βββ models/
β β βββ memory.py # MemoryModule (L4), MemoryState (L2), TitanMemory (L3)
β β βββ attention.py # SlidingWindowAttention, InfiniAttention, HybridAttention
β β βββ transformer.py # DrexConfig, DrexLayer, DrexTransformer
β βββ training/
β β βββ data.py # SegmentDataset, collate_fn, tokenize_chars
β β βββ optimizer.py # build_optimizer, cosine_schedule_with_warmup
β β βββ trainer.py # DrexTrainer (TBPTT, grad clip, segment loop)
β βββ eval/
β β βββ passkey.py # PasskeyBenchmark (multi-length passkey recall)
β β βββ babilong.py # BABILongBenchmark (5-task Q&A)
β βββ utils/
β βββ config.py # save_checkpoint, load_checkpoint
βββ src/ # Rust source (SnapshotStore, PrefetchEngine)
βββ scripts/
β βββ train.py # TinyStories training (write-rate monitoring, NaN guard)
β βββ eval_passkey.py # Passkey recall CLI (+ density sweep)
β βββ eval_babilong.py # BABILong 5-task evaluation CLI
βββ tests/python/ # 241 tests, 100% branch coverage
βββ research/experiments/ # 247+ research experiments (cat1βcat48)
βββ results/ # Training run results and comparisons
βββ PLAN.md # Implementation roadmap (Phases 1β16; Phases 17β21 planned)
βββ ARCHITECTURE_FINDINGS.md # Full spec + dead ends + confidence classifications
βββ CLAUDE.md # Project conventions for AI collaboration
16 phases of hypothesis-driven experimentation established the architecture:
- Phases 1β4: Established delta-rule update, ELU+1 feature map, L2/L3 baseline
- Phases 5β6: Ruled out offline consolidation, hierarchical routing
- Phases 7β8: Confirmed outer-product write, eliminated bidirectional rule
- Phases 9β10: Confirmed relative-norm gate at thresh=0.40; ruled out regularisation and two-phase training
- Phase 11 (exp_47): Discovered and resolved EMA bootstrap failure at Lβ€32 β
Ξ±(L)=0.95^(96/L)keeps Ο/L β 0.21 constant across all sequence lengths - Phase 12 (exp_48): Confirmed thresh*=0.70 for OR-gate full system; wr(L=32)=0.581, wr(L=96)=0.421 β both within validated target ranges
- Phase 13: Production
MemoryModuleimplementation; 197-test suite, 100% coverage - Phase 14: Training script integration; write-rate monitoring, validation loss, BABILong CLI, passkey density sweep
- Phase 15: Stability hardening β NaN guard, TBPTT boundary reset, vectorized write
loop,
F.normalizeeps fix; 199 tests total, 100% branch coverage - Phase 16: Pre-publication hardening β write loop CPU backend + detach (543β2,310
tok/s at L=512), output
norm_outLayerNorm, multi-seed ablations (full-seq-residual INCONCLUSIVE; last-layer-only same quality at 2.7Γ speedup; null-gate +0.30 ppl), checkpoint resume optimizer/scheduler fix, arXiv paper draft (paper/main.tex); 241 tests, 100% branch coverage
All architectural decisions have evidence trails in research/experiments/. Dead ends
are documented in ARCHITECTURE_FINDINGS.md Β§9.
MIT