Skip to content

wesleyscholl/drex

Repository files navigation

Drex

Drex is an experimental transformer architecture with a four-tier memory hierarchy, implementing a validated episodic/semantic split associative memory module. The architecture was developed through 16 phases of hypothesis-driven research across 247+ experiments, with each architectural decision grounded in controlled ablation studies (SUPPORTED / INCONCLUSIVE / REFUTED verdicts, β‰₯2/3 seed confirmation).

The research specifically resolved a previously undocumented failure mode in EMA-based associative memory at short sequence lengths β€” the EMA bootstrap problem β€” and confirmed a validated fix (Ξ±(L) = 0.95^(96/L)) that keeps memory time constants calibrated across L=16–128. The full research log is in research/experiments/ (cat1–cat48, 247+ experiment result files).

Current state (Phase 16 complete / Phase 17 in progress): The architecture is fully implemented, hardened, and training-ready. A 9-page arXiv preprint has been drafted (paper/main.tex). Exp A (baseline) is running (step ~16k/50k); Exp B (episodic memory) will auto-start on Exp A completion. Evaluation results (passkey recall, BABILong) are pending final checkpoints β€” that is the current milestone. See Current Results.

Architecture

Drex uses a four-tier memory hierarchy:

Layer Mechanism Scope
L1 Sliding-window causal attention In-context (short range)
L2 Infini-Attention delta-rule matrix Cross-segment (medium range)
L3 Titans-style MLP weight snapshots Disk (long range, async)
L4 Episodic/semantic split delta-rule Per-segment associative recall

L4 MemoryModule (Phase 13, validated)

The episodic/semantic memory layer is the primary research contribution. Key properties:

  • Two H/2 associative matrices: M_sem (semantic, uniform weight) and M_epi (episodic, recency-weighted writes)
  • Delta-rule update: Ξ” = (k βˆ’ MkΜ‚) βŠ— kΜ‚, written via EMA with (1βˆ’Ξ±) smoothing
  • Length-adaptive EMA: Ξ±(L) = 0.95^(96/L) β€” keeps Ο„/L β‰ˆ 0.21 constant across L=16–128, solving the EMA bootstrap failure at short sequences (Phase 11)
  • OR relative-norm write gate: fires when β€–k βˆ’ vpβ€– β‰₯ threshΒ·β€–kβ€– on either branch; thresh*=0.70 (confirmed exp_48_1, Phase 12)
  • Null retrieval gate: learned scalar g = Οƒ(wΒ·q) suppresses empty-memory reads
  • Soft concatenated retrieval: concat(r_sem, r_epi) β€” no learned combination gate (exp_38_3 ruled this out)

Validated write rates at thresh=0.70:

  • L=32: wr=0.581 (target: 0.20–0.70) βœ“
  • L=96: wr=0.421 (target: 0.15–0.50) βœ“

See ARCHITECTURE_FINDINGS.md for the full specification, confidence classifications, and the complete list of research dead ends.

Current Results

Status: no published end-to-end benchmark yet.

The architecture components are fully validated (exp_48_1, Phase 12; 199-test suite, 100% branch coverage). The production implementation trains on TinyStories with write rates in-range. The following comparisons are planned:

Experiment Config Status
Exp A β€” baseline DrexTransformer (no MemoryModule), 256d, 4L, 512 ctx Pending
Exp B β€” episodic memory DrexTransformer + MemoryModule (thresh=0.70), same config Pending

Evaluation targets: passkey recall at 2k/4k/8k/16k context; BABILong tasks 1–5 at 2k/4k/8k context. Results will be published at results/TRAINING_RUNS.md.

Installation

Prerequisites

  • Python β‰₯ 3.11
  • Rust toolchain (for the drex._sys extension β€” SnapshotStore, PrefetchEngine)
  • PyTorch β‰₯ 2.3.0

Build

git clone https://github.com/wesleyscholl/drex.git
cd drex

# Install maturin (Rust-Python build tool)
pip install maturin

# Build and install the Rust extension + Python package
maturin develop --release

Development install

pip install -e ".[dev]"

Network note: If your environment has a corporate SSL proxy, training data download may fail with a certificate error. Use --no-ssl-verify (development only) or --data-file /path/to/tinystories.txt to load from a local file.

Usage

Training

Quick smoke run

python scripts/train.py \
  --steps 1000 --log-every 100 \
  --d-model 128 --n-layers 3 --n-heads 4 \
  --use-episodic-memory --no-ssl-verify

Baseline comparison (Experiment A β€” no episodic memory)

python scripts/train.py \
  --steps 50000 --log-every 200 \
  --d-model 256 --n-layers 4 --n-heads 4 --ff-mult 4 \
  --batch-size 8 --segment-len 512 --dropout 0.1 \
  --lr 3e-4 --warmup-steps 2000 \
  --val-every 1000 --val-max-chars 500000 \
  --reset-on-boundary \
  --ckpt-dir checkpoints/exp_a \
  --no-ssl-verify

Episodic memory run (Experiment B β€” thresh*=0.70)

python scripts/train.py \
  --steps 50000 --log-every 200 \
  --d-model 256 --n-layers 4 --n-heads 4 --ff-mult 4 \
  --batch-size 8 --segment-len 512 --dropout 0.1 \
  --lr 3e-4 --warmup-steps 2000 \
  --use-episodic-memory --episodic-gate-thresh 0.70 \
  --val-every 1000 --val-max-chars 500000 \
  --reset-on-boundary \
  --ckpt-dir checkpoints/exp_b \
  --no-ssl-verify

Passkey recall evaluation

python scripts/eval_passkey.py \
    --checkpoint checkpoints/exp_b/step_0050000_final.safetensors \
    --use-episodic-memory --episodic-gate-thresh 0.70 \
    --report-write-rate \
    --lengths 2048 4096 8192 16384

BABILong evaluation

python scripts/eval_babilong.py \
    --checkpoint checkpoints/exp_b/step_0050000_final.safetensors \
    --use-episodic-memory --episodic-gate-thresh 0.70 \
    --lengths 2048 4096 8192 \
    --tasks 1 2 3 4 5 --trials 10

Write-rate monitoring

from drex.models.memory import MemoryModule

for module in model.modules():
    if isinstance(module, MemoryModule):
        wr = module.last_write_rate()
        module.assert_write_rate_valid()  # raises if outside [0.10, 0.85]

Testing

# Run full test suite with 100% branch coverage requirement
PYTHONPATH=python pytest tests/python/

# Run a specific test class
PYTHONPATH=python pytest tests/python/test_memory.py::TestMemoryModule -v

241 tests, 100% branch coverage (enforced by pytest --cov configuration).

Project Structure

drex/
β”œβ”€β”€ python/drex/
β”‚   β”œβ”€β”€ models/
β”‚   β”‚   β”œβ”€β”€ memory.py          # MemoryModule (L4), MemoryState (L2), TitanMemory (L3)
β”‚   β”‚   β”œβ”€β”€ attention.py       # SlidingWindowAttention, InfiniAttention, HybridAttention
β”‚   β”‚   └── transformer.py     # DrexConfig, DrexLayer, DrexTransformer
β”‚   β”œβ”€β”€ training/
β”‚   β”‚   β”œβ”€β”€ data.py            # SegmentDataset, collate_fn, tokenize_chars
β”‚   β”‚   β”œβ”€β”€ optimizer.py       # build_optimizer, cosine_schedule_with_warmup
β”‚   β”‚   └── trainer.py         # DrexTrainer (TBPTT, grad clip, segment loop)
β”‚   β”œβ”€β”€ eval/
β”‚   β”‚   β”œβ”€β”€ passkey.py         # PasskeyBenchmark (multi-length passkey recall)
β”‚   β”‚   └── babilong.py        # BABILongBenchmark (5-task Q&A)
β”‚   └── utils/
β”‚       └── config.py          # save_checkpoint, load_checkpoint
β”œβ”€β”€ src/                       # Rust source (SnapshotStore, PrefetchEngine)
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ train.py               # TinyStories training (write-rate monitoring, NaN guard)
β”‚   β”œβ”€β”€ eval_passkey.py        # Passkey recall CLI (+ density sweep)
β”‚   └── eval_babilong.py       # BABILong 5-task evaluation CLI
β”œβ”€β”€ tests/python/              # 241 tests, 100% branch coverage
β”œβ”€β”€ research/experiments/      # 247+ research experiments (cat1–cat48)
β”œβ”€β”€ results/                   # Training run results and comparisons
β”œβ”€β”€ PLAN.md                    # Implementation roadmap (Phases 1–16; Phases 17–21 planned)
β”œβ”€β”€ ARCHITECTURE_FINDINGS.md   # Full spec + dead ends + confidence classifications
└── CLAUDE.md                  # Project conventions for AI collaboration

Research Summary

16 phases of hypothesis-driven experimentation established the architecture:

  • Phases 1–4: Established delta-rule update, ELU+1 feature map, L2/L3 baseline
  • Phases 5–6: Ruled out offline consolidation, hierarchical routing
  • Phases 7–8: Confirmed outer-product write, eliminated bidirectional rule
  • Phases 9–10: Confirmed relative-norm gate at thresh=0.40; ruled out regularisation and two-phase training
  • Phase 11 (exp_47): Discovered and resolved EMA bootstrap failure at L≀32 β€” Ξ±(L)=0.95^(96/L) keeps Ο„/L β‰ˆ 0.21 constant across all sequence lengths
  • Phase 12 (exp_48): Confirmed thresh*=0.70 for OR-gate full system; wr(L=32)=0.581, wr(L=96)=0.421 β€” both within validated target ranges
  • Phase 13: Production MemoryModule implementation; 197-test suite, 100% coverage
  • Phase 14: Training script integration; write-rate monitoring, validation loss, BABILong CLI, passkey density sweep
  • Phase 15: Stability hardening β€” NaN guard, TBPTT boundary reset, vectorized write loop, F.normalize eps fix; 199 tests total, 100% branch coverage
  • Phase 16: Pre-publication hardening β€” write loop CPU backend + detach (543β†’2,310 tok/s at L=512), output norm_out LayerNorm, multi-seed ablations (full-seq-residual INCONCLUSIVE; last-layer-only same quality at 2.7Γ— speedup; null-gate +0.30 ppl), checkpoint resume optimizer/scheduler fix, arXiv paper draft (paper/main.tex); 241 tests, 100% branch coverage

All architectural decisions have evidence trails in research/experiments/. Dead ends are documented in ARCHITECTURE_FINDINGS.md Β§9.

License

MIT

Releases

No releases published

Packages

 
 
 

Contributors