FSDP2 Mixed Precision: Reduce Dtype Considerations

Default Behavior

By default, when using FSDP2 trainer, the reduce dtype is set to bfloat16. This configuration is defined in the training arguments:

reduce_dtype: Optional[str] = "bfloat16"
output_dtype: Optional[str] = "bfloat16"

This setting controls the data type used for gradient reduction operations during distributed training.

Benefits

Using bfloat16 for reduce operations provides:

  • Memory Savings: Reduces memory footprint during gradient synchronization across devices

  • Faster Training: Speeds up the all-reduce communication by transferring smaller tensors

In many projects, we have successfully used bfloat16 reduce dtype to train models and achieve acceptable performance.

Caution

While bfloat16 reduce dtype works well for most training scenarios, users should be cautious when:

  • Training very large models where gradient precision becomes critical

  • Requiring high numerical accuracy in the final model

  • Observing training instability or unexpected loss spikes

If you encounter issues related to numerical precision, consider switching to float32 for the reduce dtype:

# In your training config
reduce_dtype: float32
output_dtype: bfloat16

Configuration Example

To explicitly set the reduce dtype in your YAML configuration:

training_config:
  bf16: true
  fsdp2: true
  reduce_dtype: bfloat16  # or float32 for higher precision