Skip to content

Conversation

@mntss
Copy link
Contributor

@mntss mntss commented Jun 6, 2025

Description

Adds support for QK normalization in attention layers and implements Qwen3 model weight conversion, enabling TransformerLens to work with Qwen3 models that use Query-Key normalization.

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Testing

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer

torch.set_grad_enabled(False)

MODEL_NAME = "Qwen/Qwen3-4B"


tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype="float32", trust_remote_code=True
).to("cuda")
hf_model.eval()

tl_model = HookedTransformer.from_pretrained_no_processing(MODEL_NAME).to("cuda").to(torch.float32)
tl_model.eval()

tokens = tokenizer("Hello world", return_tensors="pt")["input_ids"].to("cuda")

with torch.no_grad():
    hf_out = hf_model(tokens, output_hidden_states=True)
    tl_logits, cache = tl_model.run_with_cache(
        tokens, names_filter=lambda name: "resid_pre" in name or "resid_post" in name
    )

for i in range(tl_model.cfg.n_layers + 1):
    if i < tl_model.cfg.n_layers:
        diff = (cache[f"blocks.{i}.hook_resid_pre"] - hf_out.hidden_states[i]).abs().max()
    else:
        diff = (
            (tl_model.ln_final(cache[f"blocks.{i-1}.hook_resid_post"]) - hf_out.hidden_states[-1])
            .abs()
            .max()
        )
    print(f"Layer {i}: max hidden state diff {diff.item():.6f}")


logit_diff = (tl_logits - hf_out.logits).abs().max()
print(f"Max logit diff: {logit_diff.item():.6f}")

Output:

➜ python /workspace/TransformerLens/debugging/qwen3_hf_compare.py
Loading checkpoint shards: 100%|█████████████████| 3/3 [00:29<00:00,  9.68s/it]
WARNING:root:Loading model Qwen/Qwen3-4B requires setting trust_remote_code=True
WARNING:root:Loading model Qwen/Qwen3-4B state dict requires setting trust_remote_code=True
Loading checkpoint shards: 100%|█████████████████| 3/3 [00:31<00:00, 10.35s/it]
Loaded pretrained model Qwen/Qwen3-4B into HookedTransformer
Moving model to device:  cuda
Changing model dtype to torch.float32
Layer 0: max hidden state diff 0.000000
Layer 1: max hidden state diff 0.000003
Layer 2: max hidden state diff 0.000027
Layer 3: max hidden state diff 0.000016
Layer 4: max hidden state diff 0.000013
Layer 5: max hidden state diff 0.000014
Layer 6: max hidden state diff 0.000032
Layer 7: max hidden state diff 0.006836
Layer 8: max hidden state diff 0.006836
Layer 9: max hidden state diff 0.006836
Layer 10: max hidden state diff 0.006836
Layer 11: max hidden state diff 0.006836
Layer 12: max hidden state diff 0.006836
Layer 13: max hidden state diff 0.006836
Layer 14: max hidden state diff 0.006836
Layer 15: max hidden state diff 0.006836
Layer 16: max hidden state diff 0.006836
Layer 17: max hidden state diff 0.006836
Layer 18: max hidden state diff 0.006836
Layer 19: max hidden state diff 0.006836
Layer 20: max hidden state diff 0.006836
Layer 21: max hidden state diff 0.006836
Layer 22: max hidden state diff 0.006836
Layer 23: max hidden state diff 0.006836
Layer 24: max hidden state diff 0.006836
Layer 25: max hidden state diff 0.006836
Layer 26: max hidden state diff 0.006836
Layer 27: max hidden state diff 0.006836
Layer 28: max hidden state diff 0.006836
Layer 29: max hidden state diff 0.005859
Layer 30: max hidden state diff 0.005859
Layer 31: max hidden state diff 0.005859
Layer 32: max hidden state diff 0.005859
Layer 33: max hidden state diff 0.005859
Layer 34: max hidden state diff 0.005859
Layer 35: max hidden state diff 0.004883
Layer 36: max hidden state diff 0.000130
Max logit diff: 0.000053

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@bryce13950 bryce13950 changed the base branch from main to dev June 9, 2025 19:29
@bryce13950 bryce13950 merged commit 4b69393 into TransformerLensOrg:dev Jun 10, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants