Skip to content

Reference for ultralytics/models/yolo/classify/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/classify/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.classify.train.ClassificationTrainer

ClassificationTrainer(
    cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None
)

Bases: BaseTrainer

A trainer class extending BaseTrainer for training image classification models.

This trainer handles the training process for image classification tasks, supporting both YOLO classification models and torchvision models with comprehensive dataset handling and validation.

Attributes:

Name Type Description
model ClassificationModel

The classification model to be trained.

data Dict[str, Any]

Dictionary containing dataset information including class names and number of classes.

loss_names List[str]

Names of the loss functions used during training.

validator ClassificationValidator

Validator instance for model evaluation.

Methods:

Name Description
set_model_attributes

Set the model's class names from the loaded dataset.

get_model

Return a modified PyTorch model configured for training.

setup_model

Load, create or download model for classification.

build_dataset

Create a ClassificationDataset instance.

get_dataloader

Return PyTorch DataLoader with transforms for image preprocessing.

preprocess_batch

Preprocess a batch of images and classes.

progress_string

Return a formatted string showing training progress.

get_validator

Return an instance of ClassificationValidator.

label_loss_items

Return a loss dict with labelled training loss items.

plot_metrics

Plot metrics from a CSV file.

final_eval

Evaluate trained model and save validation results.

plot_training_samples

Plot training samples with their annotations.

Examples:

Initialize and train a classification model

>>> from ultralytics.models.yolo.classify import ClassificationTrainer
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
>>> trainer = ClassificationTrainer(overrides=args)
>>> trainer.train()

This constructor sets up a trainer for image classification tasks, configuring the task type and default image size if not specified.

Parameters:

Name Type Description Default
cfg Dict[str, Any]

Default configuration dictionary containing training parameters.

DEFAULT_CFG
overrides Dict[str, Any]

Dictionary of parameter overrides for the default configuration.

None
_callbacks List[Any]

List of callback functions to be executed during training.

None

Examples:

Create a trainer with custom configuration

>>> from ultralytics.models.yolo.classify import ClassificationTrainer
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
>>> trainer = ClassificationTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/models/yolo/classify/train.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):
    """
    Initialize a ClassificationTrainer object.

    This constructor sets up a trainer for image classification tasks, configuring the task type and default
    image size if not specified.

    Args:
        cfg (Dict[str, Any], optional): Default configuration dictionary containing training parameters.
        overrides (Dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
        _callbacks (List[Any], optional): List of callback functions to be executed during training.

    Examples:
        Create a trainer with custom configuration
        >>> from ultralytics.models.yolo.classify import ClassificationTrainer
        >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
        >>> trainer = ClassificationTrainer(overrides=args)
        >>> trainer.train()
    """
    if overrides is None:
        overrides = {}
    overrides["task"] = "classify"
    if overrides.get("imgsz") is None:
        overrides["imgsz"] = 224
    super().__init__(cfg, overrides, _callbacks)

build_dataset

build_dataset(img_path: str, mode: str = 'train', batch=None)

Create a ClassificationDataset instance given an image path and mode.

Parameters:

Name Type Description Default
img_path str

Path to the dataset images.

required
mode str

Dataset mode ('train', 'val', or 'test').

'train'
batch Any

Batch information (unused in this implementation).

None

Returns:

Type Description
ClassificationDataset

Dataset for the specified mode.

Source code in ultralytics/models/yolo/classify/train.py
126
127
128
129
130
131
132
133
134
135
136
137
138
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
    """
    Create a ClassificationDataset instance given an image path and mode.

    Args:
        img_path (str): Path to the dataset images.
        mode (str, optional): Dataset mode ('train', 'val', or 'test').
        batch (Any, optional): Batch information (unused in this implementation).

    Returns:
        (ClassificationDataset): Dataset for the specified mode.
    """
    return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)

final_eval

final_eval()

Evaluate trained model and save validation results.

Source code in ultralytics/models/yolo/classify/train.py
210
211
212
213
214
215
216
217
218
219
220
221
def final_eval(self):
    """Evaluate trained model and save validation results."""
    for f in self.last, self.best:
        if f.exists():
            strip_optimizer(f)  # strip optimizers
            if f is self.best:
                LOGGER.info(f"\nValidating {f}...")
                self.validator.args.data = self.args.data
                self.validator.args.plots = self.args.plots
                self.metrics = self.validator(model=f)
                self.metrics.pop("fitness", None)
                self.run_callbacks("on_fit_epoch_end")

get_dataloader

get_dataloader(
    dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"
)

Return PyTorch DataLoader with transforms to preprocess images.

Parameters:

Name Type Description Default
dataset_path str

Path to the dataset.

required
batch_size int

Number of images per batch.

16
rank int

Process rank for distributed training.

0
mode str

'train', 'val', or 'test' mode.

'train'

Returns:

Type Description
DataLoader

DataLoader for the specified dataset and mode.

Source code in ultralytics/models/yolo/classify/train.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
    """
    Return PyTorch DataLoader with transforms to preprocess images.

    Args:
        dataset_path (str): Path to the dataset.
        batch_size (int, optional): Number of images per batch.
        rank (int, optional): Process rank for distributed training.
        mode (str, optional): 'train', 'val', or 'test' mode.

    Returns:
        (torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
    """
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = self.build_dataset(dataset_path, mode)

    loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
    # Attach inference transforms
    if mode != "train":
        if is_parallel(self.model):
            self.model.module.transforms = loader.dataset.torch_transforms
        else:
            self.model.transforms = loader.dataset.torch_transforms
    return loader

get_model

get_model(cfg=None, weights=None, verbose: bool = True)

Return a modified PyTorch model configured for training YOLO classification.

Parameters:

Name Type Description Default
cfg Any

Model configuration.

None
weights Any

Pre-trained model weights.

None
verbose bool

Whether to display model information.

True

Returns:

Type Description
ClassificationModel

Configured PyTorch model for classification.

Source code in ultralytics/models/yolo/classify/train.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def get_model(self, cfg=None, weights=None, verbose: bool = True):
    """
    Return a modified PyTorch model configured for training YOLO classification.

    Args:
        cfg (Any, optional): Model configuration.
        weights (Any, optional): Pre-trained model weights.
        verbose (bool, optional): Whether to display model information.

    Returns:
        (ClassificationModel): Configured PyTorch model for classification.
    """
    model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)

    for m in model.modules():
        if not self.args.pretrained and hasattr(m, "reset_parameters"):
            m.reset_parameters()
        if isinstance(m, torch.nn.Dropout) and self.args.dropout:
            m.p = self.args.dropout  # set dropout
    for p in model.parameters():
        p.requires_grad = True  # for training
    return model

get_validator

get_validator()

Return an instance of ClassificationValidator for validation.

Source code in ultralytics/models/yolo/classify/train.py
181
182
183
184
185
186
def get_validator(self):
    """Return an instance of ClassificationValidator for validation."""
    self.loss_names = ["loss"]
    return yolo.classify.ClassificationValidator(
        self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

label_loss_items

label_loss_items(loss_items: Optional[Tensor] = None, prefix: str = 'train')

Return a loss dict with labelled training loss items tensor.

Parameters:

Name Type Description Default
loss_items Tensor

Loss tensor items.

None
prefix str

Prefix to prepend to loss names.

'train'

Returns:

Name Type Description
keys List[str]

List of loss keys if loss_items is None.

loss_dict Dict[str, float]

Dictionary of loss items if loss_items is provided.

Source code in ultralytics/models/yolo/classify/train.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def label_loss_items(self, loss_items: Optional[torch.Tensor] = None, prefix: str = "train"):
    """
    Return a loss dict with labelled training loss items tensor.

    Args:
        loss_items (torch.Tensor, optional): Loss tensor items.
        prefix (str, optional): Prefix to prepend to loss names.

    Returns:
        keys (List[str]): List of loss keys if loss_items is None.
        loss_dict (Dict[str, float]): Dictionary of loss items if loss_items is provided.
    """
    keys = [f"{prefix}/{x}" for x in self.loss_names]
    if loss_items is None:
        return keys
    loss_items = [round(float(loss_items), 5)]
    return dict(zip(keys, loss_items))

plot_metrics

plot_metrics()

Plot metrics from a CSV file.

Source code in ultralytics/models/yolo/classify/train.py
206
207
208
def plot_metrics(self):
    """Plot metrics from a CSV file."""
    plot_results(file=self.csv, classify=True, on_plot=self.on_plot)  # save results.png

plot_training_samples

plot_training_samples(batch: Dict[str, Tensor], ni: int)

Plot training samples with their annotations.

Parameters:

Name Type Description Default
batch Dict[str, Tensor]

Batch containing images and class labels.

required
ni int

Number of iterations.

required
Source code in ultralytics/models/yolo/classify/train.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def plot_training_samples(self, batch: Dict[str, torch.Tensor], ni: int):
    """
    Plot training samples with their annotations.

    Args:
        batch (Dict[str, torch.Tensor]): Batch containing images and class labels.
        ni (int): Number of iterations.
    """
    plot_images(
        images=batch["img"],
        batch_idx=torch.arange(len(batch["img"])),
        cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

preprocess_batch

preprocess_batch(batch: Dict[str, Tensor]) -> Dict[str, torch.Tensor]

Preprocess a batch of images and classes.

Source code in ultralytics/models/yolo/classify/train.py
165
166
167
168
169
def preprocess_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Preprocess a batch of images and classes."""
    batch["img"] = batch["img"].to(self.device)
    batch["cls"] = batch["cls"].to(self.device)
    return batch

progress_string

progress_string() -> str

Return a formatted string showing training progress.

Source code in ultralytics/models/yolo/classify/train.py
171
172
173
174
175
176
177
178
179
def progress_string(self) -> str:
    """Return a formatted string showing training progress."""
    return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
        "Epoch",
        "GPU_mem",
        *self.loss_names,
        "Instances",
        "Size",
    )

set_model_attributes

set_model_attributes()

Set the YOLO model's class names from the loaded dataset.

Source code in ultralytics/models/yolo/classify/train.py
78
79
80
def set_model_attributes(self):
    """Set the YOLO model's class names from the loaded dataset."""
    self.model.names = self.data["names"]

setup_model

setup_model()

Load, create or download model for classification tasks.

Returns:

Type Description
Any

Model checkpoint if applicable, otherwise None.

Source code in ultralytics/models/yolo/classify/train.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def setup_model(self):
    """
    Load, create or download model for classification tasks.

    Returns:
        (Any): Model checkpoint if applicable, otherwise None.
    """
    import torchvision  # scope for faster 'import ultralytics'

    if str(self.model) in torchvision.models.__dict__:
        self.model = torchvision.models.__dict__[self.model](
            weights="IMAGENET1K_V1" if self.args.pretrained else None
        )
        ckpt = None
    else:
        ckpt = super().setup_model()
    ClassificationModel.reshape_outputs(self.model, self.data["nc"])
    return ckpt





📅 Created 1 year ago ✏️ Updated 8 months ago