Skip to content

Baantu Research: Hybrid KAN-Transformer for investigating learnable activations in LLM reasoning. Built on nanochat by Andrej Karpathy.

License

Notifications You must be signed in to change notification settings

stchakwdev/kan_transformer

Repository files navigation

Baantu Research: Hybrid KAN-Transformer

Baantu Research Logo

Investigating Kolmogorov-Arnold Networks for Language Model Reasoning

This research project explores whether replacing standard MLP layers with KAN (Kolmogorov-Arnold Networks) in transformer architectures improves reasoning capabilities. Built on top of nanochat by Andrej Karpathy.

The Experiment

We hypothesize that KAN's learnable activation functions (B-splines on edges) may provide advantages for mathematical reasoning and function approximation compared to fixed ReLU activations in standard MLPs.

Key Questions:

  • Does replacing MLP layers with KAN improve performance on reasoning tasks?
  • Which layer positions benefit most from KAN layers?
  • What is the trade-off between KAN's expressiveness and computational cost?

Architecture Comparison

================================================================================
                    STANDARD TRANSFORMER (All MLP)
================================================================================

     ┌─────────────────────────────────────────────┐
     │           TRANSFORMER BLOCK (x20)           │
     │                                             │
     │  ┌─────────────────────────────────────┐    │
     │  │         ATTENTION                   │    │
     │  └─────────────────────────────────────┘    │
     │                  │                          │
     │                  v                          │
     │  ┌─────────────────────────────────────┐    │
     │  │            MLP                      │    │
     │  │   Input ──> [ReLU²] ──> Output      │    │  <-- FIXED activation
     │  │                                     │    │
     │  └─────────────────────────────────────┘    │
     └─────────────────────────────────────────────┘

================================================================================
                 HYBRID KAN-TRANSFORMER (Last 2 layers = KAN)
================================================================================

     ┌─────────────────────────────────────────────┐
     │      TRANSFORMER BLOCKS 0-17 (MLP)          │
     │      Standard MLP with fixed ReLU²          │
     └─────────────────────────────────────────────┘
                         │
                         v
     ┌─────────────────────────────────────────────┐
     │      TRANSFORMER BLOCKS 18-19 (KAN)         │  <-- THE DIFFERENCE
     │                                             │
     │  ┌─────────────────────────────────────┐    │
     │  │           KAN LAYER                 │    │
     │  │   Input ──> [φ(x)] ──> Output       │    │  <-- LEARNABLE activations
     │  │        (B-spline functions)         │    │      (adapts during training)
     │  └─────────────────────────────────────┘    │
     └─────────────────────────────────────────────┘

MLP vs KAN: The Core Difference

MLP (Multi-Layer Perceptron):
==============================
     x1 ───┐
           ├───O──[ReLU²]──O───> output
     x2 ───┤
           ├───────────────────────────
     x3 ───┘

  - Weights on EDGES (learnable)
  - Activation on NODES (fixed ReLU²)
  - Same activation for all inputs


KAN (Kolmogorov-Arnold Network):
=================================
     x1 ──[φ₁(x)]──┐
                   ├───O───> output
     x2 ──[φ₂(x)]──┤
                   ├───────────────
     x3 ──[φ₃(x)]──┘

  - Activation on EDGES (learnable B-splines!)
  - Sum on NODES (just addition)
  - Each edge learns its own activation function

KAN Placement Options

We support multiple KAN layer placement strategies via --kan_layers:

Option Layers Hypothesis
last2 Final 2 layers KAN helps final reasoning (default)
first2 First 2 layers KAN helps feature extraction
every4 Every 4th layer Distributed KAN throughout
middle2 Middle 2 layers KAN helps intermediate processing
all All layers Full KAN transformer

Configuration

# Standard MLP baseline
python -m scripts.base_train --use_kan=False

# Hybrid KAN-Transformer (last 2 layers)
python -m scripts.base_train --use_kan=True --kan_layers=last2

# Full KAN transformer
python -m scripts.base_train --use_kan=True --kan_layers=all

# With gradient checkpointing (saves memory)
python -m scripts.base_train --use_kan=True --use_grad_checkpoint=True

Quick Start

Environment Setup

# Clone the repo
git clone https://github.com/stchakwdev/nanochat.git
cd nanochat

# Install uv package manager
curl -LsSf https://astral.sh/uv/install.sh | sh

# Create environment and install dependencies
uv venv
source .venv/bin/activate
uv sync --extra gpu

# Build Rust tokenizer
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml

Training

Full Pipeline (Hybrid KAN-Transformer on 8xH100):

# Download training data
python -m kan_transformer.dataset -n 240

# Train tokenizer
python -m scripts.tok_train --max_chars=2000000000

# Base model pretraining with KAN
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
    --depth=20 --use_kan=True --kan_layers=last2

# Midtraining (conversation format + tools)
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train

# Supervised Fine-tuning
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft

# Evaluate
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft

Inference

# CLI chat
python -m scripts.chat_cli

# Web UI
python -m scripts.chat_web

Ablation Study Plan

Experiment KAN Layers Est. Params Hypothesis
Baseline None (MLP only) 561M Control group
Last2 18, 19 ~800M Best for reasoning
First2 0, 1 ~800M Best for features
Every4 3,7,11,15,19 ~1B Balanced approach
All 0-19 ~1.4B Maximum expressiveness

Results

Coming soon after experiments complete.

Evaluation Tasks

  • GSM8K: Grade school math reasoning
  • HumanEval: Python code generation
  • ARC-Easy/Challenge: Science reasoning
  • MMLU: Multi-subject knowledge

File Structure

.
├── kan_transformer/
│   ├── gpt.py              # GPT + KAN_MLP implementation
│   ├── engine.py           # Inference with KV cache
│   ├── dataloader.py       # Streaming data loader
│   ├── tokenizer.py        # BPE tokenizer
│   └── ...
├── scripts/
│   ├── base_train.py       # Pretraining (with KAN options)
│   ├── mid_train.py        # Midtraining
│   ├── chat_sft.py         # Supervised finetuning
│   ├── chat_eval.py        # Evaluation
│   └── ...
├── tasks/                   # Evaluation datasets
├── rustbpe/                 # Rust BPE tokenizer
└── tests/

Citation

This Work

@misc{baantu-kan-transformer,
  author = {Samuel Tchakwera and Baantu Research},
  title = {Hybrid KAN-Transformer: Investigating Learnable Activations for LLM Reasoning},
  year = {2025},
  publisher = {GitHub},
  url = {https://github.com/stchakwdev/nanochat}
}

Original Codebase

This project builds upon nanochat by Andrej Karpathy:

@misc{nanochat,
  author = {Andrej Karpathy},
  title = {nanochat: The best ChatGPT that $100 can buy},
  year = {2025},
  publisher = {GitHub},
  url = {https://github.com/karpathy/nanochat}
}

KAN Architecture

@misc{liu2024kan,
  title = {KAN: Kolmogorov-Arnold Networks},
  author = {Ziming Liu and others},
  year = {2024},
  eprint = {2404.19756},
  archivePrefix = {arXiv}
}

Acknowledgements

  • Andrej Karpathy for the excellent nanochat foundation
  • AthanasiosDelis for efficient-kan
  • Ziming Liu et al. for the original KAN paper
  • HuggingFace for datasets and model hosting

License

MIT

About

Baantu Research: Hybrid KAN-Transformer for investigating learnable activations in LLM reasoning. Built on nanochat by Andrej Karpathy.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published