Reference for ultralytics/models/yolo/classify/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/classify/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.classify.train.ClassificationTrainer
ClassificationTrainer(
cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None
)
Bases: BaseTrainer
A trainer class extending BaseTrainer for training image classification models.
This trainer handles the training process for image classification tasks, supporting both YOLO classification models and torchvision models with comprehensive dataset handling and validation.
Attributes:
Name | Type | Description |
---|---|---|
model |
ClassificationModel
|
The classification model to be trained. |
data |
Dict[str, Any]
|
Dictionary containing dataset information including class names and number of classes. |
loss_names |
List[str]
|
Names of the loss functions used during training. |
validator |
ClassificationValidator
|
Validator instance for model evaluation. |
Methods:
Name | Description |
---|---|
set_model_attributes |
Set the model's class names from the loaded dataset. |
get_model |
Return a modified PyTorch model configured for training. |
setup_model |
Load, create or download model for classification. |
build_dataset |
Create a ClassificationDataset instance. |
get_dataloader |
Return PyTorch DataLoader with transforms for image preprocessing. |
preprocess_batch |
Preprocess a batch of images and classes. |
progress_string |
Return a formatted string showing training progress. |
get_validator |
Return an instance of ClassificationValidator. |
label_loss_items |
Return a loss dict with labelled training loss items. |
plot_metrics |
Plot metrics from a CSV file. |
final_eval |
Evaluate trained model and save validation results. |
plot_training_samples |
Plot training samples with their annotations. |
Examples:
Initialize and train a classification model
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
>>> trainer = ClassificationTrainer(overrides=args)
>>> trainer.train()
This constructor sets up a trainer for image classification tasks, configuring the task type and default image size if not specified.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
Dict[str, Any]
|
Default configuration dictionary containing training parameters. |
DEFAULT_CFG
|
overrides
|
Dict[str, Any]
|
Dictionary of parameter overrides for the default configuration. |
None
|
_callbacks
|
List[Any]
|
List of callback functions to be executed during training. |
None
|
Examples:
Create a trainer with custom configuration
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
>>> trainer = ClassificationTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/classify/train.py
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|
build_dataset
build_dataset(img_path: str, mode: str = 'train', batch=None)
Create a ClassificationDataset instance given an image path and mode.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img_path
|
str
|
Path to the dataset images. |
required |
mode
|
str
|
Dataset mode ('train', 'val', or 'test'). |
'train'
|
batch
|
Any
|
Batch information (unused in this implementation). |
None
|
Returns:
Type | Description |
---|---|
ClassificationDataset
|
Dataset for the specified mode. |
Source code in ultralytics/models/yolo/classify/train.py
126 127 128 129 130 131 132 133 134 135 136 137 138 |
|
final_eval
final_eval()
Evaluate trained model and save validation results.
Source code in ultralytics/models/yolo/classify/train.py
210 211 212 213 214 215 216 217 218 219 220 221 |
|
get_dataloader
get_dataloader(
dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"
)
Return PyTorch DataLoader with transforms to preprocess images.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_path
|
str
|
Path to the dataset. |
required |
batch_size
|
int
|
Number of images per batch. |
16
|
rank
|
int
|
Process rank for distributed training. |
0
|
mode
|
str
|
'train', 'val', or 'test' mode. |
'train'
|
Returns:
Type | Description |
---|---|
DataLoader
|
DataLoader for the specified dataset and mode. |
Source code in ultralytics/models/yolo/classify/train.py
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
|
get_model
get_model(cfg=None, weights=None, verbose: bool = True)
Return a modified PyTorch model configured for training YOLO classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
Any
|
Model configuration. |
None
|
weights
|
Any
|
Pre-trained model weights. |
None
|
verbose
|
bool
|
Whether to display model information. |
True
|
Returns:
Type | Description |
---|---|
ClassificationModel
|
Configured PyTorch model for classification. |
Source code in ultralytics/models/yolo/classify/train.py
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
get_validator
get_validator()
Return an instance of ClassificationValidator for validation.
Source code in ultralytics/models/yolo/classify/train.py
181 182 183 184 185 186 |
|
label_loss_items
label_loss_items(loss_items: Optional[Tensor] = None, prefix: str = 'train')
Return a loss dict with labelled training loss items tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_items
|
Tensor
|
Loss tensor items. |
None
|
prefix
|
str
|
Prefix to prepend to loss names. |
'train'
|
Returns:
Name | Type | Description |
---|---|---|
keys |
List[str]
|
List of loss keys if loss_items is None. |
loss_dict |
Dict[str, float]
|
Dictionary of loss items if loss_items is provided. |
Source code in ultralytics/models/yolo/classify/train.py
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
|
plot_metrics
plot_metrics()
Plot metrics from a CSV file.
Source code in ultralytics/models/yolo/classify/train.py
206 207 208 |
|
plot_training_samples
plot_training_samples(batch: Dict[str, Tensor], ni: int)
Plot training samples with their annotations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Dict[str, Tensor]
|
Batch containing images and class labels. |
required |
ni
|
int
|
Number of iterations. |
required |
Source code in ultralytics/models/yolo/classify/train.py
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
|
preprocess_batch
preprocess_batch(batch: Dict[str, Tensor]) -> Dict[str, torch.Tensor]
Preprocess a batch of images and classes.
Source code in ultralytics/models/yolo/classify/train.py
165 166 167 168 169 |
|
progress_string
progress_string() -> str
Return a formatted string showing training progress.
Source code in ultralytics/models/yolo/classify/train.py
171 172 173 174 175 176 177 178 179 |
|
set_model_attributes
set_model_attributes()
Set the YOLO model's class names from the loaded dataset.
Source code in ultralytics/models/yolo/classify/train.py
78 79 80 |
|
setup_model
setup_model()
Load, create or download model for classification tasks.
Returns:
Type | Description |
---|---|
Any
|
Model checkpoint if applicable, otherwise None. |
Source code in ultralytics/models/yolo/classify/train.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
|