ViT Frame Parallel
For multimodal training with variable per-sample frame counts (typical of mixed image + video batches), the per-rank ViT forward is the dominant imbalance source: one rank may end up with a 64-frame video while another rank processes a single image, and FSDP stalls all ranks waiting on the slowest.
ViT frame parallel redistributes the frames themselves across the DP group before the ViT forward, then routes each frame’s patch embeddings back to the rank that owns the corresponding text sample.
Mental Model
For each forward step:
Every rank announces how many frames / patches it would normally process.
An LPT (longest-processing-time first) bin-packing pass assigns frames to ranks so that the per-rank patch count is balanced.
input_dispatchships frames to their assigned compute rank via a singleall_to_all_single(autograd-aware), so each rank now runs the ViT on a roughly equal number of patches.The ViT forward runs unchanged.
output_dispatchships the ViT outputs (last_hidden_stateandpooler_output) back to the rank that owns the original sample, again via a singleall_to_all_single.
The wrap is implemented as a thin plumbing layer
(lmms_engine.parallel.vit_parallel.frame_parallel.wrap_vit_forward) that
takes three callables — input_dispatch, orig_forward, output_dispatch —
so the model-specific dispatch logic stays in the corresponding model
directory.
When It Helps
Scenario |
Net effect |
|---|---|
Single node, NVLink, video + image mix |
Almost always a win — all-to-all is cheap on NVLink, idle time savings are large. |
Multi-node, IB, large variance in frame counts |
Usually a clear win. Communication is on the order of 100 ms per step; idle time saved is often 1–3 s. |
Multi-node, very uniform frame counts (e.g. fixed 8 frames everywhere) |
Marginal or net negative. Imbalance is small; the extra communication may outweigh the savings. |
|
No-op. The patch logs and returns immediately. |
If your batches are already balanced (every sample has roughly the same number of patches), the all-to-all overhead is pure cost. Inspect step-time variance across ranks before enabling on uniform workloads.
Enabling It
Add vit_frame_parallel to model_config.monkey_patch_kwargs.patch_type:
model_config:
load_from_pretrained_path: Qwen/Qwen3.5-4B
attn_implementation: flash_attention_2
monkey_patch_kwargs:
patch_type: ["liger", "vit_frame_parallel"]
The patch must be registered for the target model. Currently registered:
qwen3_5(Qwen3_5VisionModel.forward)
To add support for a new model, drop an <model>_vit_ops.py that exports
input_dispatch(self, ...) -> (dispatched_inputs, ctx) and
output_dispatch(ctx, ...) -> outputs, and register a patch under
@MONKEY_PATCHER.register("<model>", "vit_frame_parallel") that calls
wrap_vit_forward. See src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py
and src/lmms_engine/models/qwen3_5/monkey_patch.py for the reference.
Implementation Notes
The LPT planner lives in
lmms_engine.parallel.vit_parallel.balanceand is a pure algorithm (no rank / comm knowledge), making it easy to unit-test against synthetic load distributions.Frames are physically reordered on the source rank (via an
argsort-based permutation) before theall_to_all_singleso that each destination’s contiguous slice corresponds to its assigned frames. The reverse permutation is applied inoutput_dispatch.Both ViT outputs (
last_hidden_stateat patch scale andpooler_outputat patch / merge² scale) are shipped back. The LLM consumespooler_output;last_hidden_stateis shipped for completeness (loss-free with respect to downstream usage).The wrap reuses
pgm.process_group_manager.dp_groupso it composes naturally with FSDP2 sharding.