FSDP2 Mixed Precision: Reduce Dtype Considerations
Default Behavior
By default, when using FSDP2 trainer, the reduce dtype is set to bfloat16. This configuration is defined in the training arguments:
reduce_dtype: Optional[str] = "bfloat16"
output_dtype: Optional[str] = "bfloat16"
This setting controls the data type used for gradient reduction operations during distributed training.
Benefits
Using bfloat16 for reduce operations provides:
Memory Savings: Reduces memory footprint during gradient synchronization across devices
Faster Training: Speeds up the all-reduce communication by transferring smaller tensors
In many projects, we have successfully used bfloat16 reduce dtype to train models and achieve acceptable performance.
Caution
While bfloat16 reduce dtype works well for most training scenarios, users should be cautious when:
Training very large models where gradient precision becomes critical
Requiring high numerical accuracy in the final model
Observing training instability or unexpected loss spikes
If you encounter issues related to numerical precision, consider switching to float32 for the reduce dtype:
# In your training config
reduce_dtype: float32
output_dtype: bfloat16
Configuration Example
To explicitly set the reduce dtype in your YAML configuration:
training_config:
bf16: true
fsdp2: true
reduce_dtype: bfloat16 # or float32 for higher precision