Adding a Custom Parallel Strategy (Expert Parallel)
This guide walks you through implementing a custom parallel strategy in LMMs Engine using Qwen3-MoE as a concrete example. We’ll cover expert parallelism, which is particularly useful for mixture-of-experts (MoE) models.
Important
Custom parallelization strategies can only be used with the FSDP2 trainer. Ensure your training configuration uses FSDP2 for distributed training.
Overview
Expert parallelism distributes MoE experts across multiple devices to reduce memory pressure and improve throughput. The implementation involves five key steps:
Create the model folder in the parallel directory
Define a parallel style class
Implement parallelization functions
Apply FSDP (Fully Sharded Data Parallel)
Register the parallelization function
Step 1: Create the Model Folder
Create a new directory for your custom parallel strategy under src/lmms_engine/parallel/:
mkdir -p src/lmms_engine/parallel/qwen3_moe
Create two main files:
- style.py - Define your parallel style
- parallelize.py - Implement parallelization functions
- __init__.py - Export public interfaces
Step 2: Define Your Parallel Style
The parallel style class extends PyTorch’s ParallelStyle and defines how to distribute parameters across the device mesh.
File: ``src/lmms_engine/parallel/qwen3_moe/style.py``
Key components:
from torch.distributed.tensor.parallel import ParallelStyle
from torch.distributed.tensor import DeviceMesh, Shard, Replicate, DTensor
class Qwen3MoeParallelStyle(ParallelStyle):
"""
Custom parallel style for Qwen3MoE expert parallelism.
Handles token dispatching and combining across expert parallel ranks.
"""
def __init__(
self,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
) -> None:
super().__init__()
# Define how inputs/outputs should be distributed
self.input_layouts = (input_layouts or Shard(0),)
self.output_layouts = (output_layouts or Shard(0),)
self.use_local_output = use_local_output
The parallel style must implement these key methods:
``_input_fn``: Preprocesses inputs before they reach the module (token dispatching)
``_output_fn``: Postprocesses outputs after the module (token combining)
``_partition_fn``: Distributes module parameters across the device mesh
``_apply``: Applies the parallel style to the module
Example from Qwen3MoE:
@staticmethod
def _partition_fn(name, mod, device_mesh):
"""Distribute expert parameters across the device mesh."""
if isinstance(mod, Qwen3MoeExperts):
expert_parallel_dim = 0
# Distribute each expert parameter
mod.register_parameter(
"up_proj",
nn.Parameter(
distribute_tensor(
mod.up_proj,
device_mesh,
[Shard(expert_parallel_dim)],
)
),
)
The _input_fn and _output_fn handle token routing:
def _input_fn(self, inputs, mesh: DeviceMesh):
"""
Dispatch tokens to appropriate experts.
Uses all-to-all communication for token exchange.
"""
routed_input, num_tokens_per_expert = inputs
if ep_world_size > 1:
# Perform token dispatching across expert parallel group
(
routed_input,
input_splits,
output_splits,
num_tokens_per_expert_group,
) = _token_dispatch(routed_input, num_tokens_per_expert)
# Store metadata for output combining
self.input_splits = input_splits
self.output_splits = output_splits
return routed_input
def _output_fn(self, output, mesh: DeviceMesh):
"""Combine outputs from all experts back to original token order."""
if ep_world_size > 1:
output = _token_combine(output, self.input_splits, self.output_splits)
return output
Step 3: Implement Parallelization Functions
File: ``src/lmms_engine/parallel/qwen3_moe/parallelize.py``
This file contains the main parallelization logic:
from torch.distributed.tensor.parallel import parallelize_module
def apply_qwen3_moe_parallel(
model: Qwen3MoeForCausalLM,
ep_mesh: DeviceMesh,
tp_mesh: DeviceMesh = None,
**kwargs,
):
"""
Apply expert parallel style to model layers.
Args:
model: The model to parallelize
ep_mesh: Expert parallel device mesh
tp_mesh: Tensor parallel device mesh (if supported)
"""
for decoder_layer in model.model.layers:
module = decoder_layer.mlp
ep_plan = Qwen3MoeParallelStyle()
parallelize_module(
module.experts,
device_mesh=ep_mesh,
parallelize_plan=ep_plan,
)
The key function is apply_qwen3_moe_parallelize_fn which orchestrates the entire process:
def apply_qwen3_moe_parallelize_fn(
model: Qwen3MoeForCausalLM,
train_args: TrainingArguments,
**kwargs,
):
"""
Main entry point for applying parallelization to Qwen3MoE.
This function is called by the training system.
"""
ep_size = process_group_manager.ep_size
# Step 1: Stack expert parameters into efficient format
stack_expert_params(model)
full_state_dict = model.state_dict()
# Step 2: Apply expert parallel style (if ep_size > 1)
if ep_size > 1:
ep_mesh = process_group_manager.device_mesh["ep"]
apply_qwen3_moe_parallel(model, ep_mesh=ep_mesh, **kwargs)
# Step 3: Apply FSDP2
apply_qwen3_moe_fsdp2(model, train_args, **kwargs)
# Step 4: Restore full state dict
fsdp2_load_full_state_dict(model, full_state_dict)
Step 4: Apply FSDP2
After defining the parallel style, apply Fully Sharded Data Parallel (FSDP2) for distributed training:
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
def apply_qwen3_moe_fsdp2(
model: Qwen3MoeForCausalLM,
train_args: TrainingArguments,
**kwargs,
):
"""Apply FSDP2 sharding to model."""
# Configure mixed precision
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
output_dtype=torch.float32,
)
dp_mesh = process_group_manager.device_mesh["fsdp"]
fsdp_kwargs = {
"reshard_after_forward": True,
"mp_policy": mp_policy,
"mesh": dp_mesh,
}
# Shard expert layers
ep_size = process_group_manager.ep_size
if ep_size > 1:
# Use a different mesh for experts
expert_fsdp_kwargs = dict(fsdp_kwargs)
expert_fsdp_kwargs["mesh"] = process_group_manager.device_mesh["dp_shard_mod_ep"]
expert_fsdp_kwargs["shard_placement_fn"] = lambda p: Shard(1)
fully_shard(decoder_layer.mlp, **expert_fsdp_kwargs)
# Shard other layers
for decoder_layer in model.model.layers:
fully_shard(decoder_layer.self_attn, **fsdp_kwargs)
fully_shard(model.model.embed_tokens, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)
Step 5: Register Your Parallelization Function
Register your parallelization function in the main registry so it can be discovered by the training system.
File: ``src/lmms_engine/parallel/parallelize.py``
from .qwen3_moe.parallelize import apply_qwen3_moe_parallelize_fn
MODEL_TO_PARALLEL_METHOD = {
"qwen3_moe": apply_qwen3_moe_parallelize_fn,
}
def apply_parallelize(model, model_type, train_args: TrainingArguments, **kwargs):
"""
Apply parallelization based on model type.
Args:
model: The model to parallelize
model_type: Key in MODEL_TO_PARALLEL_METHOD (e.g., "qwen3_moe")
train_args: Training configuration
Raises:
ValueError: If model_type is not supported
"""
if model_type not in MODEL_TO_PARALLEL_METHOD:
raise ValueError(f"Model type {model_type} not supported")
return MODEL_TO_PARALLEL_METHOD[model_type](model, train_args, **kwargs)
File: ``src/lmms_engine/parallel/qwen3_moe/__init__.py``
from .parallelize import apply_qwen3_moe_parallel
from .style import Qwen3MoeParallelStyle
__all__ = ["apply_qwen3_moe_parallel", "Qwen3MoeParallelStyle"]
Key Utilities
Helper functions for token routing are typically placed in a shared utils module:
File: ``src/lmms_engine/parallel/expert_parallel/utils.py``
_token_dispatch: Distributes tokens to experts across ranks using all-to-all communication_token_combine: Combines tokens back from all experts_compute_permute_indices: Computes token permutation indices for efficient routing
These functions handle the communication and coordination between expert parallel ranks.
Configuration in Training Arguments
To enable expert parallelism in your training config:
ep_degree: 8
Summary
To add a custom parallel strategy:
✅ Create a folder under
src/lmms_engine/parallel/with your model name✅ Define a
ParallelStylesubclass that implements_partition_fn,_input_fn,_output_fn, and_apply✅ Implement parallelization functions in
parallelize.py✅ Apply FSDP2 sharding using
fully_shardwith appropriate device meshes✅ Register your parallelization function in
src/lmms_engine/parallel/parallelize.py
Remember: Custom parallelization only works with the FSDP2 trainer. The parallelization is applied automatically when the training system detects your model type.
For questions about token routing, device meshes, or FSDP2 configuration, refer to the PyTorch Distributed documentation and the LMMs Engine API reference.