Adding a new trainer
This guide explains how to define, register, and use a custom trainer via the centralized registry used by the TrainRunner.
How the registry works
Registration is decorator-driven: you attach a key to your class/function using
@TRAINER_REGISTER.register(...).Lookup happens in
TrainRunner._build_trainerviaTRAINER_REGISTER[self.config.trainer_type].The trainer class is then instantiated with specific keyword arguments (see below).
Constructor requirements
TrainRunner will instantiate your trainer like this:
trainer_cls(
model=self.model,
args=self.config.trainer_args,
data_collator=self.train_dataset.get_collator(),
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
processing_class=self.train_dataset.processor,
)
Your trainer should accept these keyword arguments (unused ones can be ignored via **kwargs). If you subclass Hugging Face transformers.Trainer, note that it expects tokenizer=; in our stack, we pass processing_class=. You can forward processing_class to the appropriate place (e.g., tokenizer) inside your constructor.
Step 1: Implement your trainer
# src/lmms_engine/train/my_trainer.py
from transformers import Trainer as HFTrainer
from lmms_engine.train.registry import TRAINER_REGISTER
@TRAINER_REGISTER.register("my_trainer") # or omit the string to use the class name as the key
class MyTrainer(HFTrainer):
def __init__(
self,
*,
model,
args,
data_collator,
train_dataset,
eval_dataset=None,
processing_class=None,
**kwargs,
):
# If subclassing HF Trainer, you can map processing_class to tokenizer
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=processing_class,
)
# ... any custom init logic ...
Notes:
Using
@TRAINER_REGISTER.register("my_trainer")registers the key as"my_trainer".If you use
@TRAINER_REGISTER.registerwithout parentheses, the key defaults to the class name (MyTrainer).Re-registering an existing key will overwrite the previous value and print a warning.
Step 2: Ensure registration is imported
Registration happens at import time. Make sure the module containing your decorator is imported before building the trainer.
Common options:
Import in
src/lmms_engine/train/__init__.py:
# src/lmms_engine/train/__init__.py
from . import my_trainer # noqa: F401 ensures registration side-effect
Or import explicitly in your application/runner setup prior to calling
TrainRunner.build().
Step 3: Select your trainer in config
Set trainer_type in your config to the registry key you used in registration.
# examples/load_from_config_example.yaml (snippet)
trainer_type: my_trainer
trainer_args:
output_dir: ./output/run
bf16: true
# ... other args ...
If you registered without a string, use the class name instead (e.g., trainer_type: MyTrainer).
Step 4: Run
TrainRunner will resolve the trainer class from the registry and instantiate it with the expected arguments.
Troubleshooting
KeyError: Ensure your module with the decorator ran before
TrainRunnerbuilds the trainer (import ordering).Constructor errors: Ensure your
__init__accepts the arguments listed above; capture extras with**kwargsif needed.Duplicate key warning: Another module registered the same key; either change the key or keep only one registration.