The power of good notation, and training distributed LLMs
I've been working with sharding LLMs at work recently. It was a good time to review some "old" classics like Mesh-TensorFlow (2018), Megatron-LM (2019), ZeRO-3 (2019), and Efficiently Scaling Transformer Inference (2022).
The first author of the inference paper has since co-founded a chip company MatX. They recently put out a repository called seqax for LLM training, which I quite liked.
Good, simple, and general syntax can be mind opening. Consider the terminology around parallelism strategies - I think we can all agree data/model/tensor parallelism naming is a bit of a mess. (It's a bit more complex, but I prefer the weight vs. activation partitioning terms originally used Mesh-Tensorflow.)
I thought it'd be a good exercise in this post to:
- write out the sharded version of the Transformer equations for a model using both FSDP and Megatron-style activation partitioning.
- see how this maps to code, by annotating
seqax's forward pass implementation. - write the backward pass equivalents in equations and code.
On the notation below:
- The
seqaxrepo uses a simple "named axes" syntax, wheref32['in hidden1/d']is a 2D weight whosehidden1axis is split across theddevice-mesh axis. The repo explains it here. - In my equations, I'll also use the
{U_d}notation to indicate an unreduced sum overd, which occurs from a einsum over a sharded dimension, and requires areduce_scatter()afterwards.
These two together make the universe of sharding strategies in distributed LLM training quite simple to write.
(Note: I won't review when/why to use different shardings; the original papers and online blogposts should help with that. Also, I did write this by hand, so apologies if there are any mistakes!)
Part 1: a distributed LLM in equation form (forward pass)
Pre- Transformer
- Input assumptions: Start with input $x[B_d,T]$, where $B$ is the batch size and $T$ is the sequence length.
- One-hot mapping: Map tokens to one-hot encodings:
- $x[B_d,T,V_t] = \text{Lookup}(x[B,T])$, where $V$ is vocabulary size.
- Embedding projection:
- First, gather the embedding weights across the data dimension
- $W_{embed}[V_t,M] = \text{AllGather}_d(W_{embed}[V_t,M_d])$
- Matrix multiply, contracting over the $t$-sharded vocab axis (so the result is unreduced over $t$)
- $x[B_d,T,M]\{U_t\} = x[B_d,T,V_t] \cdot W_{embed}[V_t,M]$
- Reduce-scatter to land the model dimension sharded over $t$
- $x[B_d,T,M_t] = \text{ReduceScatter}_t(x[B_d,T,M]\{U_t\})$
- First, gather the embedding weights across the data dimension
Transformer block: FFN
Note that $M$ is typically sharded across the data dimension (FSDP), while the inner feed-forward dimension $F$ is sharded across the tensor dimension (Megatron).
- pre-FFN LayerNorm: $LN_2[M_{t,d}]$ (sharded across both $t$ and $d$)
- Gate weight: $W_{gate}[M_d,F_t]$
- Up weight: $W_{up}[M_d,F_t]$
- Down weight: $W_{down}[M_d,F_t]$
Assuming the residual arrives as $x[B_d,T,M_t]$, the forward pass steps look like this:
- Apply pre-FFN LayerNorm (an affine transformation applied via element-wise multiplication):
- First all-gather the residual to a replicated copy: $x[B_d,T,M] = \text{AllGather}_t(x[B_d,T,M_t])$
- $LN_2[M] = \text{AllGather}_{t,d}(LN_2[M_{t,d}])$
- $nx[B_d,T,M] = \text{RMSNorm}(x)[B_d,T,M] \odot LN_2[M]$
- Gate projection (matmul, requiring FSDP's weight gather first):
- $W_{gate}[M,F_t] = \text{AllGather}_d(W_{gate}[M_d,F_t])$
- $\text{gate\_proj}[B_d,T,F_t] = nx[B_d,T,M] \cdot W_{gate}[M,F_t]$
- Up projection:
- $W_{up}[M,F_t] = \text{AllGather}_d(W_{up}[M_d,F_t])$
- $\text{up\_proj}[B_d,T,F_t] = nx[B_d,T,M] \cdot W_{up}[M,F_t]$
- Swish activation applied on the projections:
- $y[B_d,T,F_t] = \text{Swish}(\text{gate\_proj}[B_d,T,F_t]) \odot \text{up\_proj}[B_d,T,F_t]$
- Down projection: The contracting dimension ($F$) is sharded here, so the matmul is unreduced over $t$; we reduce-scatter it back to a $t$-sharded model dimension.
- $W_{down}[M,F_t] = \text{AllGather}_d(W_{down}[M_d,F_t])$
- $\text{ffn\_out}[B_d,T,M]\{U_t\} = y[B_d,T,F_t] \cdot W_{down}[M,F_t]^T$
- $\text{ffn\_out}[B_d,T,M_t] = \text{ReduceScatter}_t(\text{ffn\_out}[B_d,T,M]\{U_t\})$
- Residual connection: both operands are already sharded over $t$, so no extra communication is needed.
- $x[B_d,T,M_t] \mathrel{+}= \text{ffn\_out}[B_d,T,M_t]$
Transformer block: attention
Let's define $Q$ as the number of query heads per KV head, $K$ as the number of KV heads, and $D$ as the head dimension.
- pre-Attention LN: $LN_1[M_{t,d}]$
- Q weight: $W_Q[M_d, Q, K_t, D]$
- K weight: $W_K[M_d, K_t, D]$ (Sliced from $W_{KV}[2, M_d, K_t, D]$)
- V weight: $W_V[M_d, K_t, D]$ (Sliced from $W_{KV}[2, M_d, K_t, D]$)
- O (Output) weight: $W_O[M_d, Q, K_t, D]$
The residual arrives as $x[B_d,T,M_t]$, as in the FFN, the pre-attention RMSNorm all-gathers it over $t$ to a replicated $x[B_d,T,M]$, which is the projection input below:
- Calculate q:
- $W_Q[M, Q, K_t, D] = \text{AllGather}_d(W_Q[M_d, Q, K_t, D])$
- $q[B_d,T,Q, K_t, D] = x[B_d,T,M] \cdot W_Q[M, Q, K_t, D]$
- Calculate k: (Note: KVs may be cached and of a different length $S$, e.g., in sliding window attention).
- $W_K[M,K_t, D] = \text{AllGather}_d(W_K[M_d, K_t,D])$
- $k[B_d,S,K_t, D] = x[B_d,T,M] \cdot W_K[M,K_t,D]$
- Apply RoPE to q and k:
- $q[B_d,T,Q, K_t, D] = \text{RoPE}(q[B_d,T,Q, K_t, D])$
- $k[B_d,S,K_t, D] = \text{RoPE}(k[B_d,S,K_t, D])$
- Calculate QK scores: All dimensions remain except the head dimension $D$.
- $logits[B_d,T,S, Q,K_t] = q[B_d,T,Q, K_t, D] \cdot k[B_d,S,K_t, D]$
- Apply Mask to $logits$.
- Calculate Attention Weights: $probs = \text{softmax}(logits)$
- Calculate v:
- $W_V[M,K_t, D] = \text{AllGather}_d(W_V[M_d, K_t,D])$
- $v[B_d,S,K_t, D] = x[B_d,T,M] \cdot W_V[M,K_t,D]$
- Calculate weighted values: All dimensions remain except the KV sequence length $S$.
- $attn\_out[B_d,T,Q,K_t,D] = probs[B_d,T,S, Q,K_t] \cdot v[B_d,S,K_t,D]$
- Up projection:
- Pre-emptively all-gather to combine with the residual's $M$ dimension:
$W_O[M, Q, K_t, D] = \text{AllGather}_d(W_O[M_d, Q, K_t, D])$ - Perform the projection (which requires a
ReduceScattersince the contracting dimension $K_t$ is sharded):
$attn\_out[B_d,T,M]\{U_t\} = attn\_out[B_d,T,Q,K_t,D] \cdot W_O[M, Q, K_t, D]$ - $attn\_out[B_d,T,M_t] = \text{ReduceScatter}_t(attn\_out[B_d,T,M]\{U_t\})$
- Pre-emptively all-gather to combine with the residual's $M$ dimension:
- Residual connection: (both operands sharded over $t$)
- $x[B_d,T,M_t] = x[B_d,T,M_t] + attn\_out[B_d,T,M_t]$
Post- Transformer
Finally, we project the activations back out to the vocabulary size to get our probabilities.
- Final LayerNorm: All-gather the residual over $t$ first
- $x[B_d,T,M] = \text{AllGather}_t(x[B_d,T,M_t])$, then apply RMSNorm.
- Unembedding: Matmul to get logits (requires an
AllGatherof the unembedding weights over $d$).- $W_{unembed}[V_t, M] = \text{AllGather}_d(W_{unembed}[V_t, M_d])$
- Contracting over the replicated $M$ leaves the vocab axis sharded over $t$: $logits[B_d,T,V_t] = x[B_d,T,M] \cdot W_{unembed}[V_t,M]^T$.
- Compute Loss
Part 2: how does this map to code? annotating seqax's forward pass
# Important: defines the sharding
@pytree_dataclass
class TransformerLayer:
# LayerNorms are sharded over 2 axes
ln1: f32['d_model/t/d'] # used pre-attention
ln2: f32['d_model/t/d'] # used pre-FFN
### Attention
# Note: w_q and w_kv defined separately
# - Helps simplify usage when Q and KVs can have different lengths t and s
# - E.g. when KV caching, t=1
# Input dimensions of QKV heads sharded over d, n_kv sharded over t
# GQA: each KV has multiple heads, so w_q includes [n_q_per_kv, n_kv/t]
w_q: f32['d_model/d n_q_per_kv n_kv/t d_head']
w_kv: f32['2 d_model/d n_kv/t d_head']
w_o: f32['d_model/d n_q_per_kv n_kv/t d_head']
### FFN
# FFN's input_dim sharded over d, output dim sharded over t
# -- this means it's using FSDP (d) and Megatron (t) parallelism
# FFN uses SwiGLU, i.e. (Swish(w_proj @ x) * (w_gate @ x)) @ w_down
w_gate: f32['d_model/d d_ff/t']
w_up: f32['d_model/d d_ff/t']
w_down: f32['d_model/d d_ff/t']
@pytree_dataclass
class Model:
# Each dimension of embedding is sharded (over separate axes)
embed: f32['vocab/t d_model/d']
unembed: f32['vocab/t d_model/d']
transformer: Transformer
# Again, LN is sharded over 2 axes
final_layer_norm: f32['d_model/d/t']
@typechecked
# The ids are ofc sharded over batch dimension
def forward_pass(self, h: Hparams, ids: u32[b'B/d L'], is_seq_start: bool_[b'B/d L']) -> f32[b'B/d L V/t']:
### Initial embedding lookup
embed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(self.embed))
x = shardops.index_unreduced('[V/t] M, B/d L -> B/d L M', embed, ids)
x = shardops.psum_scatter('B/d L M -> B/d L M/t', x)
L = ids.shape[1]
# Sequence packing
# `is_seq_start` flags the first token of each new document
# - hence segment_ids will be [1, 1, 2, 2, 2, 3, ...]
segment_ids = jnp.cumsum(is_seq_start, axis=1)
segment_mask: bool_[b'B/d L L'] = segment_ids[:, :, jnp.newaxis] == segment_ids[:, jnp.newaxis, :]
# add axes for q_per_k, num_kv_heads dimensions
segment_mask: bool_[b'B/d L L 1 1'] = segment_mask[..., jnp.newaxis, jnp.newaxis]
causal_mask: bool_[b'1 L L 1 1'] = jnp.tril(jnp.ones((L, L), dtype=jnp.bool_), 0)[jnp.newaxis, ..., jnp.newaxis, jnp.newaxis]
causal_mask: bool_[b'B/d L L 1 1'] = jnp.logical_and(segment_mask, causal_mask)
rope_table = RopeTable.create(L, h)
### Transformer blocks
@explicit_activation_checkpointing
@typechecked
def loop_body(x: bf16[b'B/d L M/t'], layer_weights: TransformerLayer) -> Tuple[bf16[b'B/d L M/t'], Tuple[()]]:
# x = [bsz (sharded over d), length, dim=M (sharded over t)]
### Pre-attention RMSNorm
# FSDP's all_gather() step
ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln1))
gx = shardops.all_gather('B/d L M/t -> B/d L M', x)
# Multiple by ln1, as it's the \gamma in LayerNorm / RMSNorm
nx = jnp.bfloat16(rms_norm(gx) * ln1) # EC: [B/d L M]
# Notice input (nx) still sharded over batch dimension for data parallelism
### Attention, using Grouped Query Attention and RoPE position embeddings.
# GQA: eg. n_kv=4 heads, n_q=8 heads --> n_q_per_kv=2 (2 Q heads per KV head)
# - shorthand Q := n_q_per_kv (defined in TransformerLayer class), and K := n_kv
# FSDP's all_gather() step before local matmul
w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_q))
# save_for_backward() is just activation checkpointing
q = save_for_backward(shardops.einsum_unreduced('B/d L M, M Q K/t D -> B/d L Q K/t D', nx, w_q))
q = rope_table.apply('L D -> 1 L 1 1 D', q)
# FSDP's all_gather() step before local matmul
w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(layer_weights.w_kv))
k, v = shardops.einsum_unreduced('B/d L M, k_v M K/t D -> k_v B/d L K/t D', nx, w_kv)
k = save_for_backward(k)
v = save_for_backward(v)
k = rope_table.apply('L d -> 1 L 1 d', k)
# Because we're doing GMQA, the logits include [Q K/t]
# - basically just the attention between the i-th Q and j-th KV
# Note: Use float32 for logits
logits = shardops.einsum_unreduced(
'B/d Qlen Q K/t D, B/d Klen K/t D -> B/d Qlen Klen Q K/t', q, k, preferred_element_type=jnp.float32)
logits = jnp.where(causal_mask, logits, -1e10)
# Note: doesn't seem to have the division by sqrt(d)?
probs = jnp.bfloat16(jax.nn.softmax(logits, axis=2))
# QK @ V for final attention embeddings
attn_out = shardops.einsum_unreduced(
'B/d Qlen Klen Q K/t, B/d Klen K/t D -> B/d Qlen Q K/t D', probs, v)
# FSDP's all_gather() step before local matmul
w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_o))
attn_out = shardops.einsum_unreduced('B/d Qlen Q K/t D, M Q K/t D -> B/d Qlen M', attn_out, w_o)
# Note: we are doing a matmul where a contacting dimension is sharded (K/t),
# - Thus, we'll need to do a reduce-scatter after; this is Megatron pattern
attn_out = shardops.psum_scatter('B/d Qlen M -> B/d Qlen M/t', attn_out)
x = save_for_backward(x + attn_out)
### FFN, using SwiGLU
# Pre-FFN RMSNorm
# Same RMSNorm as pre-attention
ln2 = save_for_backward(shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln2)))
gx = shardops.all_gather('B/d L M/t -> B/d L M', x)
nx = jnp.bfloat16(rms_norm(gx) * ln2)
# FSDP's all_gather() step for three weights, over input_dim (M/D -> M), before local matmul
# Note: these weight matrices are the larger ones, to save on activation checkpointing memory
# - I think it's not uncommon to only save_for_backward(gate_proj, up_proj)?
w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_gate))
gate_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_gate))
w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_up))
up_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_up))
y = jax.nn.swish(gate_proj) * up_proj # EC: this * is element wise, so [B/d L F/t]
w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_down))
# Note: we are doing a matmul where a contacting dimension is sharded (F/t),
# - Thus, we'll need to do a reduce-scatter after; this is Megatron pattern
ffn_out = shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', y, w_down)
ffn_out = shardops.psum_scatter('B/d L M -> B/d L M/t', ffn_out)
return jnp.bfloat16(x + ffn_out), ()
x, () = jax.lax.scan(loop_body, jnp.bfloat16(x), self.transformer)
### Final layernorm and output projection.
# all_gather to prepare for layer norm (actually RMSNorm)
x = shardops.all_gather('B/d L M/t -> B/d L M', x)
ln = shardops.all_gather('M/t/d -> M', jnp.float32(self.final_layer_norm))
x = jnp.bfloat16(rms_norm(x) * ln)
# all_gather to prepare for unembed
unembed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(self.unembed))
logits = shardops.einsum_unreduced('B/d L M, V/t M -> B/d L V/t', x, unembed, preferred_element_type=jnp.float32)
return logits
Part 3: a distributed LLM in equation form (backward pass)
Transformer: FFN
- Initial state: Assume we are given $grad\_x[B_d,T,M]$, representing the upstream gradient with respect to the output of the FFN block.
- Compute gradients for down projection (wrt $W_{down}[M_d,F_t]$ and $y[B_d,T,F_t]$):
- Gradient wrt $y[B_d,T,F_t]$:
$W_{down}[M,F_t] = \text{AllGather}_d(W_{down}[M_d,F_t])$
$grad\_y[B_d,T,F_t] = grad\_x[B_d,T,M] \cdot W_{down}[M,F_t]$
(This $grad\_y$ now becomes the upstream gradient for the first part of the FFN.) - Gradient wrt $W_{down}[M_d,F_t]$:
$grad\_w\_down[M,F_t]\{U_d\} = y[B_d,T,F_t]^T \cdot grad\_x[B_d,T,M]$
Since the batch dimension $B_d$ is the contracting dimension, we need a scatter-reduce:
$grad\_w\_down[M_d,F_t] = \text{ScatterReduce}_d(grad\_w\_down[M,F_t]\{U_d\})$
- Gradient wrt $y[B_d,T,F_t]$:
- Compute gradients wrt SwiGLU ($\text{Swish}(\text{gate\_proj}) \odot \text{up\_proj}$):
- Assume a function $\text{swish\_grad}(\text{gate\_proj}) = \nabla \text{swish}(\text{gate\_proj})$.
- $grad\_gate\_proj[B_d,T,F_t] = grad\_y[B_d,T,F_t] \odot \text{swish\_grad}(\text{gate\_proj}[B_d,T,F_t]) \odot \text{up\_proj}[B_d,T,F_t]$
- $grad\_up\_proj[B_d,T,F_t] = grad\_y[B_d,T,F_t] \odot \text{Swish}(\text{gate\_proj}[B_d,T,F_t])$
- Compute gradients for gate projection (wrt $W_{gate}[M_d,F_t]$ and $nx[B_d,T,M]$):
- Gradient wrt $W_{gate}[M_d,F_t]$:
$grad\_w\_gate[M,F_t]\{U_d\} = nx[B_d,T,M]^T \cdot grad\_gate\_proj[B_d,T,F_t]$
$grad\_w\_gate[M_d,F_t] = \text{ScatterReduce}_d(grad\_w\_gate[M,F_t]\{U_d\})$ - Gradient wrt $nx[B_d,T,M]$ (from gate branch):
$W_{gate}[M,F_t] = \text{AllGather}_d(W_{gate}[M_d,F_t])$
$grad\_nx_{gate}[B_d,T,M]\{U_t\} = grad\_gate\_proj[B_d,T,F_t] \cdot W_{gate}[M,F_t]^T$
$grad\_nx_{gate}[B_d,T,M] = \text{AllReduce}_t(grad\_nx_{gate}[B_d,T,M]\{U_t\})$
- Gradient wrt $W_{gate}[M_d,F_t]$:
- Compute gradients for up projection (wrt $W_{up}[M_d,F_t]$ and $nx[B_d,T,M]$):
- Gradient wrt $W_{up}[M_d,F_t]$:
$grad\_w\_up[M,F_t]\{U_d\} = nx[B_d,T,M]^T \cdot grad\_up\_proj[B_d,T,F_t]$
$grad\_w\_up[M_d,F_t] = \text{ScatterReduce}_d(grad\_w\_up[M,F_t]\{U_d\})$ - Gradient wrt $nx[B_d,T,M]$ (from up branch):
$W_{up}[M,F_t] = \text{AllGather}_d(W_{up}[M_d,F_t])$
$grad\_nx_{up}[B_d,T,M]\{U_t\} = grad\_up\_proj[B_d,T,F_t] \cdot W_{up}[M,F_t]^T$
$grad\_nx_{up}[B_d,T,M] = \text{AllReduce}_t(grad\_nx_{up}[B_d,T,M]\{U_t\})$
- Gradient wrt $W_{up}[M_d,F_t]$:
- Combine gradients:
- Because the forward pass split the input $nx$ into two parallel paths (gate and up), the backward pass must sum their respective gradients:
- $grad\_nx[B_d,T,M] = grad\_nx_{gate}[B_d,T,M] + grad\_nx_{up}[B_d,T,M]$
Part 4: implementing and annotating a backward pass
The following pattern is repeated over and over: all_gather --> local einsum --> psum_scatter (scatter_reduce)
import jax
import jax.numpy as jnp
import shardops
# These should match the dimensions of the weights ofc
@pytree_dataclass
class GradTransformerLayer:
grad_ln1: f32['d_model/t/d']
grad_ln2: f32['d_model/t/d']
grad_w_q: f32['d_model/d n_q_per_kv n_kv/t d_head']
grad_w_kv: f32['2 d_model/d n_kv/t d_head']
grad_w_o: f32['d_model/d n_q_per_kv n_kv/t d_head']
grad_w_gate: f32['d_model/d d_ff/t']
grad_w_up: f32['d_model/d d_ff/t']
grad_w_down: f32['d_model/d d_ff/t']
@pytree_dataclass
class GradModel:
grad_embed: f32['vocab/t d_model/d']
grad_unembed: f32['vocab/t d_model/d']
grad_transformer: List[GradTransformerLayer]
grad_final_layer_norm: f32['d_model/d/t']
def backward_pass(model: Model, h: Hparams, ids: u32[b'B/d L'], is_seq_start: bool_[b'B/d L'],
output_grad: f32[b'B/d L V/t']) -> GradModel:
# output_grad is the "upstream gradient". We calculate "local gradients" (at a current
# operation) and then compute downstream gradients as output_grad * local gradient
# - is given the logits, [B/d L V/t],
# Initialize gradient structures to store computed gradients
grad_model = GradModel(
grad_embed=jnp.zeros_like(model.embed),
grad_unembed=jnp.zeros_like(model.unembed),
grad_transformer=[GradTransformerLayer(...) for _ in range(len(model.transformer))],
grad_final_layer_norm=jnp.zeros_like(model.final_layer_norm)
)
# un-embedding
# EC: gradients computed here (not in a separate helper function since it's just x * unembed)
# EC: TODO -- why do we have to all_gather unembed, only to turn M to M/d in the reduce_scatter???
# Gather the unembed matrix across devices for the backward pass
unembed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(model.unembed))
# compute gradient w.r.t. x (input to unembed): dL/dx = dL/dy * dy/dx = output_grad * unembed^T
grad_x = shardops.einsum_unreduced('B/d L V/t, V/t M -> B/d L M', output_grad, unembed)
# compute gradient w.r.t. unembed weights: dL/dunembed = x^T * dL/dy = x^T * output_grad
grad_unembed = shardops.einsum_unreduced('B/d L M, B/d L V/t -> V/t M', x, output_grad)
# scatter the unembed gradients back to their respective devices
grad_model.grad_unembed = shardops.psum_scatter('V/t M -> V/t M/d', grad_unembed)
# RMSNorm, assuming a rms_norm_backward fn exists, producing the local gradient
# Pattern: all_gather --> local einsum --> psum_scatter
ln = shardops.all_gather('M/t/d -> M', jnp.float32(model.final_layer_norm))
grad_ln, grad_x = rms_norm_backward(grad_x, x, ln)
# scatter the final layer norm gradients back to their respective devices
grad_model.grad_final_layer_norm = shardops.psum_scatter('M -> M/t/d', grad_ln)
def loop_body(grad_x, layer_idx):
# EC: IMPORTANT - from the "grad_y = ..." line, looks ilke grad_x is [B/d L M]
# - makes sense, grad_x := dL/dx, where x is activation (output of this layer)
# - TODO: however, matx_annotated.py seems to indicate each layer's output is actually [B/d L M/t] (with
# - M sharded over t), and only the final layer's output is all_gathered (outside the for loop)
# - Not sure how this is reconciled with having this be [M] (unsharded) within this loop_body
layer = model.transformer[layer_idx]
grad_layer = grad_model.grad_transformer[layer_idx]
### FFN's W_down
# Gather the down-projection weights of the FFN
w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer.w_down))
# Compute gradient w.r.t. y (input to down-projection): dL/dy = dL/dx * dx/dy = grad_x * w_down^T
grad_y = shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', grad_x, w_down)
# Compute gradient w.r.t. w_down: dL/dw_down = y^T * dL/dx
# Note: assumes activations (e.g. y) are available/checkpointed from forward pass
grad_w_down = shardops.einsum_unreduced('B/d L F/t, B/d L M -> M F/t', y, grad_x)
# scatter w_down gradients back to their respective devices
# - remember, contracting dimension is sharded (B/d) so need to reduce_scatter
grad_layer.grad_w_down = shardops.psum_scatter('M F/t -> M/d F/t', grad_w_down)
### FFN (SwiGLU)'s W_gate and W_proj
# Gather the up-projection and gate weights of the FFN
w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer.w_up))
w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer.w_gate))
# Compute gradients for the SwiGLU activation
grad_gate_proj, grad_up_proj = swish_backward(grad_y, gate_proj, up_proj)
# Compute gradients w.r.t. w_up and w_gate
grad_w_up = shardops.einsum_unreduced('B/d L M, B/d L F/t -> M F/t', nx, grad_up_proj)
grad_w_gate = shardops.einsum_unreduced('B/d L M, B/d L F/t -> M F/t', nx, grad_gate_proj)
# Scatter w_up and w_gate gradients back to their respective devices
grad_layer.grad_w_up = shardops.psum_scatter('M F/t -> M/d F/t', grad_w_up)
grad_layer.grad_w_gate = shardops.psum_scatter('M F/t -> M/d F/t', grad_w_gate)
###
# Compute gradient w.r.t. nx (input to FFN): dL/dnx = dL/dy_up * dy_up/dnx + dL/dy_gate * dy_gate/dnx
# - Since the output goes to 2 different outputs in the forward pass), naturally the gradient must be combined from each path (think backprop circuits analysis)
grad_nx = (shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', grad_up_proj, w_up) +
shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', grad_gate_proj, w_gate))
### pre-FFN LN
# Gather the pre-FFN layer norm weights
ln2 = shardops.all_gather('M/t/d -> M', jnp.float32(layer.ln2))
# Compute gradients for the pre-FFN layer norm
grad_ln2, grad_x = rms_norm_backward(grad_nx, x, ln2)
# Scatter pre-FFN layer norm gradients back to their respective devices
grad_layer.grad_ln2 = shardops.psum_scatter('M -> M/t/d', grad_ln2)
### attn_out @ w_out
# Gather the output projection weights of the attention mechanism
w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer.w_o))
# Compute gradient w.r.t. attention output: dL/dattn_out = dL/dx * dx/dattn_out = grad_x * w_o^T
grad_attn_out = shardops.einsum_unreduced('B/d Qlen M, M Q K/t D -> B/d Qlen Q K/t D', grad_x, w_o)
# Compute gradient w.r.t. w_o: dL/dw_o = attn_out^T * dL/dx
grad_w_o = shardops.einsum_unreduced('B/d Qlen Q K/t D, B/d Qlen M -> M Q K/t D', attn_out, grad_x)
# Scatter w_o gradients back to their respective devices
grad_layer.grad_w_o = shardops.psum_scatter('M Q K/t D -> M/d Q K/t D', grad_w_o)
### attn_probs and Q K activations
# Note: no actual *weight* gradients (just intermediate grads used dowsntream), so no psum_scatter
# Compute gradient w.r.t. attention probabilities: dL/dprobs = dL/dattn_out * v^T
grad_probs = shardops.einsum_unreduced('B/d Qlen Q K/t D, B/d Klen K/t D -> B/d Qlen Klen Q K/t', grad_attn_out, v)
# Compute gradient w.r.t. values: dL/dv = probs^T * dL/dattn_out
grad_v = shardops.einsum_unreduced('B/d Qlen Klen Q K/t, B/d Qlen Q K/t D -> B/d Klen K/t D', probs, grad_attn_out)
# Compute gradient w.r.t. logits (pre-softmax): dL/dlogits = dL/dprobs * dprobs/dlogits
grad_logits = softmax_backward(grad_probs, probs)
# Apply causal mask to logits gradients
grad_logits = jnp.where(causal_mask, grad_logits, 0)
# Compute gradient w.r.t. queries: dL/dq = dL/dlogits * k^T
grad_q = shardops.einsum_unreduced('B/d Qlen Klen Q K/t, B/d Klen K/t D -> B/d Qlen Q K/t D', grad_logits, k)
# Compute gradient w.r.t. keys: dL/dk = q^T * dL/dlogits
grad_k = shardops.einsum_unreduced('B/d Qlen Klen Q K/t, B/d Qlen Q K/t D -> B/d Klen K/t D', grad_logits, q)
### RoPE
grad_q = rope_table.apply_inverse('L D -> 1 L 1 1 D', grad_q)
grad_k = rope_table.apply_inverse('L d -> 1 L 1 d', grad_k)
### w_q and q_kv
# Gather query and key-value projection weights
w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer.w_q))
w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(layer.w_kv))
# Remember the "+="
# Compute gradient w.r.t. nx (input to attention): dL/dnx = dL/dq * w_q^T + dL/dk * w_k^T + dL/dv * w_v^T
grad_nx += shardops.einsum_unreduced('B/d L Q K/t D, M Q K/t D -> B/d L M', grad_q, w_q)
grad_nx += shardops.einsum_unreduced('k_v B/d L K/t D, k_v M K/t D -> B/d L M', jnp.stack([grad_k, grad_v]), w_kv)
# Compute gradients w.r.t. w_q and w_kv
grad_w_q = shardops.einsum_unreduced('B/d L M, B/d L Q K/t D -> M Q K/t D', nx, grad_q)
grad_w_kv = shardops.einsum_unreduced('B/d L M, k_v B/d L K/t D -> k_v M K/t D', nx, jnp.stack([grad_k, grad_v]))
# Scatter w_q and w_kv gradients back to their respective devices
grad_layer.grad_w_q = shardops.psum_scatter('M Q K/t D -> M/d Q K/t D', grad_w_q)
grad_layer.grad_w_kv = shardops.psum_scatter('k_v M K/t D -> k_v M/d K/t D', grad_w_kv)
### Pre-attention layer norm
# Gather the pre-attention layer norm weights
ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(layer.ln1))
# Compute gradients for the pre-attention layer norm
grad_ln1, grad_x = rms_norm_backward(grad_nx, x, ln1)
# Scatter pre-attention layer norm gradients back to their respective devices
grad_layer.grad_ln1 = shardops.psum_scatter('M -> M/t/d', grad_ln1)
return grad_x, None
### Apply the backward pass through all transformer layers in reverse order
grad_x, _ = jax.lax.scan(loop_body, grad_x, jnp.arange(len(model.transformer) - 1, -1, -1))
### Initial vocab embedding layer
# Gather the embedding matrix
embed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(model.embed))
# Compute gradient w.r.t. embeddings: dL/dembed = sum(dL/dx * one_hot(ids))
grad_embed = shardops.index_add('V/t M, B/d L -> V/t M', jnp.zeros_like(embed), ids, grad_x)
# Scatter embedding gradients back to their respective devices
grad_model.grad_embed = shardops.psum_scatter('V/t M -> V/t M/d', grad_embed)
return grad_model