Reference for ultralytics/models/yolo/world/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.yolo.world.train.WorldTrainer
WorldTrainer(
cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None
)
Bases: DetectionTrainer
A trainer class for fine-tuning YOLO World models on close-set datasets.
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual features for improved object detection and understanding. It handles text embedding generation and caching to accelerate training with multi-modal data.
Attributes:
Name | Type | Description |
---|---|---|
text_embeddings |
Dict[str, Tensor] | None
|
Cached text embeddings for category names to accelerate training. |
model |
WorldModel
|
The YOLO World model being trained. |
data |
Dict[str, Any]
|
Dataset configuration containing class information. |
args |
Any
|
Training arguments and configuration. |
Methods:
Name | Description |
---|---|
get_model |
Return WorldModel initialized with specified config and weights. |
build_dataset |
Build YOLO Dataset for training or validation. |
set_text_embeddings |
Set text embeddings for datasets to accelerate training. |
generate_text_embeddings |
Generate text embeddings for a list of text samples. |
preprocess_batch |
Preprocess a batch of images and text for YOLOWorld training. |
Examples:
Initialize and train a YOLO World model
>>> from ultralytics.models.yolo.world import WorldTrainer
>>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
>>> trainer = WorldTrainer(overrides=args)
>>> trainer.train()
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
Dict[str, Any]
|
Configuration for the trainer. |
DEFAULT_CFG
|
overrides
|
Dict[str, Any]
|
Configuration overrides. |
None
|
_callbacks
|
List[Any]
|
List of callback functions. |
None
|
Source code in ultralytics/models/yolo/world/train.py
54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
build_dataset
build_dataset(img_path: str, mode: str = 'train', batch: Optional[int] = None)
Build YOLO Dataset for training or validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img_path
|
str
|
Path to the folder containing images. |
required |
mode
|
str
|
|
'train'
|
batch
|
int
|
Size of batches, this is for |
None
|
Returns:
Type | Description |
---|---|
Any
|
YOLO dataset configured for training or validation. |
Source code in ultralytics/models/yolo/world/train.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
|
generate_text_embeddings
generate_text_embeddings(
texts: List[str], batch: int, cache_dir: Path
) -> Dict[str, torch.Tensor]
Generate text embeddings for a list of text samples.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
texts
|
List[str]
|
List of text samples to encode. |
required |
batch
|
int
|
Batch size for processing. |
required |
cache_dir
|
Path
|
Directory to save/load cached embeddings. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Tensor]
|
Dictionary mapping text samples to their embeddings. |
Source code in ultralytics/models/yolo/world/train.py
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
|
get_model
get_model(
cfg=None, weights: Optional[str] = None, verbose: bool = True
) -> WorldModel
Return WorldModel initialized with specified config and weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
Dict[str, Any] | str
|
Model configuration. |
None
|
weights
|
str
|
Path to pretrained weights. |
None
|
verbose
|
bool
|
Whether to display model info. |
True
|
Returns:
Type | Description |
---|---|
WorldModel
|
Initialized WorldModel. |
Source code in ultralytics/models/yolo/world/train.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
preprocess_batch
preprocess_batch(batch: Dict[str, Any]) -> Dict[str, Any]
Preprocess a batch of images and text for YOLOWorld training.
Source code in ultralytics/models/yolo/world/train.py
166 167 168 169 170 171 172 173 174 175 |
|
set_text_embeddings
set_text_embeddings(datasets: List[Any], batch: Optional[int]) -> None
Set text embeddings for datasets to accelerate training by caching category names.
This method collects unique category names from all datasets, then generates and caches text embeddings for these categories to improve training efficiency.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
datasets
|
List[Any]
|
List of datasets from which to extract category names. |
required |
batch
|
int | None
|
Batch size used for processing. |
required |
Notes
This method collects category names from datasets that have the 'category_names' attribute, then uses the first dataset's image path to determine where to cache the generated text embeddings.
Source code in ultralytics/models/yolo/world/train.py
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
|
ultralytics.models.yolo.world.train.on_pretrain_routine_end
on_pretrain_routine_end(trainer) -> None
Set up model classes and text encoder at the end of the pretrain routine.
Source code in ultralytics/models/yolo/world/train.py
16 17 18 19 20 21 |
|