Skip to content

tom-pollak/dllm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion Language Models

What are Language Diffusion Models?

Instead of predicting the next token, a diffusion LLM must "de-mask" a corrupted sequence of text.

LLaDA: Large Language Diffusion with mAsking

The paper introduces masked diffusion models.

These generate text by iteratively denoising a corrupted sequence.

[MASK] [MASK]  [MASK] [MASK] [MASK] [MASK] [EOS]
[MASK] [MASK]  [MASK] France [MASK] [MASK] [EOS]
The    [MASK]  [MASK] France [MASK] [MASK] [EOS]
The    capital [MASK] France [MASK] [MASK] [EOS]
The    capital [MASK] France is     [MASK] [EOS]
The    capital [MASK] France is     Paris  [EOS]
The    capital of     France is     Paris  [EOS]

In inference, this allows the model to produce multiple tokens per step.

Training

The core idea is very similar to BERT

My minimal training setup at: train_dllm.py. This just takes a pretrained BERT model and finetunes with a variable amount of masking.

Run with ./train_dllm.py (8 H100s).

We corrupt part of the input sequence with [MASK] tokens, and train the model to predict these tokens using cross-entropy loss (CEL).

However, while BERT is trained with a fixed 15% mask rate, LLaDA uses variable 15-99% masking.

A disadvantage of BERT-like models: only get training signal from masked tokens.

The [MASK] of France [MASK] Paris [EOS]
       ^                ^
  Only the masked tokens contribute to the loss

A Minimal Inference Setup

import sys
import time
import torch

@torch.no_grad()
def iter_mask_decode(model, tokenizer, prompt: str, answer_length: int = 32):
    # Create initial sequence with masked tokens
    assistant_message = tokenizer.mask_token * answer_length
    toks_dict = tokenizer.apply_chat_template(
        [{"content": prompt}, {"content": assistant_message}],
        return_dict=True, return_assistant_tokens_mask=True, return_tensors="pt")
    
    ids = toks_dict['input_ids'][0].tolist()
    assistant_mask = toks_dict['assistant_masks']
    answer_start = assistant_mask[0].nonzero().min().item()
    
    device = next(model.parameters()).device
    
    for step in range(answer_length):
        logits = model(input_ids=torch.tensor([ids]).to(device)).logits
        probs = torch.softmax(logits[0], dim=-1)

        mask_positions = (torch.tensor(ids) == tokenizer.mask_token_id).nonzero(as_tuple=True)[0]
        if len(mask_positions) == 0:
            break

        mask_probs = probs[mask_positions]
        confidence_scores = mask_probs.max(dim=-1)[0]
        best_idx = confidence_scores.argmax()

        pos = mask_positions[best_idx]
        new_token = mask_probs[best_idx].argmax().item()
        ids[pos] = new_token
        
        yield new_token, pos.item() - answer_start

def demo_inference(model, tokenizer, prompt: str, answer_length: int = 25, delay=0.1):
    def _print_step(resp, n_clear):
        """1) move to start, 2) blank the full width, 3) move back, 4) write new text"""
        resp = resp.encode('unicode_escape').decode('ascii')
        blank = " " * n_clear
        sys.stdout.write("\r" + blank + "\r" + resp)
        sys.stdout.flush()
        return len(resp)
    
    print(f"User: {prompt}")
    print("Assistant: ", end="")
    
    tokens = ["[MASK]"] * answer_length
    n_clear = 0
    
    for new_token, pos in iter_mask_decode(model, tokenizer, prompt, answer_length):
        tokens[pos] = tokenizer.decode(new_token)
        resp = "".join(tokens)
        n_clear = _print_step(resp, n_clear)
        time.sleep(delay)
    
    print()
import os, random, itertools, math, torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

device =  (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

model_id = "tommyp111/modernbert-dllm-tulu"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device).eval()

prompt = "What is the meaning of life?"

Inference

"Greedy strategy"

Iteratively "de-mask" the most confident token at the most confidence point in the sequence.

demo_inference(model, tokenizer, prompt, answer_length=25, delay=0.25)
User: What is the meaning of life?
\nThe meaning of life is to live, to love, to laugh, to cry, to dream, to hope,[SEP]                                                              

Block Generation with KV cache.

Key limitation: KV caching is not possible, since at the newly de-masked token attention is all-to-all.

Mitigated by generating "blocks" of tokens at a time, autoregressively.

E.g: To generate 128 tokens, we autoregressively generate block 4 blocks of 32.

Remasking

A key limitation of masked diffusion LMs is they have no self-correction

The authors suggest confidence based remasking.

Diffusion Duality (Duo)

Primary issue with masked diffusion models: No ability for self correction.

For inference: Model starts from random tokens, that will be iteratively refined.

  • This allows it to self-correct it's previous predictions!

See training script: train_flowlm.py

Duo's Approach:

  1. Corrupt one-hot tokens with a variable amount of gaussian noise.
  2. Convert back to one-hot (with argmax), some tokens will have flipped. This is our corrupted sequence.
  3. Train model to restore original sequence.
import numpy as np
import torch
input_ids = torch.arange(4)
temperature = 1
np.set_printoptions(precision=2)

def one_hot(x):
    out = torch.nn.functional.one_hot(x)
    print(f"one_hot:\n{out.numpy()}\n")
    return out

def softmax(x):
    out = torch.nn.functional.softmax(x.to(torch.float32), dim=-1)
    print(f"sotmax:\n{out.numpy()}\n")
    return out

def randn_like(x):
    out = torch.randn_like(x.to(torch.float32))
    print(f"random noise:\n{out.numpy()}\n")
    return out
noise_weight = random.uniform(0.25, 0.5) # variable amount of noise
clean_weight = 1 - noise_weight

hot = one_hot(input_ids)
w = clean_weight * hot + noise_weight * randn_like(hot)
soft_latents = softmax(w / temperature)

print(f"noise: {noise_weight:.3f} | input: {input_ids.tolist()} | corrupted: {soft_latents.argmax(-1).tolist()}")
one_hot:
[[1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 0 0 1]]

random noise:
[[-0.18 -0.95  0.94  0.8 ]
 [ 0.03  0.75 -0.32 -0.17]
 [-1.3   0.49  1.63 -1.23]
 [ 0.52  0.05  0.06  0.13]]

sotmax:
[[0.34 0.14 0.27 0.26]
 [0.19 0.47 0.17 0.18]
 [0.11 0.2  0.58 0.11]
 [0.23 0.19 0.2  0.38]]

noise: 0.364 | input: [0, 1, 2, 3] | corrupted: [0, 1, 2, 3]

"Soft" Tokens

To stabilize early training, we use a high temperature softmax to give a spread over the noisy tokens for a richer signal.

Temperature is annealead to 0 over a course of 500K steps, where this becomes equivalent to argmax.

Early in training (temperature = 1e-3)
softmax(noisy_one_hot_tokens / temperature) => [0.001, 0.92, 0.005, 0.074]

Late training (temperature = 1e-8)
softmax(noisy_one_hot_tokens / temperature) ~= argmax => [0, 1, 0, 0]

Loss

By default, hard one-hot targets and CEL cause very high-variance gradiants, and it doesn't work very well.

A.k.a: For two rows of 15% and 99% noise, the 99% is a much harder and will dominate the gradient.

  • Use a Symmetric KL divergence between the target labels and logits.
    • Similar to CEL, but also penalises over-predicting wrong labels.
  • Use a scaling factor based on the amount of noise per row, so all rows have ~= loss
p_prob = softmax(logits)    # predicted
q_prob = one_hot(input_ids) # ground-truth
sym_kl = symmetric_kl_divergence(p_prob, q_prob)

# As noise -> 0, view_scale -> 1, loss counts fully.
view_scale = (noise_weight / (vocab_size * clean_weight + noise_weight)
loss = (view_scale * sym_kl).sum()

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published