Skip to content

Reference for ultralytics/models/yolo/yoloe/val.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/yoloe/val.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.yoloe.val.YOLOEDetectValidator

YOLOEDetectValidator(
    dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None
)

Bases: DetectionValidator

A validator class for YOLOE detection models that handles both text and visual prompt embeddings.

This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible evaluation strategies for prompt-based object detection.

Attributes:

Name Type Description
device device

The device on which validation is performed.

args namespace

Configuration arguments for validation.

dataloader DataLoader

DataLoader for validation data.

Methods:

Name Description
get_visual_pe

Extract visual prompt embeddings from training samples.

preprocess

Preprocess batch data ensuring visuals are on the same device as images.

get_vpe_dataloader

Create a dataloader for LVIS training visual prompt samples.

__call__

Run validation using either text or visual prompt embeddings.

Examples:

Validate with text prompts

>>> validator = YOLOEDetectValidator()
>>> stats = validator(model=model, load_vp=False)

Validate with visual prompts

>>> stats = validator(model=model, refer_data="path/to/data.yaml", load_vp=True)
Source code in ultralytics/models/yolo/detect/val.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
    """
    Initialize detection validator with necessary variables and settings.

    Args:
        dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
        save_dir (Path, optional): Directory to save results.
        pbar (Any, optional): Progress bar for displaying progress.
        args (Dict[str, Any], optional): Arguments for the validator.
        _callbacks (List[Any], optional): List of callback functions.
    """
    super().__init__(dataloader, save_dir, pbar, args, _callbacks)
    self.nt_per_class = None
    self.nt_per_image = None
    self.is_coco = False
    self.is_lvis = False
    self.class_map = None
    self.args.task = "detect"
    self.metrics = DetMetrics(save_dir=self.save_dir)
    self.iouv = torch.linspace(0.5, 0.95, 10)  # IoU vector for mAP@0.5:0.95
    self.niou = self.iouv.numel()

__call__

__call__(
    trainer: Optional[Any] = None,
    model: Optional[Union[YOLOEModel, str]] = None,
    refer_data: Optional[str] = None,
    load_vp: bool = False,
) -> Dict[str, Any]

Run validation on the model using either text or visual prompt embeddings.

This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It supports validation during training (using a trainer object) or standalone validation with a provided model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.

Parameters:

Name Type Description Default
trainer object

Trainer object containing the model and device.

None
model YOLOEModel | str

Model to validate. Required if trainer is not provided.

None
refer_data str

Path to reference data for visual prompts.

None
load_vp bool

Whether to load visual prompts. If False, text prompts are used.

False

Returns:

Type Description
dict

Validation statistics containing metrics computed during validation.

Source code in ultralytics/models/yolo/yoloe/val.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
@smart_inference_mode()
def __call__(
    self,
    trainer: Optional[Any] = None,
    model: Optional[Union[YOLOEModel, str]] = None,
    refer_data: Optional[str] = None,
    load_vp: bool = False,
) -> Dict[str, Any]:
    """
    Run validation on the model using either text or visual prompt embeddings.

    This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.
    It supports validation during training (using a trainer object) or standalone validation with a provided
    model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.

    Args:
        trainer (object, optional): Trainer object containing the model and device.
        model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.
        refer_data (str, optional): Path to reference data for visual prompts.
        load_vp (bool): Whether to load visual prompts. If False, text prompts are used.

    Returns:
        (dict): Validation statistics containing metrics computed during validation.
    """
    if trainer is not None:
        self.device = trainer.device
        model = trainer.ema.ema
        names = [name.split("/", 1)[0] for name in list(self.dataloader.dataset.data["names"].values())]

        if load_vp:
            LOGGER.info("Validate using the visual prompt.")
            self.args.half = False
            # Directly use the same dataloader for visual embeddings extracted during training
            vpe = self.get_visual_pe(self.dataloader, model)
            model.set_classes(names, vpe)
        else:
            LOGGER.info("Validate using the text prompt.")
            tpe = model.get_text_pe(names)
            model.set_classes(names, tpe)
        stats = super().__call__(trainer, model)
    else:
        if refer_data is not None:
            assert load_vp, "Refer data is only used for visual prompt validation."
        self.device = select_device(self.args.device)

        if isinstance(model, str):
            from ultralytics.nn.tasks import attempt_load_weights

            model = attempt_load_weights(model, device=self.device, inplace=True)
        model.eval().to(self.device)
        data = check_det_dataset(refer_data or self.args.data)
        names = [name.split("/", 1)[0] for name in list(data["names"].values())]

        if load_vp:
            LOGGER.info("Validate using the visual prompt.")
            self.args.half = False
            # TODO: need to check if the names from refer data is consistent with the evaluated dataset
            # could use same dataset or refer to extract visual prompt embeddings
            dataloader = self.get_vpe_dataloader(data)
            vpe = self.get_visual_pe(dataloader, model)
            model.set_classes(names, vpe)
            stats = super().__call__(model=deepcopy(model))
        elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], "lrpc"):  # prompt-free
            return super().__call__(trainer, model)
        else:
            LOGGER.info("Validate using the text prompt.")
            tpe = model.get_text_pe(names)
            model.set_classes(names, tpe)
            stats = super().__call__(model=deepcopy(model))
    return stats

get_visual_pe

get_visual_pe(dataloader: DataLoader, model: YOLOEModel) -> torch.Tensor

Extract visual prompt embeddings from training samples.

This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to zero.

Parameters:

Name Type Description Default
dataloader DataLoader

The dataloader providing training samples.

required
model YOLOEModel

The YOLOE model from which to extract visual prompt embeddings.

required

Returns:

Type Description
Tensor

Visual prompt embeddings with shape (1, num_classes, embed_dim).

Source code in ultralytics/models/yolo/yoloe/val.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@smart_inference_mode()
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
    """
    Extract visual prompt embeddings from training samples.

    This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.
    It normalizes the embeddings and handles cases where no samples exist for a class by setting their
    embeddings to zero.

    Args:
        dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
        model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.

    Returns:
        (torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
    """
    assert isinstance(model, YOLOEModel)
    names = [name.split("/", 1)[0] for name in list(dataloader.dataset.data["names"].values())]
    visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
    cls_visual_num = torch.zeros(len(names))

    desc = "Get visual prompt embeddings from samples"

    # Count samples per class
    for batch in dataloader:
        cls = batch["cls"].squeeze(-1).to(torch.int).unique()
        count = torch.bincount(cls, minlength=len(names))
        cls_visual_num += count

    cls_visual_num = cls_visual_num.to(self.device)

    # Extract visual prompt embeddings
    pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
    for batch in pbar:
        batch = self.preprocess(batch)
        preds = model.get_visual_pe(batch["img"], visual=batch["visuals"])  # (B, max_n, embed_dim)

        batch_idx = batch["batch_idx"]
        for i in range(preds.shape[0]):
            cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
            pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
            pad_cls[: len(cls)] = cls
            for c in cls:
                visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]

    # Normalize embeddings for classes with samples, set others to zero
    visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
    visual_pe[cls_visual_num == 0] = 0
    return visual_pe.unsqueeze(0)

get_vpe_dataloader

get_vpe_dataloader(data: Dict[str, Any]) -> torch.utils.data.DataLoader

Create a dataloader for LVIS training visual prompt samples.

This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.

Parameters:

Name Type Description Default
data dict

Dataset configuration dictionary containing paths and settings.

required

Returns:

Type Description
DataLoader

The dataloader for visual prompt samples.

Source code in ultralytics/models/yolo/yoloe/val.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def get_vpe_dataloader(self, data: Dict[str, Any]) -> torch.utils.data.DataLoader:
    """
    Create a dataloader for LVIS training visual prompt samples.

    This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.
    It applies necessary transformations including LoadVisualPrompt and configurations to the dataset
    for validation purposes.

    Args:
        data (dict): Dataset configuration dictionary containing paths and settings.

    Returns:
        (torch.utils.data.DataLoader): The dataloader for visual prompt samples.
    """
    dataset = build_yolo_dataset(
        self.args,
        data.get(self.args.split, data.get("val")),
        self.args.batch,
        data,
        mode="val",
        rect=False,
    )
    if isinstance(dataset, YOLOConcatDataset):
        for d in dataset.datasets:
            d.transforms.append(LoadVisualPrompt())
    else:
        dataset.transforms.append(LoadVisualPrompt())
    return build_dataloader(
        dataset,
        self.args.batch,
        self.args.workers,
        shuffle=False,
        rank=-1,
    )

preprocess

preprocess(batch: Dict[str, Any]) -> Dict[str, Any]

Preprocess batch data, ensuring visuals are on the same device as images.

Source code in ultralytics/models/yolo/yoloe/val.py
 98
 99
100
101
102
103
def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
    """Preprocess batch data, ensuring visuals are on the same device as images."""
    batch = super().preprocess(batch)
    if "visuals" in batch:
        batch["visuals"] = batch["visuals"].to(batch["img"].device)
    return batch





ultralytics.models.yolo.yoloe.val.YOLOESegValidator

YOLOESegValidator(
    dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None
)

Bases: YOLOEDetectValidator, SegmentationValidator

YOLOE segmentation validator that supports both text and visual prompt embeddings.

Source code in ultralytics/models/yolo/segment/val.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
    """
    Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.

    Args:
        dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
        save_dir (Path, optional): Directory to save results.
        pbar (Any, optional): Progress bar for displaying progress.
        args (namespace, optional): Arguments for the validator.
        _callbacks (list, optional): List of callback functions.
    """
    super().__init__(dataloader, save_dir, pbar, args, _callbacks)
    self.plot_masks = None
    self.process = None
    self.args.task = "segment"
    self.metrics = SegmentMetrics(save_dir=self.save_dir)





📅 Created 2 months ago ✏️ Updated 2 months ago