Reference for ultralytics/engine/trainer.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/trainer.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.engine.trainer.BaseTrainer
BaseTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)
A base class for creating trainers.
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing, and various training utilities. It supports both single-GPU and multi-GPU distributed training.
Attributes:
Name | Type | Description |
---|---|---|
args |
SimpleNamespace
|
Configuration for the trainer. |
validator |
BaseValidator
|
Validator instance. |
model |
Module
|
Model instance. |
callbacks |
defaultdict
|
Dictionary of callbacks. |
save_dir |
Path
|
Directory to save results. |
wdir |
Path
|
Directory to save weights. |
last |
Path
|
Path to the last checkpoint. |
best |
Path
|
Path to the best checkpoint. |
save_period |
int
|
Save checkpoint every x epochs (disabled if < 1). |
batch_size |
int
|
Batch size for training. |
epochs |
int
|
Number of epochs to train for. |
start_epoch |
int
|
Starting epoch for training. |
device |
device
|
Device to use for training. |
amp |
bool
|
Flag to enable AMP (Automatic Mixed Precision). |
scaler |
GradScaler
|
Gradient scaler for AMP. |
data |
str
|
Path to data. |
ema |
Module
|
EMA (Exponential Moving Average) of the model. |
resume |
bool
|
Resume training from a checkpoint. |
lf |
Module
|
Loss function. |
scheduler |
_LRScheduler
|
Learning rate scheduler. |
best_fitness |
float
|
The best fitness value achieved. |
fitness |
float
|
Current fitness value. |
loss |
float
|
Current loss value. |
tloss |
float
|
Total loss value. |
loss_names |
list
|
List of loss names. |
csv |
Path
|
Path to results CSV file. |
metrics |
dict
|
Dictionary of metrics. |
plots |
dict
|
Dictionary of plots. |
Methods:
Name | Description |
---|---|
train |
Execute the training process. |
validate |
Run validation on the test set. |
save_model |
Save model training checkpoints. |
get_dataset |
Get train and validation datasets. |
setup_model |
Load, create, or download model. |
build_optimizer |
Construct an optimizer for the model. |
Examples:
Initialize a trainer and start training
>>> trainer = BaseTrainer(cfg="config.yaml")
>>> trainer.train()
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
str
|
Path to a configuration file. |
DEFAULT_CFG
|
overrides
|
dict
|
Configuration overrides. |
None
|
_callbacks
|
list
|
List of callback functions. |
None
|
Source code in ultralytics/engine/trainer.py
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
|
add_callback
add_callback(event: str, callback)
Append the given callback to the event's callback list.
Source code in ultralytics/engine/trainer.py
178 179 180 |
|
auto_batch
auto_batch(max_num_obj=0)
Calculate optimal batch size based on model and device memory constraints.
Source code in ultralytics/engine/trainer.py
505 506 507 508 509 510 511 512 513 |
|
build_dataset
build_dataset(img_path, mode='train', batch=None)
Build dataset.
Source code in ultralytics/engine/trainer.py
674 675 676 |
|
build_optimizer
build_optimizer(
model, name="auto", lr=0.001, momentum=0.9, decay=1e-05, iterations=100000.0
)
Construct an optimizer for the given model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model for which to build an optimizer. |
required |
name
|
str
|
The name of the optimizer to use. If 'auto', the optimizer is selected based on the number of iterations. |
'auto'
|
lr
|
float
|
The learning rate for the optimizer. |
0.001
|
momentum
|
float
|
The momentum factor for the optimizer. |
0.9
|
decay
|
float
|
The weight decay for the optimizer. |
1e-05
|
iterations
|
float
|
The number of iterations, which determines the optimizer if name is 'auto'. |
100000.0
|
Returns:
Type | Description |
---|---|
Optimizer
|
The constructed optimizer. |
Source code in ultralytics/engine/trainer.py
809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 |
|
build_targets
build_targets(preds, targets)
Build target tensors for training YOLO model.
Source code in ultralytics/engine/trainer.py
691 692 693 |
|
check_resume
check_resume(overrides)
Check if resume checkpoint exists and update arguments accordingly.
Source code in ultralytics/engine/trainer.py
742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 |
|
final_eval
final_eval()
Perform final evaluation and validation for object detection YOLO model.
Source code in ultralytics/engine/trainer.py
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 |
|
get_dataloader
get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')
Return dataloader derived from torch.data.Dataloader.
Source code in ultralytics/engine/trainer.py
670 671 672 |
|
get_dataset
get_dataset()
Get train and validation datasets from data dictionary.
Returns:
Type | Description |
---|---|
dict
|
A dictionary containing the training/validation/test dataset and category names. |
Source code in ultralytics/engine/trainer.py
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 |
|
get_model
get_model(cfg=None, weights=None, verbose=True)
Get model and raise NotImplementedError for loading cfg files.
Source code in ultralytics/engine/trainer.py
662 663 664 |
|
get_validator
get_validator()
Return a NotImplementedError when the get_validator function is called.
Source code in ultralytics/engine/trainer.py
666 667 668 |
|
label_loss_items
label_loss_items(loss_items=None, prefix='train')
Return a loss dict with labelled training loss items tensor.
Note
This is not needed for classification but necessary for segmentation & detection
Source code in ultralytics/engine/trainer.py
678 679 680 681 682 683 684 685 |
|
on_plot
on_plot(name, data=None)
Register plots (e.g. to be consumed in callbacks).
Source code in ultralytics/engine/trainer.py
721 722 723 724 |
|
optimizer_step
optimizer_step()
Perform a single step of the training optimizer with gradient clipping and EMA update.
Source code in ultralytics/engine/trainer.py
634 635 636 637 638 639 640 641 642 |
|
plot_metrics
plot_metrics()
Plot and display metrics visually.
Source code in ultralytics/engine/trainer.py
717 718 719 |
|
plot_training_labels
plot_training_labels()
Plot training labels for YOLO model.
Source code in ultralytics/engine/trainer.py
704 705 706 |
|
plot_training_samples
plot_training_samples(batch, ni)
Plot training samples during YOLO training.
Source code in ultralytics/engine/trainer.py
700 701 702 |
|
preprocess_batch
preprocess_batch(batch)
Allow custom preprocessing model inputs and ground truths depending on task type.
Source code in ultralytics/engine/trainer.py
644 645 646 |
|
progress_string
progress_string()
Return a string describing training progress.
Source code in ultralytics/engine/trainer.py
695 696 697 |
|
read_results_csv
read_results_csv()
Read results.csv into a dictionary using pandas.
Source code in ultralytics/engine/trainer.py
538 539 540 541 542 |
|
resume_training
resume_training(ckpt)
Resume YOLO training from given epoch and best fitness.
Source code in ultralytics/engine/trainer.py
774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 |
|
run_callbacks
run_callbacks(event: str)
Run all existing callbacks associated with a particular event.
Source code in ultralytics/engine/trainer.py
186 187 188 189 |
|
save_metrics
save_metrics(metrics)
Save training metrics to a CSV file.
Source code in ultralytics/engine/trainer.py
708 709 710 711 712 713 714 715 |
|
save_model
save_model()
Save model training checkpoints with additional metadata.
Source code in ultralytics/engine/trainer.py
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 |
|
set_callback
set_callback(event: str, callback)
Override the existing callbacks with the given callback for the specified event.
Source code in ultralytics/engine/trainer.py
182 183 184 |
|
set_model_attributes
set_model_attributes()
Set or update model parameters before training.
Source code in ultralytics/engine/trainer.py
687 688 689 |
|
setup_model
setup_model()
Load, create, or download model for any task.
Returns:
Type | Description |
---|---|
dict
|
Optional checkpoint to resume training from. |
Source code in ultralytics/engine/trainer.py
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 |
|
train
train()
Allow device='', device=None on Multi-GPU systems to default to device=0.
Source code in ultralytics/engine/trainer.py
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
|
validate
validate()
Run validation on test set using self.validator.
Returns:
Name | Type | Description |
---|---|---|
metrics |
dict
|
Dictionary of validation metrics. |
fitness |
float
|
Fitness score for the validation. |
Source code in ultralytics/engine/trainer.py
648 649 650 651 652 653 654 655 656 657 658 659 660 |
|