The power of good notation, and training distributed LLMs

I've been working with sharding LLMs at work recently. It was a good time to review some "old" classics like Mesh-TensorFlow (2018), Megatron-LM (2019), ZeRO-3 (2019), and Efficiently Scaling Transformer Inference (2022).

The first author of the inference paper has since co-founded a chip company MatX. They recently put out a repository called seqax for LLM training, which I quite liked.

Good, simple, and general syntax can be mind opening. Consider the terminology around parallelism strategies - I think we can all agree data/model/tensor parallelism naming is a bit of a mess. (It's a bit more complex, but I prefer the weight vs. activation partitioning terms originally used Mesh-Tensorflow.)

I thought it'd be a good exercise in this post to:

  1. write out the sharded version of the Transformer equations for a model using both FSDP and Megatron-style activation partitioning.
  2. see how this maps to code, by annotating seqax's forward pass implementation.
  3. write the backward pass equivalents in equations and code.

On the notation below:

  1. The seqax repo uses a simple "named axes" syntax, where f32['in hidden1/d'] is a 2D weight whose hidden1 axis is split across the d device-mesh axis. The repo explains it here.
  2. In my equations, I'll also use the {U_d} notation to indicate an unreduced sum over d, which occurs from a einsum over a sharded dimension, and requires a reduce_scatter() afterwards.

These two together make the universe of sharding strategies in distributed LLM training quite simple to write.

(Note: I won't review when/why to use different shardings; the original papers and online blogposts should help with that. Also, I did write this by hand, so apologies if there are any mistakes!)

Part 1: a distributed LLM in equation form (forward pass)

Pre- Transformer

  1. Input assumptions: Start with input $x[B_d,T]$, where $B$ is the batch size and $T$ is the sequence length.
  2. One-hot mapping: Map tokens to one-hot encodings:
    • $x[B_d,T,V_t] = \text{Lookup}(x[B,T])$, where $V$ is vocabulary size.
  3. Embedding projection:
    • First, gather the embedding weights across the data dimension
      • $W_{embed}[V_t,M] = \text{AllGather}_d(W_{embed}[V_t,M_d])$
    • Matrix multiply, contracting over the $t$-sharded vocab axis (so the result is unreduced over $t$)
      • $x[B_d,T,M]\{U_t\} = x[B_d,T,V_t] \cdot W_{embed}[V_t,M]$
    • Reduce-scatter to land the model dimension sharded over $t$
      • $x[B_d,T,M_t] = \text{ReduceScatter}_t(x[B_d,T,M]\{U_t\})$

Transformer block: FFN

Note that $M$ is typically sharded across the data dimension (FSDP), while the inner feed-forward dimension $F$ is sharded across the tensor dimension (Megatron).

  • pre-FFN LayerNorm: $LN_2[M_{t,d}]$ (sharded across both $t$ and $d$)
  • Gate weight: $W_{gate}[M_d,F_t]$
  • Up weight: $W_{up}[M_d,F_t]$
  • Down weight: $W_{down}[M_d,F_t]$

Assuming the residual arrives as $x[B_d,T,M_t]$, the forward pass steps look like this:

  1. Apply pre-FFN LayerNorm (an affine transformation applied via element-wise multiplication):
    • First all-gather the residual to a replicated copy: $x[B_d,T,M] = \text{AllGather}_t(x[B_d,T,M_t])$
    • $LN_2[M] = \text{AllGather}_{t,d}(LN_2[M_{t,d}])$
    • $nx[B_d,T,M] = \text{RMSNorm}(x)[B_d,T,M] \odot LN_2[M]$
  2. Gate projection (matmul, requiring FSDP's weight gather first):
    • $W_{gate}[M,F_t] = \text{AllGather}_d(W_{gate}[M_d,F_t])$
    • $\text{gate\_proj}[B_d,T,F_t] = nx[B_d,T,M] \cdot W_{gate}[M,F_t]$
  3. Up projection:
    • $W_{up}[M,F_t] = \text{AllGather}_d(W_{up}[M_d,F_t])$
    • $\text{up\_proj}[B_d,T,F_t] = nx[B_d,T,M] \cdot W_{up}[M,F_t]$
  4. Swish activation applied on the projections:
    • $y[B_d,T,F_t] = \text{Swish}(\text{gate\_proj}[B_d,T,F_t]) \odot \text{up\_proj}[B_d,T,F_t]$
  5. Down projection: The contracting dimension ($F$) is sharded here, so the matmul is unreduced over $t$; we reduce-scatter it back to a $t$-sharded model dimension.
    • $W_{down}[M,F_t] = \text{AllGather}_d(W_{down}[M_d,F_t])$
    • $\text{ffn\_out}[B_d,T,M]\{U_t\} = y[B_d,T,F_t] \cdot W_{down}[M,F_t]^T$
    • $\text{ffn\_out}[B_d,T,M_t] = \text{ReduceScatter}_t(\text{ffn\_out}[B_d,T,M]\{U_t\})$
  6. Residual connection: both operands are already sharded over $t$, so no extra communication is needed.
    • $x[B_d,T,M_t] \mathrel{+}= \text{ffn\_out}[B_d,T,M_t]$

Transformer block: attention

Let's define $Q$ as the number of query heads per KV head, $K$ as the number of KV heads, and $D$ as the head dimension.

  • pre-Attention LN: $LN_1[M_{t,d}]$
  • Q weight: $W_Q[M_d, Q, K_t, D]$
  • K weight: $W_K[M_d, K_t, D]$ (Sliced from $W_{KV}[2, M_d, K_t, D]$)
  • V weight: $W_V[M_d, K_t, D]$ (Sliced from $W_{KV}[2, M_d, K_t, D]$)
  • O (Output) weight: $W_O[M_d, Q, K_t, D]$

The residual arrives as $x[B_d,T,M_t]$, as in the FFN, the pre-attention RMSNorm all-gathers it over $t$ to a replicated $x[B_d,T,M]$, which is the projection input below:

  1. Calculate q:
    • $W_Q[M, Q, K_t, D] = \text{AllGather}_d(W_Q[M_d, Q, K_t, D])$
    • $q[B_d,T,Q, K_t, D] = x[B_d,T,M] \cdot W_Q[M, Q, K_t, D]$
  2. Calculate k: (Note: KVs may be cached and of a different length $S$, e.g., in sliding window attention).
    • $W_K[M,K_t, D] = \text{AllGather}_d(W_K[M_d, K_t,D])$
    • $k[B_d,S,K_t, D] = x[B_d,T,M] \cdot W_K[M,K_t,D]$
  3. Apply RoPE to q and k:
    • $q[B_d,T,Q, K_t, D] = \text{RoPE}(q[B_d,T,Q, K_t, D])$
    • $k[B_d,S,K_t, D] = \text{RoPE}(k[B_d,S,K_t, D])$
  4. Calculate QK scores: All dimensions remain except the head dimension $D$.
    • $logits[B_d,T,S, Q,K_t] = q[B_d,T,Q, K_t, D] \cdot k[B_d,S,K_t, D]$
  5. Apply Mask to $logits$.
  6. Calculate Attention Weights: $probs = \text{softmax}(logits)$
  7. Calculate v:
    • $W_V[M,K_t, D] = \text{AllGather}_d(W_V[M_d, K_t,D])$
    • $v[B_d,S,K_t, D] = x[B_d,T,M] \cdot W_V[M,K_t,D]$
  8. Calculate weighted values: All dimensions remain except the KV sequence length $S$.
    • $attn\_out[B_d,T,Q,K_t,D] = probs[B_d,T,S, Q,K_t] \cdot v[B_d,S,K_t,D]$
  9. Up projection:
    • Pre-emptively all-gather to combine with the residual's $M$ dimension:
      $W_O[M, Q, K_t, D] = \text{AllGather}_d(W_O[M_d, Q, K_t, D])$
    • Perform the projection (which requires a ReduceScatter since the contracting dimension $K_t$ is sharded):
      $attn\_out[B_d,T,M]\{U_t\} = attn\_out[B_d,T,Q,K_t,D] \cdot W_O[M, Q, K_t, D]$
    • $attn\_out[B_d,T,M_t] = \text{ReduceScatter}_t(attn\_out[B_d,T,M]\{U_t\})$
  10. Residual connection: (both operands sharded over $t$)
    • $x[B_d,T,M_t] = x[B_d,T,M_t] + attn\_out[B_d,T,M_t]$

Post- Transformer

Finally, we project the activations back out to the vocabulary size to get our probabilities.

  1. Final LayerNorm: All-gather the residual over $t$ first
    • $x[B_d,T,M] = \text{AllGather}_t(x[B_d,T,M_t])$, then apply RMSNorm.
  2. Unembedding: Matmul to get logits (requires an AllGather of the unembedding weights over $d$).
    • $W_{unembed}[V_t, M] = \text{AllGather}_d(W_{unembed}[V_t, M_d])$
    • Contracting over the replicated $M$ leaves the vocab axis sharded over $t$: $logits[B_d,T,V_t] = x[B_d,T,M] \cdot W_{unembed}[V_t,M]^T$.
  3. Compute Loss

Part 2: how does this map to code? annotating seqax's forward pass

# Important: defines the sharding
@pytree_dataclass
class TransformerLayer:
  # LayerNorms are sharded over 2 axes
  ln1: f32['d_model/t/d']  # used pre-attention
  ln2: f32['d_model/t/d']  # used pre-FFN

  ### Attention
  # Note: w_q and w_kv defined separately
  # - Helps simplify usage when Q and KVs can have different lengths t and s
  # - E.g. when KV caching, t=1
  # Input dimensions of QKV heads sharded over d, n_kv sharded over t
  # GQA: each KV has multiple heads, so w_q includes [n_q_per_kv, n_kv/t]
  w_q: f32['d_model/d n_q_per_kv n_kv/t d_head']
  w_kv: f32['2 d_model/d n_kv/t d_head']
  w_o: f32['d_model/d n_q_per_kv n_kv/t d_head']

  ### FFN
  # FFN's input_dim sharded over d, output dim sharded over t
  # -- this means it's using FSDP (d) and Megatron (t) parallelism
  # FFN uses SwiGLU, i.e. (Swish(w_proj @ x) * (w_gate @ x)) @ w_down
  w_gate: f32['d_model/d d_ff/t']
  w_up: f32['d_model/d d_ff/t']
  w_down: f32['d_model/d d_ff/t']


@pytree_dataclass
class Model:
  # Each dimension of embedding is sharded (over separate axes)
  embed: f32['vocab/t d_model/d']
  unembed: f32['vocab/t d_model/d']
  transformer: Transformer
  # Again, LN is sharded over 2 axes
  final_layer_norm: f32['d_model/d/t']

  @typechecked
  # The ids are ofc sharded over batch dimension
  def forward_pass(self, h: Hparams, ids: u32[b'B/d L'], is_seq_start: bool_[b'B/d L']) -> f32[b'B/d L V/t']:

    ### Initial embedding lookup
    embed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(self.embed))
    x = shardops.index_unreduced('[V/t] M, B/d L -> B/d L M', embed, ids)
    x = shardops.psum_scatter('B/d L M -> B/d L M/t', x)

    L = ids.shape[1]
    # Sequence packing
    # `is_seq_start` flags the first token of each new document
    # - hence segment_ids will be [1, 1, 2, 2, 2, 3, ...]
    segment_ids = jnp.cumsum(is_seq_start, axis=1)
    segment_mask: bool_[b'B/d L L'] = segment_ids[:, :, jnp.newaxis] == segment_ids[:, jnp.newaxis, :]
    # add axes for q_per_k, num_kv_heads dimensions
    segment_mask: bool_[b'B/d L L 1 1'] = segment_mask[..., jnp.newaxis, jnp.newaxis]
    causal_mask: bool_[b'1 L L 1 1'] = jnp.tril(jnp.ones((L, L), dtype=jnp.bool_), 0)[jnp.newaxis, ..., jnp.newaxis, jnp.newaxis]
    causal_mask: bool_[b'B/d L L 1 1'] = jnp.logical_and(segment_mask, causal_mask)

    rope_table = RopeTable.create(L, h)

    ### Transformer blocks
    @explicit_activation_checkpointing
    @typechecked
    def loop_body(x: bf16[b'B/d L M/t'], layer_weights: TransformerLayer) -> Tuple[bf16[b'B/d L M/t'], Tuple[()]]:
      # x = [bsz (sharded over d), length, dim=M (sharded over t)]

      ### Pre-attention RMSNorm
      # FSDP's all_gather() step
      ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln1))
      gx = shardops.all_gather('B/d L M/t -> B/d L M', x)
      # Multiple by ln1, as it's the \gamma in LayerNorm / RMSNorm
      nx = jnp.bfloat16(rms_norm(gx) * ln1)  # EC: [B/d L M]
      # Notice input (nx) still sharded over batch dimension for data parallelism

      ### Attention, using Grouped Query Attention and RoPE position embeddings.
      # GQA: eg. n_kv=4 heads, n_q=8 heads --> n_q_per_kv=2 (2 Q heads per KV head)
      # - shorthand Q := n_q_per_kv (defined in TransformerLayer class), and K := n_kv

      # FSDP's all_gather() step before local matmul
      w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_q))
      # save_for_backward() is just activation checkpointing
      q = save_for_backward(shardops.einsum_unreduced('B/d L M, M Q K/t D -> B/d L Q K/t D', nx, w_q))
      q = rope_table.apply('L D -> 1 L 1 1 D', q)

      # FSDP's all_gather() step before local matmul
      w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(layer_weights.w_kv))
      k, v = shardops.einsum_unreduced('B/d L M, k_v M K/t D -> k_v B/d L K/t D', nx, w_kv)
      k = save_for_backward(k)
      v = save_for_backward(v)
      k = rope_table.apply('L d -> 1 L 1 d', k)

      # Because we're doing GMQA, the logits include [Q K/t]
      # - basically just the attention between the i-th Q and j-th KV
      # Note: Use float32 for logits
      logits = shardops.einsum_unreduced(
        'B/d Qlen Q K/t D, B/d Klen K/t D -> B/d Qlen Klen Q K/t', q, k, preferred_element_type=jnp.float32)
      logits = jnp.where(causal_mask, logits, -1e10)
      # Note: doesn't seem to have the division by sqrt(d)?
      probs = jnp.bfloat16(jax.nn.softmax(logits, axis=2))
      # QK @ V for final attention embeddings
      attn_out = shardops.einsum_unreduced(
        'B/d Qlen Klen Q K/t, B/d Klen K/t D -> B/d Qlen Q K/t D', probs, v)

      # FSDP's all_gather() step before local matmul
      w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_o))
      attn_out = shardops.einsum_unreduced('B/d Qlen Q K/t D, M Q K/t D -> B/d Qlen M', attn_out, w_o)
      # Note: we are doing a matmul where a contacting dimension is sharded (K/t),
      # - Thus, we'll need to do a reduce-scatter after; this is Megatron pattern
      attn_out = shardops.psum_scatter('B/d Qlen M -> B/d Qlen M/t', attn_out)

      x = save_for_backward(x + attn_out)

      ### FFN, using SwiGLU

      # Pre-FFN RMSNorm
      # Same RMSNorm as pre-attention
      ln2 = save_for_backward(shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln2)))
      gx = shardops.all_gather('B/d L M/t -> B/d L M', x)
      nx = jnp.bfloat16(rms_norm(gx) * ln2)

      # FSDP's all_gather() step for three weights, over input_dim (M/D -> M), before local matmul
      # Note: these weight matrices are the larger ones, to save on activation checkpointing memory
      # - I think it's not uncommon to only save_for_backward(gate_proj, up_proj)?
      w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_gate))
      gate_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_gate))
      w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_up))
      up_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_up))
      y = jax.nn.swish(gate_proj) * up_proj  # EC: this * is element wise, so [B/d L F/t]
      w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_down))
      # Note: we are doing a matmul where a contacting dimension is sharded (F/t),
      # - Thus, we'll need to do a reduce-scatter after; this is Megatron pattern
      ffn_out = shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', y, w_down)
      ffn_out = shardops.psum_scatter('B/d L M -> B/d L M/t', ffn_out)

      return jnp.bfloat16(x + ffn_out), ()

    x, () = jax.lax.scan(loop_body, jnp.bfloat16(x), self.transformer)

    ### Final layernorm and output projection.
    # all_gather to prepare for layer norm (actually RMSNorm)
    x = shardops.all_gather('B/d L M/t -> B/d L M', x)
    ln = shardops.all_gather('M/t/d -> M', jnp.float32(self.final_layer_norm))
    x = jnp.bfloat16(rms_norm(x) * ln)
    # all_gather to prepare for unembed
    unembed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(self.unembed))
    logits = shardops.einsum_unreduced('B/d L M, V/t M -> B/d L V/t', x, unembed, preferred_element_type=jnp.float32)

    return logits

Part 3: a distributed LLM in equation form (backward pass)

Transformer: FFN

  1. Initial state: Assume we are given $grad\_x[B_d,T,M]$, representing the upstream gradient with respect to the output of the FFN block.
  2. Compute gradients for down projection (wrt $W_{down}[M_d,F_t]$ and $y[B_d,T,F_t]$):
    • Gradient wrt $y[B_d,T,F_t]$:
      $W_{down}[M,F_t] = \text{AllGather}_d(W_{down}[M_d,F_t])$
      $grad\_y[B_d,T,F_t] = grad\_x[B_d,T,M] \cdot W_{down}[M,F_t]$
      (This $grad\_y$ now becomes the upstream gradient for the first part of the FFN.)
    • Gradient wrt $W_{down}[M_d,F_t]$:
      $grad\_w\_down[M,F_t]\{U_d\} = y[B_d,T,F_t]^T \cdot grad\_x[B_d,T,M]$
      Since the batch dimension $B_d$ is the contracting dimension, we need a scatter-reduce:
      $grad\_w\_down[M_d,F_t] = \text{ScatterReduce}_d(grad\_w\_down[M,F_t]\{U_d\})$
  3. Compute gradients wrt SwiGLU ($\text{Swish}(\text{gate\_proj}) \odot \text{up\_proj}$):
    • Assume a function $\text{swish\_grad}(\text{gate\_proj}) = \nabla \text{swish}(\text{gate\_proj})$.
    • $grad\_gate\_proj[B_d,T,F_t] = grad\_y[B_d,T,F_t] \odot \text{swish\_grad}(\text{gate\_proj}[B_d,T,F_t]) \odot \text{up\_proj}[B_d,T,F_t]$
    • $grad\_up\_proj[B_d,T,F_t] = grad\_y[B_d,T,F_t] \odot \text{Swish}(\text{gate\_proj}[B_d,T,F_t])$
  4. Compute gradients for gate projection (wrt $W_{gate}[M_d,F_t]$ and $nx[B_d,T,M]$):
    • Gradient wrt $W_{gate}[M_d,F_t]$:
      $grad\_w\_gate[M,F_t]\{U_d\} = nx[B_d,T,M]^T \cdot grad\_gate\_proj[B_d,T,F_t]$
      $grad\_w\_gate[M_d,F_t] = \text{ScatterReduce}_d(grad\_w\_gate[M,F_t]\{U_d\})$
    • Gradient wrt $nx[B_d,T,M]$ (from gate branch):
      $W_{gate}[M,F_t] = \text{AllGather}_d(W_{gate}[M_d,F_t])$
      $grad\_nx_{gate}[B_d,T,M]\{U_t\} = grad\_gate\_proj[B_d,T,F_t] \cdot W_{gate}[M,F_t]^T$
      $grad\_nx_{gate}[B_d,T,M] = \text{AllReduce}_t(grad\_nx_{gate}[B_d,T,M]\{U_t\})$
  5. Compute gradients for up projection (wrt $W_{up}[M_d,F_t]$ and $nx[B_d,T,M]$):
    • Gradient wrt $W_{up}[M_d,F_t]$:
      $grad\_w\_up[M,F_t]\{U_d\} = nx[B_d,T,M]^T \cdot grad\_up\_proj[B_d,T,F_t]$
      $grad\_w\_up[M_d,F_t] = \text{ScatterReduce}_d(grad\_w\_up[M,F_t]\{U_d\})$
    • Gradient wrt $nx[B_d,T,M]$ (from up branch):
      $W_{up}[M,F_t] = \text{AllGather}_d(W_{up}[M_d,F_t])$
      $grad\_nx_{up}[B_d,T,M]\{U_t\} = grad\_up\_proj[B_d,T,F_t] \cdot W_{up}[M,F_t]^T$
      $grad\_nx_{up}[B_d,T,M] = \text{AllReduce}_t(grad\_nx_{up}[B_d,T,M]\{U_t\})$
  6. Combine gradients:
    • Because the forward pass split the input $nx$ into two parallel paths (gate and up), the backward pass must sum their respective gradients:
    • $grad\_nx[B_d,T,M] = grad\_nx_{gate}[B_d,T,M] + grad\_nx_{up}[B_d,T,M]$

Part 4: implementing and annotating a backward pass

The following pattern is repeated over and over: all_gather --> local einsum --> psum_scatter (scatter_reduce)


import jax
import jax.numpy as jnp
import shardops

# These should match the dimensions of the weights ofc
@pytree_dataclass
class GradTransformerLayer:
    grad_ln1: f32['d_model/t/d']
    grad_ln2: f32['d_model/t/d']
    grad_w_q: f32['d_model/d n_q_per_kv n_kv/t d_head']
    grad_w_kv: f32['2 d_model/d n_kv/t d_head']
    grad_w_o: f32['d_model/d n_q_per_kv n_kv/t d_head']
    grad_w_gate: f32['d_model/d d_ff/t']
    grad_w_up: f32['d_model/d d_ff/t']
    grad_w_down: f32['d_model/d d_ff/t']

@pytree_dataclass
class GradModel:
    grad_embed: f32['vocab/t d_model/d']
    grad_unembed: f32['vocab/t d_model/d']
    grad_transformer: List[GradTransformerLayer]
    grad_final_layer_norm: f32['d_model/d/t']

def backward_pass(model: Model, h: Hparams, ids: u32[b'B/d L'], is_seq_start: bool_[b'B/d L'], 
                  output_grad: f32[b'B/d L V/t']) -> GradModel:
    # output_grad is the "upstream gradient". We calculate "local gradients" (at a current
    # operation) and then compute downstream gradients as output_grad * local gradient
    # - is given the logits, [B/d L V/t],

    # Initialize gradient structures to store computed gradients
    grad_model = GradModel(
        grad_embed=jnp.zeros_like(model.embed),
        grad_unembed=jnp.zeros_like(model.unembed),
        grad_transformer=[GradTransformerLayer(...) for _ in range(len(model.transformer))],
        grad_final_layer_norm=jnp.zeros_like(model.final_layer_norm)
    )

    # un-embedding
    # EC: gradients computed here (not in a separate helper function since it's just x * unembed)
    # EC: TODO -- why do we have to all_gather unembed, only to turn M to M/d in the reduce_scatter???
    # Gather the unembed matrix across devices for the backward pass
    unembed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(model.unembed))
    # compute gradient w.r.t. x (input to unembed): dL/dx = dL/dy * dy/dx = output_grad * unembed^T
    grad_x = shardops.einsum_unreduced('B/d L V/t, V/t M -> B/d L M', output_grad, unembed)
    # compute gradient w.r.t. unembed weights: dL/dunembed = x^T * dL/dy = x^T * output_grad
    grad_unembed = shardops.einsum_unreduced('B/d L M, B/d L V/t -> V/t M', x, output_grad)
    # scatter the unembed gradients back to their respective devices
    grad_model.grad_unembed = shardops.psum_scatter('V/t M -> V/t M/d', grad_unembed)

    # RMSNorm, assuming a rms_norm_backward fn exists, producing the local gradient
    # Pattern: all_gather --> local einsum --> psum_scatter
    ln = shardops.all_gather('M/t/d -> M', jnp.float32(model.final_layer_norm))
    grad_ln, grad_x = rms_norm_backward(grad_x, x, ln)
    # scatter the final layer norm gradients back to their respective devices
    grad_model.grad_final_layer_norm = shardops.psum_scatter('M -> M/t/d', grad_ln)

    def loop_body(grad_x, layer_idx):
        # EC: IMPORTANT - from the "grad_y = ..." line, looks ilke grad_x is [B/d L M]
        # - makes sense, grad_x := dL/dx, where x is activation (output of this layer)
        # - TODO: however, matx_annotated.py seems to indicate each layer's output is actually [B/d L M/t] (with
        # - M sharded over t), and only the final layer's output is all_gathered (outside the for loop)
        # - Not sure how this is reconciled with having this be [M] (unsharded) within this loop_body

        layer = model.transformer[layer_idx]
        grad_layer = grad_model.grad_transformer[layer_idx]

        ### FFN's W_down
        # Gather the down-projection weights of the FFN
        w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer.w_down))
        # Compute gradient w.r.t. y (input to down-projection): dL/dy = dL/dx * dx/dy = grad_x * w_down^T
        grad_y = shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', grad_x, w_down)
        # Compute gradient w.r.t. w_down: dL/dw_down = y^T * dL/dx
        # Note: assumes activations (e.g. y) are available/checkpointed from forward pass
        grad_w_down = shardops.einsum_unreduced('B/d L F/t, B/d L M -> M F/t', y, grad_x)
        # scatter w_down gradients back to their respective devices
        # - remember, contracting dimension is sharded (B/d) so need to reduce_scatter
        grad_layer.grad_w_down = shardops.psum_scatter('M F/t -> M/d F/t', grad_w_down)

        ### FFN (SwiGLU)'s W_gate and W_proj
        # Gather the up-projection and gate weights of the FFN
        w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer.w_up))
        w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer.w_gate))
        # Compute gradients for the SwiGLU activation
        grad_gate_proj, grad_up_proj = swish_backward(grad_y, gate_proj, up_proj)
        # Compute gradients w.r.t. w_up and w_gate
        grad_w_up = shardops.einsum_unreduced('B/d L M, B/d L F/t -> M F/t', nx, grad_up_proj)
        grad_w_gate = shardops.einsum_unreduced('B/d L M, B/d L F/t -> M F/t', nx, grad_gate_proj)        
        # Scatter w_up and w_gate gradients back to their respective devices
        grad_layer.grad_w_up = shardops.psum_scatter('M F/t -> M/d F/t', grad_w_up)
        grad_layer.grad_w_gate = shardops.psum_scatter('M F/t -> M/d F/t', grad_w_gate)

        ### 
        # Compute gradient w.r.t. nx (input to FFN): dL/dnx = dL/dy_up * dy_up/dnx + dL/dy_gate * dy_gate/dnx
        # - Since the output goes to 2 different outputs in the forward pass), naturally the gradient must be combined from each path (think backprop circuits analysis)
        grad_nx = (shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', grad_up_proj, w_up) + 
                   shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', grad_gate_proj, w_gate))

        ### pre-FFN LN
        # Gather the pre-FFN layer norm weights
        ln2 = shardops.all_gather('M/t/d -> M', jnp.float32(layer.ln2))
        # Compute gradients for the pre-FFN layer norm
        grad_ln2, grad_x = rms_norm_backward(grad_nx, x, ln2)
        # Scatter pre-FFN layer norm gradients back to their respective devices
        grad_layer.grad_ln2 = shardops.psum_scatter('M -> M/t/d', grad_ln2)

        ### attn_out @ w_out
        # Gather the output projection weights of the attention mechanism
        w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer.w_o))
        # Compute gradient w.r.t. attention output: dL/dattn_out = dL/dx * dx/dattn_out = grad_x * w_o^T
        grad_attn_out = shardops.einsum_unreduced('B/d Qlen M, M Q K/t D -> B/d Qlen Q K/t D', grad_x, w_o)
        # Compute gradient w.r.t. w_o: dL/dw_o = attn_out^T * dL/dx
        grad_w_o = shardops.einsum_unreduced('B/d Qlen Q K/t D, B/d Qlen M -> M Q K/t D', attn_out, grad_x)
        # Scatter w_o gradients back to their respective devices
        grad_layer.grad_w_o = shardops.psum_scatter('M Q K/t D -> M/d Q K/t D', grad_w_o)

        ### attn_probs and Q K activations
        # Note: no actual *weight* gradients (just intermediate grads used dowsntream), so no psum_scatter
        # Compute gradient w.r.t. attention probabilities: dL/dprobs = dL/dattn_out * v^T
        grad_probs = shardops.einsum_unreduced('B/d Qlen Q K/t D, B/d Klen K/t D -> B/d Qlen Klen Q K/t', grad_attn_out, v)
        # Compute gradient w.r.t. values: dL/dv = probs^T * dL/dattn_out
        grad_v = shardops.einsum_unreduced('B/d Qlen Klen Q K/t, B/d Qlen Q K/t D -> B/d Klen K/t D', probs, grad_attn_out)
        # Compute gradient w.r.t. logits (pre-softmax): dL/dlogits = dL/dprobs * dprobs/dlogits
        grad_logits = softmax_backward(grad_probs, probs)
        # Apply causal mask to logits gradients
        grad_logits = jnp.where(causal_mask, grad_logits, 0)
        # Compute gradient w.r.t. queries: dL/dq = dL/dlogits * k^T
        grad_q = shardops.einsum_unreduced('B/d Qlen Klen Q K/t, B/d Klen K/t D -> B/d Qlen Q K/t D', grad_logits, k)
        # Compute gradient w.r.t. keys: dL/dk = q^T * dL/dlogits
        grad_k = shardops.einsum_unreduced('B/d Qlen Klen Q K/t, B/d Qlen Q K/t D -> B/d Klen K/t D', grad_logits, q)

        ### RoPE
        grad_q = rope_table.apply_inverse('L D -> 1 L 1 1 D', grad_q)
        grad_k = rope_table.apply_inverse('L d -> 1 L 1 d', grad_k)

        ### w_q and q_kv
        # Gather query and key-value projection weights
        w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer.w_q))
        w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(layer.w_kv))
        # Remember the "+="
        # Compute gradient w.r.t. nx (input to attention): dL/dnx = dL/dq * w_q^T + dL/dk * w_k^T + dL/dv * w_v^T
        grad_nx += shardops.einsum_unreduced('B/d L Q K/t D, M Q K/t D -> B/d L M', grad_q, w_q)
        grad_nx += shardops.einsum_unreduced('k_v B/d L K/t D, k_v M K/t D -> B/d L M', jnp.stack([grad_k, grad_v]), w_kv)
        # Compute gradients w.r.t. w_q and w_kv
        grad_w_q = shardops.einsum_unreduced('B/d L M, B/d L Q K/t D -> M Q K/t D', nx, grad_q)
        grad_w_kv = shardops.einsum_unreduced('B/d L M, k_v B/d L K/t D -> k_v M K/t D', nx, jnp.stack([grad_k, grad_v]))
        # Scatter w_q and w_kv gradients back to their respective devices
        grad_layer.grad_w_q = shardops.psum_scatter('M Q K/t D -> M/d Q K/t D', grad_w_q)
        grad_layer.grad_w_kv = shardops.psum_scatter('k_v M K/t D -> k_v M/d K/t D', grad_w_kv)

        ### Pre-attention layer norm
        # Gather the pre-attention layer norm weights
        ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(layer.ln1))
        # Compute gradients for the pre-attention layer norm
        grad_ln1, grad_x = rms_norm_backward(grad_nx, x, ln1)
        # Scatter pre-attention layer norm gradients back to their respective devices
        grad_layer.grad_ln1 = shardops.psum_scatter('M -> M/t/d', grad_ln1)

        return grad_x, None

    ### Apply the backward pass through all transformer layers in reverse order
    grad_x, _ = jax.lax.scan(loop_body, grad_x, jnp.arange(len(model.transformer) - 1, -1, -1))

    ### Initial vocab embedding layer
    # Gather the embedding matrix
    embed = shardops.all_gather('V/t M/d -> V/t M', jnp.bfloat16(model.embed))
    # Compute gradient w.r.t. embeddings: dL/dembed = sum(dL/dx * one_hot(ids))
    grad_embed = shardops.index_add('V/t M, B/d L -> V/t M', jnp.zeros_like(embed), ids, grad_x)
    # Scatter embedding gradients back to their respective devices
    grad_model.grad_embed = shardops.psum_scatter('V/t M -> V/t M/d', grad_embed)

    return grad_model