Reference for ultralytics/models/yolo/detect/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/detect/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.detect.train.DetectionTrainer
DetectionTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)
Bases: BaseTrainer
A class extending the BaseTrainer class for training based on a detection model.
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for object detection including dataset building, data loading, preprocessing, and model configuration.
Attributes:
Name | Type | Description |
---|---|---|
model |
DetectionModel
|
The YOLO detection model being trained. |
data |
Dict
|
Dictionary containing dataset information including class names and number of classes. |
loss_names |
tuple
|
Names of the loss components used in training (box_loss, cls_loss, dfl_loss). |
Methods:
Name | Description |
---|---|
build_dataset |
Build YOLO dataset for training or validation. |
get_dataloader |
Construct and return dataloader for the specified mode. |
preprocess_batch |
Preprocess a batch of images by scaling and converting to float. |
set_model_attributes |
Set model attributes based on dataset information. |
get_model |
Return a YOLO detection model. |
get_validator |
Return a validator for model evaluation. |
label_loss_items |
Return a loss dictionary with labeled training loss items. |
progress_string |
Return a formatted string of training progress. |
plot_training_samples |
Plot training samples with their annotations. |
plot_metrics |
Plot metrics from a CSV file. |
plot_training_labels |
Create a labeled training plot of the YOLO model. |
auto_batch |
Calculate optimal batch size based on model memory requirements. |
Examples:
>>> from ultralytics.models.yolo.detect import DetectionTrainer
>>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
>>> trainer = DetectionTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/engine/trainer.py
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
|
auto_batch
auto_batch()
Get optimal batch size by calculating memory occupation of model.
Returns:
Type | Description |
---|---|
int
|
Optimal batch size. |
Source code in ultralytics/models/yolo/detect/train.py
209 210 211 212 213 214 215 216 217 218 |
|
build_dataset
build_dataset(img_path: str, mode: str = 'train', batch: Optional[int] = None)
Build YOLO Dataset for training or validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img_path
|
str
|
Path to the folder containing images. |
required |
mode
|
str
|
'train' mode or 'val' mode, users are able to customize different augmentations for each mode. |
'train'
|
batch
|
int
|
Size of batches, this is for 'rect' mode. |
None
|
Returns:
Type | Description |
---|---|
Dataset
|
YOLO dataset object configured for the specified mode. |
Source code in ultralytics/models/yolo/detect/train.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
get_dataloader
get_dataloader(
dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"
)
Construct and return dataloader for the specified mode.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_path
|
str
|
Path to the dataset. |
required |
batch_size
|
int
|
Number of images per batch. |
16
|
rank
|
int
|
Process rank for distributed training. |
0
|
mode
|
str
|
'train' for training dataloader, 'val' for validation dataloader. |
'train'
|
Returns:
Type | Description |
---|---|
DataLoader
|
PyTorch dataloader object. |
Source code in ultralytics/models/yolo/detect/train.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
|
get_model
get_model(
cfg: Optional[str] = None,
weights: Optional[str] = None,
verbose: bool = True,
)
Return a YOLO detection model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
str
|
Path to model configuration file. |
None
|
weights
|
str
|
Path to model weights. |
None
|
verbose
|
bool
|
Whether to display model information. |
True
|
Returns:
Type | Description |
---|---|
DetectionModel
|
YOLO detection model. |
Source code in ultralytics/models/yolo/detect/train.py
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
|
get_validator
get_validator()
Return a DetectionValidator for YOLO model validation.
Source code in ultralytics/models/yolo/detect/train.py
146 147 148 149 150 151 |
|
label_loss_items
label_loss_items(
loss_items: Optional[List[float]] = None, prefix: str = "train"
)
Return a loss dict with labeled training loss items tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_items
|
List[float]
|
List of loss values. |
None
|
prefix
|
str
|
Prefix for keys in the returned dictionary. |
'train'
|
Returns:
Type | Description |
---|---|
Dict | List
|
Dictionary of labeled loss items if loss_items is provided, otherwise list of keys. |
Source code in ultralytics/models/yolo/detect/train.py
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
|
plot_metrics
plot_metrics()
Plot metrics from a CSV file.
Source code in ultralytics/models/yolo/detect/train.py
199 200 201 |
|
plot_training_labels
plot_training_labels()
Create a labeled training plot of the YOLO model.
Source code in ultralytics/models/yolo/detect/train.py
203 204 205 206 207 |
|
plot_training_samples
plot_training_samples(batch: Dict, ni: int)
Plot training samples with their annotations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Dict
|
Dictionary containing batch data. |
required |
ni
|
int
|
Number of iterations. |
required |
Source code in ultralytics/models/yolo/detect/train.py
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
|
preprocess_batch
preprocess_batch(batch: Dict) -> Dict
Preprocess a batch of images by scaling and converting to float.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Dict
|
Dictionary containing batch data with 'img' tensor. |
required |
Returns:
Type | Description |
---|---|
Dict
|
Preprocessed batch with normalized images. |
Source code in ultralytics/models/yolo/detect/train.py
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
|
progress_string
progress_string()
Return a formatted string of training progress with epoch, GPU memory, loss, instances and size.
Source code in ultralytics/models/yolo/detect/train.py
171 172 173 174 175 176 177 178 179 |
|
set_model_attributes
set_model_attributes()
Set model attributes based on dataset information.
Source code in ultralytics/models/yolo/detect/train.py
118 119 120 121 122 123 124 125 126 |
|