Skip to content

Conversation

@severian42
Copy link

Summary

This PR contributes a faithful MLX port of the BDH architecture for Apple Silicon. It mirrors the original PyTorch implementation’s math and behavior (byte-level vocab, shared weights, sparse ReLU activations, RoPE attention with Q=K, non-affine LayerNorm), packaged with training and testing scripts. The goal is to make BDH easy to train on Mac M-series devices with unified memory and Metal acceleration.

Motivation

  • Enable native, efficient BDH training on Apple Silicon using MLX.
  • Keep the reference architecture aligned across frameworks with clear equivalence notes.
  • Lower setup friction for Mac developers and researchers.

What’s included

  • mlx/bdh_mlx.py: Core MLX model (BDH) matching PyTorch semantics
  • mlx/train_mlx.py: Training loop for byte-level LM on HF datasets (incl. IKM-style usage)
  • mlx/test_mlx.py: Sanity tests (config, instantiate, forward pass)
  • mlx/README-MLX.md: Detailed conversion notes, usage guide, and equivalence rationale
  • mlx/requirements-mlx.txt: Minimal deps for MLX workflow (optional if you prefer a single root requirements)

Design and equivalence notes

  • Parameter sharing: same encoder, decoder, and encoder_v reused across layers

  • Sparse activations: ReLU on latent projections to enforce non-negativity

  • Attention: RoPE applied with Q=K, causal mask with diagonal=-1 behavior

  • LayerNorm: non-affine, consistent with original

  • Output head: same shape, same initialization scale (normal std=0.02)

  • Differences are API-level only:

    • Explicit transpose permutations (e.g., transpose(0,2,1,3) instead of shorthand)
    • Reshape + then expand_dims (equivalent to PyTorch’s reshape with an added singleton)
    • Explicit mask multiply instead of in-place tril
  • Verified shape parity and operation order at each step; generation (top-k) is also equivalent

Performance (Apple Silicon, indicative)

  • Runs efficiently on M1/M2/M3/M4 with unified memory and Metal acceleration
  • Tokens/sec and memory usage comparable or better than MPS-based PyTorch on Mac
  • Intended for single-device training; no multi-GPU in MLX (by design)

How to use

  • pip install -r mlx/requirements-mlx.txt
  • python mlx/train_mlx.py
  • See mlx/README-MLX.md for full instructions, dataset options, and troubleshooting

Status and availability

  • I’m currently training the MLX build on an Internal Knowledge Map dataset; weights should be ready in about a day. I’ll publish checkpoints to Hugging Face once completed.

  • A standalone repo is also available here for reference: https://github.com/severian42/BDH-MLX

Impact

  • Does not modify PyTorch codepaths
  • Lives under mlx/ to keep concerns separated
  • Adds tests and docs targeted at MLX users

Questions for maintainers

  • Would you like this to live under a subdirectory (mlx/) as proposed?
  • Preference for separate requirements file vs. consolidated root requirements?
  • Any additional tests or CI steps you’d want before merging?

Checklist

  • Model parity with PyTorch (architectural and mathematical)
  • Training and generation scripts included
  • Minimal tests included (instantiate + forward)
  • Documentation added (mlx/README-MLX.md)
  • No changes to existing PyTorch paths

References

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant