Skip to content

Reference for ultralytics/models/yolo/world/train.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.world.train.WorldTrainer

WorldTrainer(
    cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None
)

Bases: DetectionTrainer

A trainer class for fine-tuning YOLO World models on close-set datasets.

This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual features for improved object detection and understanding. It handles text embedding generation and caching to accelerate training with multi-modal data.

Attributes:

Name Type Description
text_embeddings Dict[str, Tensor] | None

Cached text embeddings for category names to accelerate training.

model WorldModel

The YOLO World model being trained.

data Dict[str, Any]

Dataset configuration containing class information.

args Any

Training arguments and configuration.

Methods:

Name Description
get_model

Return WorldModel initialized with specified config and weights.

build_dataset

Build YOLO Dataset for training or validation.

set_text_embeddings

Set text embeddings for datasets to accelerate training.

generate_text_embeddings

Generate text embeddings for a list of text samples.

preprocess_batch

Preprocess a batch of images and text for YOLOWorld training.

Examples:

Initialize and train a YOLO World model

>>> from ultralytics.models.yolo.world import WorldTrainer
>>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
>>> trainer = WorldTrainer(overrides=args)
>>> trainer.train()

Parameters:

Name Type Description Default
cfg Dict[str, Any]

Configuration for the trainer.

DEFAULT_CFG
overrides Dict[str, Any]

Configuration overrides.

None
_callbacks List[Any]

List of callback functions.

None
Source code in ultralytics/models/yolo/world/train.py
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):
    """
    Initialize a WorldTrainer object with given arguments.

    Args:
        cfg (Dict[str, Any]): Configuration for the trainer.
        overrides (Dict[str, Any], optional): Configuration overrides.
        _callbacks (List[Any], optional): List of callback functions.
    """
    if overrides is None:
        overrides = {}
    super().__init__(cfg, overrides, _callbacks)
    self.text_embeddings = None

build_dataset

build_dataset(img_path: str, mode: str = 'train', batch: Optional[int] = None)

Build YOLO Dataset for training or validation.

Parameters:

Name Type Description Default
img_path str

Path to the folder containing images.

required
mode str

train mode or val mode, users are able to customize different augmentations for each mode.

'train'
batch int

Size of batches, this is for rect.

None

Returns:

Type Description
Any

YOLO dataset configured for training or validation.

Source code in ultralytics/models/yolo/world/train.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def build_dataset(self, img_path: str, mode: str = "train", batch: Optional[int] = None):
    """
    Build YOLO Dataset for training or validation.

    Args:
        img_path (str): Path to the folder containing images.
        mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
        batch (int, optional): Size of batches, this is for `rect`.

    Returns:
        (Any): YOLO dataset configured for training or validation.
    """
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
    dataset = build_yolo_dataset(
        self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
    )
    if mode == "train":
        self.set_text_embeddings([dataset], batch)  # cache text embeddings to accelerate training
    return dataset

generate_text_embeddings

generate_text_embeddings(
    texts: List[str], batch: int, cache_dir: Path
) -> Dict[str, torch.Tensor]

Generate text embeddings for a list of text samples.

Parameters:

Name Type Description Default
texts List[str]

List of text samples to encode.

required
batch int

Batch size for processing.

required
cache_dir Path

Directory to save/load cached embeddings.

required

Returns:

Type Description
Dict[str, Tensor]

Dictionary mapping text samples to their embeddings.

Source code in ultralytics/models/yolo/world/train.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def generate_text_embeddings(self, texts: List[str], batch: int, cache_dir: Path) -> Dict[str, torch.Tensor]:
    """
    Generate text embeddings for a list of text samples.

    Args:
        texts (List[str]): List of text samples to encode.
        batch (int): Batch size for processing.
        cache_dir (Path): Directory to save/load cached embeddings.

    Returns:
        (Dict[str, torch.Tensor]): Dictionary mapping text samples to their embeddings.
    """
    model = "clip:ViT-B/32"
    cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
    if cache_path.exists():
        LOGGER.info(f"Reading existed cache from '{cache_path}'")
        txt_map = torch.load(cache_path)
        if sorted(txt_map.keys()) == sorted(texts):
            return txt_map
    LOGGER.info(f"Caching text embeddings to '{cache_path}'")
    assert self.model is not None
    txt_feats = self.model.get_text_pe(texts, batch, cache_clip_model=False)
    txt_map = dict(zip(texts, txt_feats.squeeze(0)))
    torch.save(txt_map, cache_path)
    return txt_map

get_model

get_model(
    cfg=None, weights: Optional[str] = None, verbose: bool = True
) -> WorldModel

Return WorldModel initialized with specified config and weights.

Parameters:

Name Type Description Default
cfg Dict[str, Any] | str

Model configuration.

None
weights str

Path to pretrained weights.

None
verbose bool

Whether to display model info.

True

Returns:

Type Description
WorldModel

Initialized WorldModel.

Source code in ultralytics/models/yolo/world/train.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def get_model(self, cfg=None, weights: Optional[str] = None, verbose: bool = True) -> WorldModel:
    """
    Return WorldModel initialized with specified config and weights.

    Args:
        cfg (Dict[str, Any] | str, optional): Model configuration.
        weights (str, optional): Path to pretrained weights.
        verbose (bool): Whether to display model info.

    Returns:
        (WorldModel): Initialized WorldModel.
    """
    # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
    # NOTE: Following the official config, nc hard-coded to 80 for now.
    model = WorldModel(
        cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
        ch=self.data["channels"],
        nc=min(self.data["nc"], 80),
        verbose=verbose and RANK == -1,
    )
    if weights:
        model.load(weights)
    self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)

    return model

preprocess_batch

preprocess_batch(batch: Dict[str, Any]) -> Dict[str, Any]

Preprocess a batch of images and text for YOLOWorld training.

Source code in ultralytics/models/yolo/world/train.py
166
167
168
169
170
171
172
173
174
175
def preprocess_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
    """Preprocess a batch of images and text for YOLOWorld training."""
    batch = DetectionTrainer.preprocess_batch(self, batch)

    # Add text features
    texts = list(itertools.chain(*batch["texts"]))
    txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
    txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
    batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
    return batch

set_text_embeddings

set_text_embeddings(datasets: List[Any], batch: Optional[int]) -> None

Set text embeddings for datasets to accelerate training by caching category names.

This method collects unique category names from all datasets, then generates and caches text embeddings for these categories to improve training efficiency.

Parameters:

Name Type Description Default
datasets List[Any]

List of datasets from which to extract category names.

required
batch int | None

Batch size used for processing.

required
Notes

This method collects category names from datasets that have the 'category_names' attribute, then uses the first dataset's image path to determine where to cache the generated text embeddings.

Source code in ultralytics/models/yolo/world/train.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def set_text_embeddings(self, datasets: List[Any], batch: Optional[int]) -> None:
    """
    Set text embeddings for datasets to accelerate training by caching category names.

    This method collects unique category names from all datasets, then generates and caches text embeddings
    for these categories to improve training efficiency.

    Args:
        datasets (List[Any]): List of datasets from which to extract category names.
        batch (int | None): Batch size used for processing.

    Notes:
        This method collects category names from datasets that have the 'category_names' attribute,
        then uses the first dataset's image path to determine where to cache the generated text embeddings.
    """
    text_embeddings = {}
    for dataset in datasets:
        if not hasattr(dataset, "category_names"):
            continue
        text_embeddings.update(
            self.generate_text_embeddings(
                list(dataset.category_names), batch, cache_dir=Path(dataset.img_path).parent
            )
        )
    self.text_embeddings = text_embeddings





ultralytics.models.yolo.world.train.on_pretrain_routine_end

on_pretrain_routine_end(trainer) -> None

Set up model classes and text encoder at the end of the pretrain routine.

Source code in ultralytics/models/yolo/world/train.py
16
17
18
19
20
21
def on_pretrain_routine_end(trainer) -> None:
    """Set up model classes and text encoder at the end of the pretrain routine."""
    if RANK in {-1, 0}:
        # Set class names for evaluation
        names = [name.split("/", 1)[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
        de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)





📅 Created 1 year ago ✏️ Updated 8 months ago