Skip to content

Reference for ultralytics/models/yolo/detect/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/detect/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.detect.train.DetectionTrainer

DetectionTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: BaseTrainer

A class extending the BaseTrainer class for training based on a detection model.

This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for object detection including dataset building, data loading, preprocessing, and model configuration.

Attributes:

Name Type Description
model DetectionModel

The YOLO detection model being trained.

data Dict

Dictionary containing dataset information including class names and number of classes.

loss_names tuple

Names of the loss components used in training (box_loss, cls_loss, dfl_loss).

Methods:

Name Description
build_dataset

Build YOLO dataset for training or validation.

get_dataloader

Construct and return dataloader for the specified mode.

preprocess_batch

Preprocess a batch of images by scaling and converting to float.

set_model_attributes

Set model attributes based on dataset information.

get_model

Return a YOLO detection model.

get_validator

Return a validator for model evaluation.

label_loss_items

Return a loss dictionary with labeled training loss items.

progress_string

Return a formatted string of training progress.

plot_training_samples

Plot training samples with their annotations.

plot_metrics

Plot metrics from a CSV file.

plot_training_labels

Create a labeled training plot of the YOLO model.

auto_batch

Calculate optimal batch size based on model memory requirements.

Examples:

>>> from ultralytics.models.yolo.detect import DetectionTrainer
>>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
>>> trainer = DetectionTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/engine/trainer.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize the BaseTrainer class.

    Args:
        cfg (str, optional): Path to a configuration file.
        overrides (dict, optional): Configuration overrides.
        _callbacks (list, optional): List of callback functions.
    """
    self.args = get_cfg(cfg, overrides)
    self.check_resume(overrides)
    self.device = select_device(self.args.device, self.args.batch)
    # Update "-1" devices so post-training val does not repeat search
    self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
    self.validator = None
    self.metrics = None
    self.plots = {}
    init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)

    # Dirs
    self.save_dir = get_save_dir(self.args)
    self.args.name = self.save_dir.name  # update name for loggers
    self.wdir = self.save_dir / "weights"  # weights dir
    if RANK in {-1, 0}:
        self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
        self.args.save_dir = str(self.save_dir)
        YAML.save(self.save_dir / "args.yaml", vars(self.args))  # save run args
    self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
    self.save_period = self.args.save_period

    self.batch_size = self.args.batch
    self.epochs = self.args.epochs or 100  # in case users accidentally pass epochs=None with timed training
    self.start_epoch = 0
    if RANK == -1:
        print_args(vars(self.args))

    # Device
    if self.device.type in {"cpu", "mps"}:
        self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading

    # Model and Dataset
    self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolo11n -> yolo11n.pt
    with torch_distributed_zero_first(LOCAL_RANK):  # avoid auto-downloading dataset multiple times
        self.data = self.get_dataset()

    self.ema = None

    # Optimization utils init
    self.lf = None
    self.scheduler = None

    # Epoch level metrics
    self.best_fitness = None
    self.fitness = None
    self.loss = None
    self.tloss = None
    self.loss_names = ["Loss"]
    self.csv = self.save_dir / "results.csv"
    self.plot_idx = [0, 1, 2]

    # HUB
    self.hub_session = None

    # Callbacks
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    if RANK in {-1, 0}:
        callbacks.add_integration_callbacks(self)

auto_batch

auto_batch()

Get optimal batch size by calculating memory occupation of model.

Returns:

Type Description
int

Optimal batch size.

Source code in ultralytics/models/yolo/detect/train.py
209
210
211
212
213
214
215
216
217
218
def auto_batch(self):
    """
    Get optimal batch size by calculating memory occupation of model.

    Returns:
        (int): Optimal batch size.
    """
    train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
    max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4  # 4 for mosaic augmentation
    return super().auto_batch(max_num_obj)

build_dataset

build_dataset(img_path: str, mode: str = 'train', batch: Optional[int] = None)

Build YOLO Dataset for training or validation.

Parameters:

Name Type Description Default
img_path str

Path to the folder containing images.

required
mode str

'train' mode or 'val' mode, users are able to customize different augmentations for each mode.

'train'
batch int

Size of batches, this is for 'rect' mode.

None

Returns:

Type Description
Dataset

YOLO dataset object configured for the specified mode.

Source code in ultralytics/models/yolo/detect/train.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def build_dataset(self, img_path: str, mode: str = "train", batch: Optional[int] = None):
    """
    Build YOLO Dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for 'rect' mode.

    Returns:
        (Dataset): YOLO dataset object configured for the specified mode.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)

get_dataloader

get_dataloader(
    dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"
)

Construct and return dataloader for the specified mode.

Parameters:

Name Type Description Default
dataset_path str

Path to the dataset.

required
batch_size int

Number of images per batch.

16
rank int

Process rank for distributed training.

0
mode str

'train' for training dataloader, 'val' for validation dataloader.

'train'

Returns:

Type Description
DataLoader

PyTorch dataloader object.

Source code in ultralytics/models/yolo/detect/train.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
    """
    Construct and return dataloader for the specified mode.

    Args:
        dataset_path (str): Path to the dataset.
        batch_size (int): Number of images per batch.
        rank (int): Process rank for distributed training.
        mode (str): 'train' for training dataloader, 'val' for validation dataloader.

    Returns:
        (DataLoader): PyTorch dataloader object.
    """
    assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = self.build_dataset(dataset_path, mode, batch_size)
    shuffle = mode == "train"
    if getattr(dataset, "rect", False) and shuffle:
        LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
        shuffle = False
    workers = self.args.workers if mode == "train" else self.args.workers * 2
    return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader

get_model

get_model(
    cfg: Optional[str] = None,
    weights: Optional[str] = None,
    verbose: bool = True,
)

Return a YOLO detection model.

Parameters:

Name Type Description Default
cfg str

Path to model configuration file.

None
weights str

Path to model weights.

None
verbose bool

Whether to display model information.

True

Returns:

Type Description
DetectionModel

YOLO detection model.

Source code in ultralytics/models/yolo/detect/train.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def get_model(self, cfg: Optional[str] = None, weights: Optional[str] = None, verbose: bool = True):
    """
    Return a YOLO detection model.

    Args:
        cfg (str, optional): Path to model configuration file.
        weights (str, optional): Path to model weights.
        verbose (bool): Whether to display model information.

    Returns:
        (DetectionModel): YOLO detection model.
    """
    model = DetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
    if weights:
        model.load(weights)
    return model

get_validator

get_validator()

Return a DetectionValidator for YOLO model validation.

Source code in ultralytics/models/yolo/detect/train.py
146
147
148
149
150
151
def get_validator(self):
    """Return a DetectionValidator for YOLO model validation."""
    self.loss_names = "box_loss", "cls_loss", "dfl_loss"
    return yolo.detect.DetectionValidator(
        self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
    )

label_loss_items

label_loss_items(
    loss_items: Optional[List[float]] = None, prefix: str = "train"
)

Return a loss dict with labeled training loss items tensor.

Parameters:

Name Type Description Default
loss_items List[float]

List of loss values.

None
prefix str

Prefix for keys in the returned dictionary.

'train'

Returns:

Type Description
Dict | List

Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.

Source code in ultralytics/models/yolo/detect/train.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def label_loss_items(self, loss_items: Optional[List[float]] = None, prefix: str = "train"):
    """
    Return a loss dict with labeled training loss items tensor.

    Args:
        loss_items (List[float], optional): List of loss values.
        prefix (str): Prefix for keys in the returned dictionary.

    Returns:
        (Dict | List): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
    """
    keys = [f"{prefix}/{x}" for x in self.loss_names]
    if loss_items is not None:
        loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
        return dict(zip(keys, loss_items))
    else:
        return keys

plot_metrics

plot_metrics()

Plot metrics from a CSV file.

Source code in ultralytics/models/yolo/detect/train.py
199
200
201
def plot_metrics(self):
    """Plot metrics from a CSV file."""
    plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png

plot_training_labels

plot_training_labels()

Create a labeled training plot of the YOLO model.

Source code in ultralytics/models/yolo/detect/train.py
203
204
205
206
207
def plot_training_labels(self):
    """Create a labeled training plot of the YOLO model."""
    boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
    cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
    plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)

plot_training_samples

plot_training_samples(batch: Dict, ni: int)

Plot training samples with their annotations.

Parameters:

Name Type Description Default
batch Dict

Dictionary containing batch data.

required
ni int

Number of iterations.

required
Source code in ultralytics/models/yolo/detect/train.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def plot_training_samples(self, batch: Dict, ni: int):
    """
    Plot training samples with their annotations.

    Args:
        batch (Dict): Dictionary containing batch data.
        ni (int): Number of iterations.
    """
    plot_images(
        images=batch["img"],
        batch_idx=batch["batch_idx"],
        cls=batch["cls"].squeeze(-1),
        bboxes=batch["bboxes"],
        paths=batch["im_file"],
        fname=self.save_dir / f"train_batch{ni}.jpg",
        on_plot=self.on_plot,
    )

preprocess_batch

preprocess_batch(batch: Dict) -> Dict

Preprocess a batch of images by scaling and converting to float.

Parameters:

Name Type Description Default
batch Dict

Dictionary containing batch data with 'img' tensor.

required

Returns:

Type Description
Dict

Preprocessed batch with normalized images.

Source code in ultralytics/models/yolo/detect/train.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def preprocess_batch(self, batch: Dict) -> Dict:
    """
    Preprocess a batch of images by scaling and converting to float.

    Args:
        batch (Dict): Dictionary containing batch data with 'img' tensor.

    Returns:
        (Dict): Preprocessed batch with normalized images.
    """
    batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
    if self.args.multi_scale:
        imgs = batch["img"]
        sz = (
            random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
            // self.stride
            * self.stride
        )  # size
        sf = sz / max(imgs.shape[2:])  # scale factor
        if sf != 1:
            ns = [
                math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
            ]  # new shape (stretched to gs-multiple)
            imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
        batch["img"] = imgs
    return batch

progress_string

progress_string()

Return a formatted string of training progress with epoch, GPU memory, loss, instances and size.

Source code in ultralytics/models/yolo/detect/train.py
171
172
173
174
175
176
177
178
179
def progress_string(self):
    """Return a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
    return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
        "Epoch",
        "GPU_mem",
        *self.loss_names,
        "Instances",
        "Size",
    )

set_model_attributes

set_model_attributes()

Set model attributes based on dataset information.

Source code in ultralytics/models/yolo/detect/train.py
118
119
120
121
122
123
124
125
126
def set_model_attributes(self):
    """Set model attributes based on dataset information."""
    # Nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)
    # self.args.box *= 3 / nl  # scale to layers
    # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
    # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
    self.model.nc = self.data["nc"]  # attach number of classes to model
    self.model.names = self.data["names"]  # attach class names to model
    self.model.args = self.args  # attach hyperparameters to model





📅 Created 1 year ago ✏️ Updated 8 months ago