Layer Sharding for Large-Scale Training with Muon

May 15, 2025

There's a lot of excitement around the Muon optimizer as a replacement for Adam and AdamW, but it's challenging to scale it up to larger training runs. As we illuminated in our recent paper, muon confers batch size advantages beyond AdamW, offering a much wider range of resources to deploy at the pre-training workload for better time to target loss. In this post, we describe our sharding strategy that scales Muon further than others we've seen discussed.

Muon has a larger overhead than Adam

The final test of an optimizer is how fast you can achieve a given loss. Traditionally, we think of this in terms of tokens, or analogously, flops, and Muon indeed uses fewer tokens or flops to get that loss. However, the true cost of training a model is hours of compute. With Muon, you can use larger batches, and while at small scale you might not notice it, at large scale you can achieve significantly better MFU when you increase the batch size.

However, Muon requires significantly more calculations for parameter updates than Adam, and in some configurations, this cost can be a significant portion of the total training time.

Suppose we have a tensor X with dimensions (L, P, Q) where L is the number of layers and P and Q are the "in" and "out" dimensions of the weights. For the largest weight tensors (the MLP weights), P and Q are the model dimension d and 4d, respectively. We want to compute the following (omitting several elementwise operations that are not expensive):

for _ in range(5):
  A = X @ X^T # (L, P, Q) x (L, Q, P) => (L, P, P) with cost 2LPQP flops
  B = A @ A   # (L, P, P) x (L, P, P) => (L, P, P) with cost 2LPPP flops
  X = B @ X   # (L, P, P) x (L, P, Q) => (L, P, Q) with cost 2LPQP flops

This costs 5*2LPP(2Q + P) flops, which is sometimes rounded up to 30LPPQ flops.

By comparison, the forward and backward passes for this same weight tensor cost 6LPQB flops, where B is the number of tokens in the batch. This means the ratio of muon flops to model flops (not counting attention) is 30LPPQ/6LPQB == 5P/B where P is no larger than (and generally close to) d. So the equation to remember is that Muon costs 5d/B extra flops for every model flop.

This leads to claims like this:

Chart showing runtime analysis of Muon optimizer

However, this assumes we can perfectly shard the Muon computation across all 16,384 GPUs used to train the model. This is unrealistic, but let's get as close as possible. For reference, Llama 405B used 8-way Tensor Parallelism, 16-way Pipeline Parallelism, and 128-way Fully-Sharded Data Parallelism (FSDP) for the bulk of its training. For example, if we end up having to duplicate computation across the FSDP dimension, that would take our cost from 0.5% to 64% of the model flops, which is likely too expensive to use.

We have a few options for how to shard this. First, let's observe that PP is a free way to shard the muon computation: L is a batch dimension, so we can operate on it in parallel. However, FSDP typically shards the P dimension, and TP typically shards the Q dimension, which are contracting dimensions in these matmuls, which means they could introduce expensive communication overheads.

Strategy #1: Replicated computation

Suppose the tensor X is sharded S ways across the P dimension (the Q case is very similar). In Llama 405B, S is between 16 and 128 for the various weights. Then, one option (as described in this paper) is to all-gather X and replicate all the computation on each device (except that for the very last matmul you can save a few flops by computing only your shard):

# let X' be our shard of X with shape (L, P/S, Q)
X = all_gather(X')
for _ in range(5):
  A = X @ X^T # (L, P, Q) x (L, Q, P) => (L, P, P) with cost 2LPQP flops
  B = A @ A   # (L, P, P) x (L, P, P) => (L, P, P) with cost 2LPPP flops
  X = B @ X   # (L, P, P) x (L, P, Q) => (L, P, Q) with cost 2LPQP flops (except last iteration is => (L, P/S, Q) with cost 2LP/8*QP flops)

This costs one all-gather of the entire model (can be done layer-by-layer to reduce peak memory, and if you're in pure Data Parallelism, you can omit the all-gather completely), but it costs almost exactly the same number of flops, which can be too large.

For Llama 405B's standard configuration, this would duplicate the computation 1024 times, which means for every model flop spent in the forward and backward pass, you have to spend 5 flops in Muon. That's not acceptable.

Here's a profile of the muon section on an MI300X node that uses this strategy. Note that Stream #20, which controls most of the networking, is active only in the beginning when it all-gathers the weights.

Profile of the muon section using replicated computation strategy

Strategy #2: Sharded matmul

Another choice is to shard the matmuls directly. This is the default that JAX chose with our naive implementation, and it's also discussed in this post.

# let X' be our shard of X with shape (L, P/S, Q)
for _ in range(5):
  X = all_gather(X')
  A' = X' @ X^T # (L, P/S, Q) x (L, Q, P) => (L, P/S, P) with cost 2LPQP/S flops
  A = all_gather(A')
  B = A' @ A    # (L, P/S, P) x (L, P, P) => (L, P/S, P) with cost 2LPPP/S flops
  X = B' @ X    # (L, P/S, P) x (L, P, Q) => (L, P/S, Q) with cost 2LPQP/S flops

This is many fewer flops: exactly a factor of 1/S on every matmul. However, the communication cost is now quite large: for each of the five iterations, we have to all-gather X (which is the size of the entire model) and A (which is around 1/4 the size of the model). Across all five iterations, this is 6.25 gathers of tensors the size of the entire model. Remember that the normal roofline estimates for how small your batch size can be for DP/FSDP assume you do one all-gather and one reduce-scatter (same cost as all-gather). Adding six more means your minimum batch size must be at least 4x larger to use DP/FSDP, and that's if you can perfectly overlap all communication and computation, which would be very difficult.

While the flop cost for Llama 405B now matches the advertised 0.5%, the communication burden is large enough to be infeasible.

Here's a profile on an MI300X node with this strategy. Note that Stream #20 is active dozens of times (10 times per tensor, and there are ~14 separate tensors in this particular model). You can't tell from this picture, but that communication is not well overlapped either.

Profile on an MI300X node with sharded matmul strategy

Strategy #3: Layer sharding

The reason we have to do so much communication is that we're sharding across a contracting dimension. What about the non-contracting dimension L? Pipeline parallelism shards across L, but in many mid-size training runs we try to avoid pipeline parallelism because it brings a whole host of complexity in the forward and backward passes because the layers have a data dependency between them: the output of one layer is the input to another.

However, for the Muon computations, there is no data dependency between layers. This means we can consider re-sharding X along the layer dimension, doing the computation, and then re-sharding it back to its original sharding.

# let X' be our shard of X with shape (L, P/S, Q)
X'' = all_to_all(X') # (L, P/S, Q) => (L/S, P, Q)
for _ in range(5):
  A'' = X'' @ X''^T  # (L/S, P, Q) x (L/S, Q, P) => (L/S, P, P) with cost 2LPQP/S flops
  B'' = A'' @ A''    # (L/S, P, P) x (L/S, P, P) => (L/S, P, P) with cost 2LPPP/S flops
  X'' = B'' @ X''    # (L/S, P, P) x (L/S, P, Q) => (L/S, P, Q) with cost 2LPQP/S flops
X' = all_to_all(X'') # (L/S, P, Q) => (L, P/S, Q)

This has the same flops as the "sharded matmul" case: a factor of 1/S on every matmul. However, it has much less communication. Instead of five (P, Q) all-gathers and five (P, P) all-gathers, it costs two (P, Q) all-to-alls. On TPUs, all-to-all is about 4x faster than all-gather, and on GPUs, they're even significantly faster within a node because each GPU has a direct link to every other GPU in the node.

Here's a profile on an MI300X node with this strategy. Note that Stream #20 is only occasionally active (in theory, twice per tensor instead of 10 times).

Profile on an MI300X node with layer sharding strategy

Eventually, you run out of layers - Llama 405B uses 126 layers, which limits you to that much layer sharding. With some effort, you can probably get another factor of 4 by allocating the weight tensors within a layer to separate GPUs: each of the three MLP tensors onto its own GPU, and another to handle all the attention projections. If you do all of this, then the expected performance hit from Muon is around 16%. This is still quite high, but it's on the edge of feasibility.

There are potentially additional gains available: for example, remember how Muon lets you use a larger batch size? If it lets you double your batch size, then that increases the model flops without increasing Muon's flops, so your performance hit is now 8%. The Llama paper actually quantifies the all-in MFU difference of doubling the batch size. It goes from 41% to 43%, which is a 4.8% increase, leaving a remaining penalty of 3.2%. That might just be small enough to swallow.

Summary and observed times

To summarize the options explored here:

                       flops (vs replicated) | comms (#all-gathers of the model)
replicated computation          1            | 1 (or 0 if we were already in DP)
sharded matmul                  1/S          | 6.25
layer sharding                  1/S          | 0.5 TPU, ~0.1 GPU

That's the theory, but it's worth measuring the actual times in a typical run. We implemented these approaches in JAX and ran them on TPUs and MI300X GPUs. This was a 2B model with a 24k per-device batch size. That sounds like a small batch size, but remember that Llama 405B was trained with a 1k per-device batch size.

On a v5p-8 TPU (so 4-way distributed):

                 muon time | total step time
replicated:        386 ms  |     2.388s
sharded matmul:    255 ms  |     2.161s
layer sharding:    155 ms  |     2.105s

Ideally, we would see layer-sharding take 25% of the time of replicated, and it actually takes 40% of the time. This discrepancy is almost entirely because the communication and computation does not get overlapped at all -- with some work, this could be improved.

Scaling up to a v5p-16 for layer sharding reduces the muon time to 75 ms, which is exactly what we would predict -- the total amount of communication and computation is the same, but we now have twice the network bandwidth and twice the processing power.

On MI300X (8-way distributed):

                 muon time | total step time
replicated:       1171 ms  |     2.720s (43%)
sharded matmul:    522 ms  |     2.200s (26%)
layer sharding:    189 ms  |     1.840s (10%)

Ideally, we would see layer-sharding take 12.5% of the time of replicated, and it actually takes 16% of the time. This is almost entirely because this particular model has 13 layers, and when we re-shard to 8-way across layers, that leaves 3 extra slots for layers that are wasted -- muon would run the same speed if you added 3 more layers (We verified this experimentally -- the ratio of layer sharding to replicated was observed to be 13.7%).

Separately, these measurements were taken when calculating Muon in fp32, which is expensive and likely unnecessary. Since in layer sharding these are mostly compute-bound, the actual cost of Muon is significantly lower.

Future concerns

Using layer sharding for a training run like 405B is definitely pushing the limits of feasibility. If you can train with fewer GPUs, such as the 2k used by DeepSeek-V3, that helps a lot.

On the other hand, Mixture-of-Experts models are increasingly common, and they actually hurt quite a bit. If you only activate 1/10 of your parameters, then your model flops are 1/10. However, you have to run Muon across all your parameters, because across the batch, some tokens will go to each expert. This means the ratio of muon flops to model flops is 10 times worse than a dense model with the same width. MoE models can usually use larger batch sizes, but not 10x larger.

All the ideas above compute the full version of Muon. To go further than this, it may be necessary to consider alterations to Muon to reduce the size of the matrix that goes through the Newton-Schulz iterations.

In the meantime, layer sharding can help scale up Muon to significantly larger training runs than were previously practical.

Resources