Masked Diffusion Language Model

Full Pipeline Walkthrough

Sentence: "I love deep learning" PyTorch BPE Tokenizer Transformer Encoder t = 0.5
Step 1

Tokenization — raw text → token IDs

Input sentence

"I love deep learning"
A BPE tokenizer splits words into subwords based on frequency statistics. Common words stay intact; rare words break into pieces like learn + ##ing.

Step 1a — raw BPE subword tokens

I
word
love
word
deep
word
learn
subword
##ing
suffix
"learning" → learn + ##ing because the tokenizer learned this split from corpus frequency. The ## prefix signals a continuation subword (no space before it).

Step 1b — add special tokens [BOS] and [EOS]

[BOS]
start
I
love
deep
learn
##ing
[EOS]
end

Step 1c — pad to max sequence length 8 with [PAD]

[BOS]
id: 1
I
id: 287
love
id: 1842
deep
id: 3383
learn
id: 3483
##ing
id: 1775
[EOS]
id: 2
[PAD]
id: 0

Step 1d — final token ID tensor

input_ids = [1, 287, 1842, 3383, 3483, 1775, 2, 0] # [BOS, I, love, deep, learn, ##ing, EOS, PAD]
Shape: [8] — a 1-D integer tensor of length 8
Step 2

Batching — wrapping into a batch dimension

PyTorch models always expect a batch dimension as the first axis, even for batch_size=1. This lets the same code work for 1 or 256 examples.

Unsqueeze: add batch dim

input_ids = input_ids.unsqueeze(0) # shape: [8] → [1, 8] # ↑ ↑ ↑ # 1-D tensor B seq_len

Tensor after unsqueeze

tensor([[ 1, 287, 1842, 3383, 3483, 1775, 2, 0]])
[1, 8]

Also create padding mask

pad_mask = (input_ids != 0) # True = real token # False = padding # shape: [1, 8] [True,True,True,True, True,True,True,False]

If batch_size = 4 (conceptual)

batch shape: [4, 8] # 4 sentences, each padded to length 8 # Row 0: our sentence "I love deep learning" # Row 1: another sentence (different tokens) # Row 2: ... # Row 3: ...
All operations below work on shape [B, L][B, L, D]. We continue with B=1 for clarity.
Step 3

Forward masking — diffusion noise process

In masked diffusion LMs, we sample a noise level t ∈ (0,1]. Each non-special token is independently replaced with [MASK] with probability t. The model must denoise back to the original.

Sample t = 0.5 (medium noise)

t = 0.5 # noise level sampled from uniform or cosine schedule # For each real content token (positions 1–5): # mask it with probability t = 0.5 # Special tokens [BOS], [EOS], [PAD] are NEVER masked.

Mask sampling (one possible outcome)

uniform_samples = [—, 0.31, 0.78, 0.44, 0.62, 0.19, —, —] mask_applied = [—, <0.5, <0.5, <0.5, <0.5, <0.5, —, —] = [—, True, False, True, False, True, —, —]

Original vs corrupted sequence

Original x₀

[BOS]
I
love
deep
learn
##ing
[EOS]
[PAD]

Corrupted xₜ (model input)

[BOS]
[MASK]
love
[MASK]
learn
[MASK]
[EOS]
[PAD]

Mask boolean matrix (which positions need loss)

mask_bool = [False, True, False, True, False, True, False, False] # BOS I love deep learn ##ing EOS PAD # ↑ ↑ ↑ # masked tokens → loss computed only here
The model sees xₜ (with [MASK] tokens). It must predict the original token at every masked position. This is exactly like BERT's MLM — but now the fraction masked is controlled by the diffusion timestep t.
Step 4

Embedding layer — IDs → dense vectors

An embedding table has shape [vocab_size, d_model]. Each token ID is a row index — we simply look up the row. This is a learned lookup; during training these vectors are updated by backprop.
input_ids
[1, 8] int
Embedding table
[vocab=32000, D=256]
token_emb
[1, 8, 256] float32

What [1, 8, 256] means

token_emb.shape = [1, 8, 256] # B L D # ↑ ↑ ↑ # batch seq embed dim

Example vectors (D=8 for readability, real D=256)

Token embedding vectors — first 8 dims shown

tokendim0dim1dim2dim3dim4dim5dim6dim7shape
Similar tokens (like learn and ##ing) will have similar embedding vectors after training — they cluster together in the 256-dimensional space. [MASK] has its own dedicated embedding vector.

Shape transformation at this step

[1, 8] lookup → [1, 8, 256] each of 8 tokens is now a 256-d float vector
Step 5

Positional encoding — injecting order

Transformers are permutation-invariant by default — they don't know word order. Positional encodings solve this by adding a unique signal to each position's embedding vector.

Sinusoidal formula (Vaswani et al.)

PE(pos, 2i) = sin( pos / 10000^(2i/D) ) PE(pos, 2i+1) = cos( pos / 10000^(2i/D) ) # pos = token position (0,1,2,...) # i = dimension index (0,1,...,D/2-1) # D = embedding dimension = 256

Intuition

Each dim = one frequency

Low-index dims oscillate slowly — they encode global position. High-index dims oscillate fast — they encode fine-grained local position. Together they create a unique fingerprint per position.

Why sin/cos?

For any offset k: PE(pos+k) can be expressed as a linear function of PE(pos). The model can learn to attend by relative position just by learning a linear transform of the PE vectors.

Computed PE values for 8 positions (dims 0–3 shown)

Addition step

x = token_emb + positional_encoding # Both have shape [1, 8, 256] # Element-wise addition → shape stays [1, 8, 256] x[0, pos=2, :] = emb("deep") + PE(2, :) ↑ [256-d] ↑ [256-d] → [256-d]
[1,8,256] token_emb + [1,8,256] pos_enc = [1,8,256] → encoder input x
After this step, every token's vector carries both semantic meaning (from its word embedding) and positional identity (from PE). The encoder can now distinguish "I love" from "love I".
Step 6

Transformer encoder — self-attention

The encoder's job: for every token (especially [MASK] tokens), look at all other tokens and gather evidence to predict what the masked token should be.

Q, K, V — conceptual meaning

Query (Q)

"What am I looking for?" Each token projects itself into a query vector representing what context it needs.
Q = x · Wq shape: [1,8,64] (d_head = D/n_heads = 256/4 = 64)

Key (K)

"What do I offer?" Each token projects into a key vector — it advertises its content.
K = x · Wk shape: [1,8,64]

Value (V)

"What information do I give?" If a key matches a query, the corresponding value is sent over.
V = x · Wv shape: [1,8,64]

Attention score formula

Attention(Q, K, V) = softmax( QKᵀ / √d_k ) · V ↑ scale by √64 = 8 (prevents vanishing gradients)

Focus: [MASK] at position 1 (was "I") attending to all tokens

The [MASK] token at position 1 has a query vector that "asks": what word fits here, given the surrounding context? It computes dot-products with every token's key vector to find how relevant each neighbor is.

Attention weight matrix — row = query token, col = key token

Read row 1 ([MASK]→I): it attends most to love (0.38) and [BOS] (0.25) — the boundary context is most useful for recovering the first content word.

What happens inside one encoder layer

1. Q, K, V = linear projections of x shape [1,8,64] each 2. scores = Q @ Kᵀ / √64 shape [1,8,8] 3. scores = fill -∞ at PAD positions (mask) 4. weights = softmax(scores, dim=-1) shape [1,8,8] 5. context = weights @ V shape [1,8,64] 6. # 4 heads → concat → [1,8,256] 7. x = LayerNorm(x + MultiHead(x)) residual conn. 8. x = LayerNorm(x + FFN(x)) feed-forward

After N encoder layers (e.g. N=6)

[1,8,256] → 6× Encoder Layer → [1,8,256] contextualised hidden states h
Step 7

Output layer — hidden states → token probabilities

A linear "language model head" projects each 256-d hidden state to a vector of length vocab_size (32000). Softmax converts this to a probability distribution over all tokens.
[1,8,256] → Linear(256→32000) → [1,8,32000] → softmax → [1,8,32000] logits → probs

Prediction at position 1 (was "I", now [MASK])

logits[0, 1, :] # shape [32000] — one score per vocab token # Top logits (before softmax): # token "I" → logit = +3.8 ← highest (model is correct!) # token "We" → logit = +2.1 # token "You" → logit = +1.4 # token "They" → logit = +0.9

Top-5 probability distributions

[MASK] at pos 1 → "I"

[MASK] at pos 3 → "deep"

[MASK] at pos 5 → "##ing"

The model only needs to be confident at masked positions. At unmasked positions, the logits are still computed (the head runs on all positions) but not included in the loss.
Step 8

Loss calculation — cross-entropy + diffusion scaling

We compute cross-entropy loss only at masked positions, then scale by 1/t — the key diffusion trick that weights steps by their difficulty.

Per-token cross-entropy

CE(pos) = -log( prob(correct_token | context) ) # Masked positions only: pos=1: correct="I" → prob = 0.62 → CE = -log(0.62) = 0.478 pos=3: correct="deep" → prob = 0.31 → CE = -log(0.31) = 1.171 pos=5: correct="##ing" → prob = 0.51 → CE = -log(0.51) = 0.673

Apply mask — zero out non-masked positions

Sum over masked tokens → raw loss

raw_loss = (0.478 + 0.0 + 1.171 + 0.0 + 0.673 + 0.0 + 0.0 + 0.0) = 2.322 # sum of CE at masked positions only

Scale by 1/t — the diffusion weight

Why 1/t?
At small t (e.g. t=0.1), very few tokens are masked — the task is easy. But these low-noise steps carry high information gradient about the data distribution. We upweight them with 1/t = 10× so they contribute meaningfully. At large t (e.g. t=0.9), many tokens are masked — hard task, low info. Scale 1/0.9 ≈ 1.1 leaves them nearly unchanged. This is the continuous-time diffusion ELBO weighting, analogous to SNR weighting in image diffusion.
loss = raw_loss / (num_masked_tokens * t) = 2.322 / (3 * 0.5) = 2.322 / 1.5 = 1.548 ← final scalar loss for this training step

Full formula

L(θ, t) = (1/t) * (1/|M|) * Σ_{i∈M} CE(x₀ᵢ, p_θ(xₜ)ᵢ) # # M = set of masked positions # |M| = number of masked tokens (= 3 here) # x₀ᵢ = original token at position i # p_θ = model probability distribution # t = noise level (0.5 here)
Step 9

Backpropagation — what gets updated

After computing the scalar loss, PyTorch's autograd traces back through every operation that produced it, computing gradients via the chain rule. Every parameter used in the forward pass receives a gradient.

Gradient flow — backward through the network

Backpropagation gradient flow diagram Dashed red arrows show gradients flowing backward from loss through LM head, encoder layers, positional encoding (no grad), and embedding table. LM head Linear 256→32k Encoder ×6 Attn + FFN + LN Pos. encoding Fixed (no grad) Embedding table [32k, 256] Loss = 1.548 ∂L/∂PE = 0 (not learned) Parameters updated by optimizer (Adam): LM head W_lm: [256, 32000] b_lm: [32000] ↑ learn which dirs map to each token Encoder weights Wq, Wk, Wv: [256,64]×4 heads Wo: [256,256] FFN W1:[256,1024] W2:[1024,256] LayerNorm γ, β: [256] each × 6 layers Embedding table E: [32000, 256] Only rows for tokens that appeared in this batch get a gradient (sparse update)

Adam optimizer step

optimizer.zero_grad() # clear previous gradients loss.backward() # compute ∂L/∂θ for all θ optimizer.step() # θ ← θ - lr * Adam(∂L/∂θ) # lr typically 1e-4 to 3e-4

What "learning" means here

Attention weights update

Wq, Wk, Wv are nudged so that [MASK] tokens in subject position attend more to neighboring verbs and determiners — learning grammatical structure.

Embedding update

The [MASK] embedding vector is pushed to produce queries that better extract denoising context. Token embeddings for "I", "deep", "##ing" are pulled toward their correct contextual neighborhoods.
Over millions of training steps, the model learns to be an expert denoiser: given any partially-masked sentence at any noise level t, it recovers the original tokens. At inference, you start from a fully-masked sentence and iteratively denoise using learned p_θ.