Skip to content

Reference for ultralytics/engine/trainer.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/trainer.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.engine.trainer.BaseTrainer

BaseTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

A base class for creating trainers.

This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing, and various training utilities. It supports both single-GPU and multi-GPU distributed training.

Attributes:

Name Type Description
args SimpleNamespace

Configuration for the trainer.

validator BaseValidator

Validator instance.

model Module

Model instance.

callbacks defaultdict

Dictionary of callbacks.

save_dir Path

Directory to save results.

wdir Path

Directory to save weights.

last Path

Path to the last checkpoint.

best Path

Path to the best checkpoint.

save_period int

Save checkpoint every x epochs (disabled if < 1).

batch_size int

Batch size for training.

epochs int

Number of epochs to train for.

start_epoch int

Starting epoch for training.

device device

Device to use for training.

amp bool

Flag to enable AMP (Automatic Mixed Precision).

scaler GradScaler

Gradient scaler for AMP.

data str

Path to data.

ema Module

EMA (Exponential Moving Average) of the model.

resume bool

Resume training from a checkpoint.

lf Module

Loss function.

scheduler _LRScheduler

Learning rate scheduler.

best_fitness float

The best fitness value achieved.

fitness float

Current fitness value.

loss float

Current loss value.

tloss float

Total loss value.

loss_names list

List of loss names.

csv Path

Path to results CSV file.

metrics dict

Dictionary of metrics.

plots dict

Dictionary of plots.

Methods:

Name Description
train

Execute the training process.

validate

Run validation on the test set.

save_model

Save model training checkpoints.

get_dataset

Get train and validation datasets.

setup_model

Load, create, or download model.

build_optimizer

Construct an optimizer for the model.

Examples:

Initialize a trainer and start training

>>> trainer = BaseTrainer(cfg="config.yaml")
>>> trainer.train()

Parameters:

Name Type Description Default
cfg str

Path to a configuration file.

DEFAULT_CFG
overrides dict

Configuration overrides.

None
_callbacks list

List of callback functions.

None
Source code in ultralytics/engine/trainer.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize the BaseTrainer class.

    Args:
        cfg (str, optional): Path to a configuration file.
        overrides (dict, optional): Configuration overrides.
        _callbacks (list, optional): List of callback functions.
    """
    self.args = get_cfg(cfg, overrides)
    self.check_resume(overrides)
    self.device = select_device(self.args.device, self.args.batch)
    # Update "-1" devices so post-training val does not repeat search
    self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
    self.validator = None
    self.metrics = None
    self.plots = {}
    init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)

    # Dirs
    self.save_dir = get_save_dir(self.args)
    self.args.name = self.save_dir.name  # update name for loggers
    self.wdir = self.save_dir / "weights"  # weights dir
    if RANK in {-1, 0}:
        self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
        self.args.save_dir = str(self.save_dir)
        YAML.save(self.save_dir / "args.yaml", vars(self.args))  # save run args
    self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
    self.save_period = self.args.save_period

    self.batch_size = self.args.batch
    self.epochs = self.args.epochs or 100  # in case users accidentally pass epochs=None with timed training
    self.start_epoch = 0
    if RANK == -1:
        print_args(vars(self.args))

    # Device
    if self.device.type in {"cpu", "mps"}:
        self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading

    # Model and Dataset
    self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolo11n -> yolo11n.pt
    with torch_distributed_zero_first(LOCAL_RANK):  # avoid auto-downloading dataset multiple times
        self.data = self.get_dataset()

    self.ema = None

    # Optimization utils init
    self.lf = None
    self.scheduler = None

    # Epoch level metrics
    self.best_fitness = None
    self.fitness = None
    self.loss = None
    self.tloss = None
    self.loss_names = ["Loss"]
    self.csv = self.save_dir / "results.csv"
    self.plot_idx = [0, 1, 2]

    # HUB
    self.hub_session = None

    # Callbacks
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    if RANK in {-1, 0}:
        callbacks.add_integration_callbacks(self)

add_callback

add_callback(event: str, callback)

Append the given callback to the event's callback list.

Source code in ultralytics/engine/trainer.py
178
179
180
def add_callback(self, event: str, callback):
    """Append the given callback to the event's callback list."""
    self.callbacks[event].append(callback)

auto_batch

auto_batch(max_num_obj=0)

Calculate optimal batch size based on model and device memory constraints.

Source code in ultralytics/engine/trainer.py
505
506
507
508
509
510
511
512
513
def auto_batch(self, max_num_obj=0):
    """Calculate optimal batch size based on model and device memory constraints."""
    return check_train_batch_size(
        model=self.model,
        imgsz=self.args.imgsz,
        amp=self.amp,
        batch=self.batch_size,
        max_num_obj=max_num_obj,
    )  # returns batch size

build_dataset

build_dataset(img_path, mode='train', batch=None)

Build dataset.

Source code in ultralytics/engine/trainer.py
674
675
676
def build_dataset(self, img_path, mode="train", batch=None):
    """Build dataset."""
    raise NotImplementedError("build_dataset function not implemented in trainer")

build_optimizer

build_optimizer(
    model, name="auto", lr=0.001, momentum=0.9, decay=1e-05, iterations=100000.0
)

Construct an optimizer for the given model.

Parameters:

Name Type Description Default
model Module

The model for which to build an optimizer.

required
name str

The name of the optimizer to use. If 'auto', the optimizer is selected based on the number of iterations.

'auto'
lr float

The learning rate for the optimizer.

0.001
momentum float

The momentum factor for the optimizer.

0.9
decay float

The weight decay for the optimizer.

1e-05
iterations float

The number of iterations, which determines the optimizer if name is 'auto'.

100000.0

Returns:

Type Description
Optimizer

The constructed optimizer.

Source code in ultralytics/engine/trainer.py
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
    """
    Construct an optimizer for the given model.

    Args:
        model (torch.nn.Module): The model for which to build an optimizer.
        name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
            based on the number of iterations.
        lr (float, optional): The learning rate for the optimizer.
        momentum (float, optional): The momentum factor for the optimizer.
        decay (float, optional): The weight decay for the optimizer.
        iterations (float, optional): The number of iterations, which determines the optimizer if
            name is 'auto'.

    Returns:
        (torch.optim.Optimizer): The constructed optimizer.
    """
    g = [], [], []  # optimizer parameter groups
    bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
    if name == "auto":
        LOGGER.info(
            f"{colorstr('optimizer:')} 'optimizer=auto' found, "
            f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
            f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
        )
        nc = self.data.get("nc", 10)  # number of classes
        lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places
        name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
        self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for Adam

    for module_name, module in model.named_modules():
        for param_name, param in module.named_parameters(recurse=False):
            fullname = f"{module_name}.{param_name}" if module_name else param_name
            if "bias" in fullname:  # bias (no decay)
                g[2].append(param)
            elif isinstance(module, bn) or "logit_scale" in fullname:  # weight (no decay)
                # ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
                g[1].append(param)
            else:  # weight (with decay)
                g[0].append(param)

    optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
    name = {x.lower(): x for x in optimizers}.get(name.lower())
    if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
        optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
    elif name == "RMSProp":
        optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
    elif name == "SGD":
        optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
    else:
        raise NotImplementedError(
            f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
            "Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
        )

    optimizer.add_param_group({"params": g[0], "weight_decay": decay})  # add g0 with weight_decay
    optimizer.add_param_group({"params": g[1], "weight_decay": 0.0})  # add g1 (BatchNorm2d weights)
    LOGGER.info(
        f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
        f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
    )
    return optimizer

build_targets

build_targets(preds, targets)

Build target tensors for training YOLO model.

Source code in ultralytics/engine/trainer.py
691
692
693
def build_targets(self, preds, targets):
    """Build target tensors for training YOLO model."""
    pass

check_resume

check_resume(overrides)

Check if resume checkpoint exists and update arguments accordingly.

Source code in ultralytics/engine/trainer.py
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
def check_resume(self, overrides):
    """Check if resume checkpoint exists and update arguments accordingly."""
    resume = self.args.resume
    if resume:
        try:
            exists = isinstance(resume, (str, Path)) and Path(resume).exists()
            last = Path(check_file(resume) if exists else get_latest_run())

            # Check that resume data YAML exists, otherwise strip to force re-download of dataset
            ckpt_args = attempt_load_weights(last).args
            if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
                ckpt_args["data"] = self.args.data

            resume = True
            self.args = get_cfg(ckpt_args)
            self.args.model = self.args.resume = str(last)  # reinstate model
            for k in (
                "imgsz",
                "batch",
                "device",
                "close_mosaic",
            ):  # allow arg updates to reduce memory or update device on resume
                if k in overrides:
                    setattr(self.args, k, overrides[k])

        except Exception as e:
            raise FileNotFoundError(
                "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
                "i.e. 'yolo train resume model=path/to/last.pt'"
            ) from e
    self.resume = resume

final_eval

final_eval()

Perform final evaluation and validation for object detection YOLO model.

Source code in ultralytics/engine/trainer.py
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
def final_eval(self):
    """Perform final evaluation and validation for object detection YOLO model."""
    ckpt = {}
    for f in self.last, self.best:
        if f.exists():
            if f is self.last:
                ckpt = strip_optimizer(f)
            elif f is self.best:
                k = "train_results"  # update best.pt train_metrics from last.pt
                strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
                LOGGER.info(f"\nValidating {f}...")
                self.validator.args.plots = self.args.plots
                self.metrics = self.validator(model=f)
                self.metrics.pop("fitness", None)
                self.run_callbacks("on_fit_epoch_end")

get_dataloader

get_dataloader(dataset_path, batch_size=16, rank=0, mode='train')

Return dataloader derived from torch.data.Dataloader.

Source code in ultralytics/engine/trainer.py
670
671
672
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
    """Return dataloader derived from torch.data.Dataloader."""
    raise NotImplementedError("get_dataloader function not implemented in trainer")

get_dataset

get_dataset()

Get train and validation datasets from data dictionary.

Returns:

Type Description
dict

A dictionary containing the training/validation/test dataset and category names.

Source code in ultralytics/engine/trainer.py
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def get_dataset(self):
    """
    Get train and validation datasets from data dictionary.

    Returns:
        (dict): A dictionary containing the training/validation/test dataset and category names.
    """
    try:
        if self.args.task == "classify":
            data = check_cls_dataset(self.args.data)
        elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
            "detect",
            "segment",
            "pose",
            "obb",
        }:
            data = check_det_dataset(self.args.data)
            if "yaml_file" in data:
                self.args.data = data["yaml_file"]  # for validating 'yolo train data=url.zip' usage
    except Exception as e:
        raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
    if self.args.single_cls:
        LOGGER.info("Overriding class names with single class.")
        data["names"] = {0: "item"}
        data["nc"] = 1
    return data

get_model

get_model(cfg=None, weights=None, verbose=True)

Get model and raise NotImplementedError for loading cfg files.

Source code in ultralytics/engine/trainer.py
662
663
664
def get_model(self, cfg=None, weights=None, verbose=True):
    """Get model and raise NotImplementedError for loading cfg files."""
    raise NotImplementedError("This task trainer doesn't support loading cfg files")

get_validator

get_validator()

Return a NotImplementedError when the get_validator function is called.

Source code in ultralytics/engine/trainer.py
666
667
668
def get_validator(self):
    """Return a NotImplementedError when the get_validator function is called."""
    raise NotImplementedError("get_validator function not implemented in trainer")

label_loss_items

label_loss_items(loss_items=None, prefix='train')

Return a loss dict with labelled training loss items tensor.

Note

This is not needed for classification but necessary for segmentation & detection

Source code in ultralytics/engine/trainer.py
678
679
680
681
682
683
684
685
def label_loss_items(self, loss_items=None, prefix="train"):
    """
    Return a loss dict with labelled training loss items tensor.

    Note:
        This is not needed for classification but necessary for segmentation & detection
    """
    return {"loss": loss_items} if loss_items is not None else ["loss"]

on_plot

on_plot(name, data=None)

Register plots (e.g. to be consumed in callbacks).

Source code in ultralytics/engine/trainer.py
721
722
723
724
def on_plot(self, name, data=None):
    """Register plots (e.g. to be consumed in callbacks)."""
    path = Path(name)
    self.plots[path] = {"data": data, "timestamp": time.time()}

optimizer_step

optimizer_step()

Perform a single step of the training optimizer with gradient clipping and EMA update.

Source code in ultralytics/engine/trainer.py
634
635
636
637
638
639
640
641
642
def optimizer_step(self):
    """Perform a single step of the training optimizer with gradient clipping and EMA update."""
    self.scaler.unscale_(self.optimizer)  # unscale gradients
    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)  # clip gradients
    self.scaler.step(self.optimizer)
    self.scaler.update()
    self.optimizer.zero_grad()
    if self.ema:
        self.ema.update(self.model)

plot_metrics

plot_metrics()

Plot and display metrics visually.

Source code in ultralytics/engine/trainer.py
717
718
719
def plot_metrics(self):
    """Plot and display metrics visually."""
    pass

plot_training_labels

plot_training_labels()

Plot training labels for YOLO model.

Source code in ultralytics/engine/trainer.py
704
705
706
def plot_training_labels(self):
    """Plot training labels for YOLO model."""
    pass

plot_training_samples

plot_training_samples(batch, ni)

Plot training samples during YOLO training.

Source code in ultralytics/engine/trainer.py
700
701
702
def plot_training_samples(self, batch, ni):
    """Plot training samples during YOLO training."""
    pass

preprocess_batch

preprocess_batch(batch)

Allow custom preprocessing model inputs and ground truths depending on task type.

Source code in ultralytics/engine/trainer.py
644
645
646
def preprocess_batch(self, batch):
    """Allow custom preprocessing model inputs and ground truths depending on task type."""
    return batch

progress_string

progress_string()

Return a string describing training progress.

Source code in ultralytics/engine/trainer.py
695
696
697
def progress_string(self):
    """Return a string describing training progress."""
    return ""

read_results_csv

read_results_csv()

Read results.csv into a dictionary using pandas.

Source code in ultralytics/engine/trainer.py
538
539
540
541
542
def read_results_csv(self):
    """Read results.csv into a dictionary using pandas."""
    import pandas as pd  # scope for faster 'import ultralytics'

    return pd.read_csv(self.csv).to_dict(orient="list")

resume_training

resume_training(ckpt)

Resume YOLO training from given epoch and best fitness.

Source code in ultralytics/engine/trainer.py
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
def resume_training(self, ckpt):
    """Resume YOLO training from given epoch and best fitness."""
    if ckpt is None or not self.resume:
        return
    best_fitness = 0.0
    start_epoch = ckpt.get("epoch", -1) + 1
    if ckpt.get("optimizer", None) is not None:
        self.optimizer.load_state_dict(ckpt["optimizer"])  # optimizer
        best_fitness = ckpt["best_fitness"]
    if self.ema and ckpt.get("ema"):
        self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())  # EMA
        self.ema.updates = ckpt["updates"]
    assert start_epoch > 0, (
        f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
        f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
    )
    LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
    if self.epochs < start_epoch:
        LOGGER.info(
            f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
        )
        self.epochs += ckpt["epoch"]  # finetune additional epochs
    self.best_fitness = best_fitness
    self.start_epoch = start_epoch
    if start_epoch > (self.epochs - self.args.close_mosaic):
        self._close_dataloader_mosaic()

run_callbacks

run_callbacks(event: str)

Run all existing callbacks associated with a particular event.

Source code in ultralytics/engine/trainer.py
186
187
188
189
def run_callbacks(self, event: str):
    """Run all existing callbacks associated with a particular event."""
    for callback in self.callbacks.get(event, []):
        callback(self)

save_metrics

save_metrics(metrics)

Save training metrics to a CSV file.

Source code in ultralytics/engine/trainer.py
708
709
710
711
712
713
714
715
def save_metrics(self, metrics):
    """Save training metrics to a CSV file."""
    keys, vals = list(metrics.keys()), list(metrics.values())
    n = len(metrics) + 2  # number of cols
    s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n")  # header
    t = time.time() - self.train_time_start
    with open(self.csv, "a", encoding="utf-8") as f:
        f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")

save_model

save_model()

Save model training checkpoints with additional metadata.

Source code in ultralytics/engine/trainer.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
def save_model(self):
    """Save model training checkpoints with additional metadata."""
    import io

    # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
    buffer = io.BytesIO()
    torch.save(
        {
            "epoch": self.epoch,
            "best_fitness": self.best_fitness,
            "model": None,  # resume and final checkpoints derive from EMA
            "ema": deepcopy(self.ema.ema).half(),
            "updates": self.ema.updates,
            "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
            "train_args": vars(self.args),  # save as dict
            "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
            "train_results": self.read_results_csv(),
            "date": datetime.now().isoformat(),
            "version": __version__,
            "license": "AGPL-3.0 (https://ultralytics.com/license)",
            "docs": "https://docs.ultralytics.com",
        },
        buffer,
    )
    serialized_ckpt = buffer.getvalue()  # get the serialized content to save

    # Save checkpoints
    self.last.write_bytes(serialized_ckpt)  # save last.pt
    if self.best_fitness == self.fitness:
        self.best.write_bytes(serialized_ckpt)  # save best.pt
    if (self.save_period > 0) and (self.epoch % self.save_period == 0):
        (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt)  # save epoch, i.e. 'epoch3.pt'

set_callback

set_callback(event: str, callback)

Override the existing callbacks with the given callback for the specified event.

Source code in ultralytics/engine/trainer.py
182
183
184
def set_callback(self, event: str, callback):
    """Override the existing callbacks with the given callback for the specified event."""
    self.callbacks[event] = [callback]

set_model_attributes

set_model_attributes()

Set or update model parameters before training.

Source code in ultralytics/engine/trainer.py
687
688
689
def set_model_attributes(self):
    """Set or update model parameters before training."""
    self.model.names = self.data["names"]

setup_model

setup_model()

Load, create, or download model for any task.

Returns:

Type Description
dict

Optional checkpoint to resume training from.

Source code in ultralytics/engine/trainer.py
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
def setup_model(self):
    """
    Load, create, or download model for any task.

    Returns:
        (dict): Optional checkpoint to resume training from.
    """
    if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
        return

    cfg, weights = self.model, None
    ckpt = None
    if str(self.model).endswith(".pt"):
        weights, ckpt = attempt_load_one_weight(self.model)
        cfg = weights.yaml
    elif isinstance(self.args.pretrained, (str, Path)):
        weights, _ = attempt_load_one_weight(self.args.pretrained)
    self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
    return ckpt

train

train()

Allow device='', device=None on Multi-GPU systems to default to device=0.

Source code in ultralytics/engine/trainer.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def train(self):
    """Allow device='', device=None on Multi-GPU systems to default to device=0."""
    if isinstance(self.args.device, str) and len(self.args.device):  # i.e. device='0' or device='0,1,2,3'
        world_size = len(self.args.device.split(","))
    elif isinstance(self.args.device, (tuple, list)):  # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
        world_size = len(self.args.device)
    elif self.args.device in {"cpu", "mps"}:  # i.e. device='cpu' or 'mps'
        world_size = 0
    elif torch.cuda.is_available():  # i.e. device=None or device='' or device=number
        world_size = 1  # default to device 0
    else:  # i.e. device=None or device=''
        world_size = 0

    # Run subprocess if DDP training, else train normally
    if world_size > 1 and "LOCAL_RANK" not in os.environ:
        # Argument checks
        if self.args.rect:
            LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
            self.args.rect = False
        if self.args.batch < 1.0:
            LOGGER.warning(
                "'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting default 'batch=16'"
            )
            self.args.batch = 16

        # Command
        cmd, file = generate_ddp_command(world_size, self)
        try:
            LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
            subprocess.run(cmd, check=True)
        except Exception as e:
            raise e
        finally:
            ddp_cleanup(self, str(file))

    else:
        self._do_train(world_size)

validate

validate()

Run validation on test set using self.validator.

Returns:

Name Type Description
metrics dict

Dictionary of validation metrics.

fitness float

Fitness score for the validation.

Source code in ultralytics/engine/trainer.py
648
649
650
651
652
653
654
655
656
657
658
659
660
def validate(self):
    """
    Run validation on test set using self.validator.

    Returns:
        metrics (dict): Dictionary of validation metrics.
        fitness (float): Fitness score for the validation.
    """
    metrics = self.validator(self)
    fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
    if not self.best_fitness or self.best_fitness < fitness:
        self.best_fitness = fitness
    return metrics, fitness





📅 Created 1 year ago ✏️ Updated 8 months ago