Skip to content

Reference for ultralytics/models/yolo/yoloe/predict.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/yoloe/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.yolo.yoloe.predict.YOLOEVPDetectPredictor

YOLOEVPDetectPredictor(
    cfg=DEFAULT_CFG,
    overrides: Optional[Dict[str, Any]] = None,
    _callbacks: Optional[Dict[str, List[callable]]] = None,
)

Bases: DetectionPredictor

A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.

This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt handling, and preprocessing transformations.

Attributes:

Name Type Description
model Module

The YOLO model for inference.

device device

Device to run the model on (CPU or CUDA).

prompts dict | Tensor

Visual prompts containing class indices and bounding boxes or masks.

Methods:

Name Description
setup_model

Initialize the YOLO model and set it to evaluation mode.

set_prompts

Set the visual prompts for the model.

pre_transform

Preprocess images and prompts before inference.

inference

Run inference with visual prompts.

get_vpe

Process source to get visual prompt embeddings.

Source code in ultralytics/engine/predictor.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def __init__(
    self,
    cfg=DEFAULT_CFG,
    overrides: Optional[Dict[str, Any]] = None,
    _callbacks: Optional[Dict[str, List[callable]]] = None,
):
    """
    Initialize the BasePredictor class.

    Args:
        cfg (str | dict): Path to a configuration file or a configuration dictionary.
        overrides (dict, optional): Configuration overrides.
        _callbacks (dict, optional): Dictionary of callback functions.
    """
    self.args = get_cfg(cfg, overrides)
    self.save_dir = get_save_dir(self.args)
    if self.args.conf is None:
        self.args.conf = 0.25  # default conf=0.25
    self.done_warmup = False
    if self.args.show:
        self.args.show = check_imshow(warn=True)

    # Usable if setup is done
    self.model = None
    self.data = self.args.data  # data_dict
    self.imgsz = None
    self.device = None
    self.dataset = None
    self.vid_writer = {}  # dict of {save_path: video_writer, ...}
    self.plotted_img = None
    self.source_type = None
    self.seen = 0
    self.windows = []
    self.batch = None
    self.results = None
    self.transforms = None
    self.callbacks = _callbacks or callbacks.get_default_callbacks()
    self.txt_path = None
    self._lock = threading.Lock()  # for automatic thread-safe inference
    callbacks.add_integration_callbacks(self)

get_vpe

get_vpe(source)

Process the source to get the visual prompt embeddings (VPE).

Parameters:

Name Type Description Default
source str | Path | int | Image | ndarray | Tensor | List | Tuple

The source of the image to make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and torch tensors.

required

Returns:

Type Description
Tensor

The visual prompt embeddings (VPE) from the model.

Source code in ultralytics/models/yolo/yoloe/predict.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def get_vpe(self, source):
    """
    Process the source to get the visual prompt embeddings (VPE).

    Args:
        source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
            of the image to make predictions on. Accepts various types including file paths, URLs, PIL
            images, numpy arrays, and torch tensors.

    Returns:
        (torch.Tensor): The visual prompt embeddings (VPE) from the model.
    """
    self.setup_source(source)
    assert len(self.dataset) == 1, "get_vpe only supports one image!"
    for _, im0s, _ in self.dataset:
        im = self.preprocess(im0s)
        return self.model(im, vpe=self.prompts, return_vpe=True)

inference

inference(im, *args, **kwargs)

Run inference with visual prompts.

Parameters:

Name Type Description Default
im Tensor

Input image tensor.

required
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}

Returns:

Type Description
Tensor

Model prediction results.

Source code in ultralytics/models/yolo/yoloe/predict.py
133
134
135
136
137
138
139
140
141
142
143
144
145
def inference(self, im, *args, **kwargs):
    """
    Run inference with visual prompts.

    Args:
        im (torch.Tensor): Input image tensor.
        *args (Any): Variable length argument list.
        **kwargs (Any): Arbitrary keyword arguments.

    Returns:
        (torch.Tensor): Model prediction results.
    """
    return super().inference(im, vpe=self.prompts, *args, **kwargs)

pre_transform

pre_transform(im)

Preprocess images and prompts before inference.

This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks) accordingly.

Parameters:

Name Type Description Default
im list

List containing a single input image.

required

Returns:

Type Description
list

Preprocessed image ready for model inference.

Raises:

Type Description
ValueError

If neither valid bounding boxes nor masks are provided in the prompts.

Source code in ultralytics/models/yolo/yoloe/predict.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def pre_transform(self, im):
    """
    Preprocess images and prompts before inference.

    This method applies letterboxing to the input image and transforms the visual prompts
    (bounding boxes or masks) accordingly.

    Args:
        im (list): List containing a single input image.

    Returns:
        (list): Preprocessed image ready for model inference.

    Raises:
        ValueError: If neither valid bounding boxes nor masks are provided in the prompts.
    """
    img = super().pre_transform(im)
    bboxes = self.prompts.pop("bboxes", None)
    masks = self.prompts.pop("masks", None)
    category = self.prompts["cls"]
    if len(img) == 1:
        visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
        self.prompts = visuals.unsqueeze(0).to(self.device)  # (1, N, H, W)
    else:
        # NOTE: only supports bboxes as prompts for now
        assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
        # NOTE: needs List[np.ndarray]
        assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
            f"Expected List[np.ndarray], but got {bboxes}!"
        )
        assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
            f"Expected List[np.ndarray], but got {category}!"
        )
        assert len(im) == len(category) == len(bboxes), (
            f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
        )
        visuals = [
            self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
            for i in range(len(img))
        ]
        self.prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)

    return img

set_prompts

set_prompts(prompts)

Set the visual prompts for the model.

Parameters:

Name Type Description Default
prompts dict

Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key with class indices.

required
Source code in ultralytics/models/yolo/yoloe/predict.py
42
43
44
45
46
47
48
49
50
def set_prompts(self, prompts):
    """
    Set the visual prompts for the model.

    Args:
        prompts (dict): Dictionary containing class indices and bounding boxes or masks.
            Must include a 'cls' key with class indices.
    """
    self.prompts = prompts

setup_model

setup_model(model, verbose: bool = True)

Set up the model for prediction.

Parameters:

Name Type Description Default
model Module

Model to load or use.

required
verbose bool

If True, provides detailed logging.

True
Source code in ultralytics/models/yolo/yoloe/predict.py
31
32
33
34
35
36
37
38
39
40
def setup_model(self, model, verbose: bool = True):
    """
    Set up the model for prediction.

    Args:
        model (torch.nn.Module): Model to load or use.
        verbose (bool, optional): If True, provides detailed logging.
    """
    super().setup_model(model, verbose=verbose)
    self.done_warmup = True





ultralytics.models.yolo.yoloe.predict.YOLOEVPSegPredictor

YOLOEVPSegPredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: YOLOEVPDetectPredictor, SegmentationPredictor

Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities.

Source code in ultralytics/models/yolo/segment/predict.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize the SegmentationPredictor with configuration, overrides, and callbacks.

    This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
    prediction results.

    Args:
        cfg (dict): Configuration for the predictor.
        overrides (dict, optional): Configuration overrides that take precedence over cfg.
        _callbacks (list, optional): List of callback functions to be invoked during prediction.
    """
    super().__init__(cfg, overrides, _callbacks)
    self.args.task = "segment"





📅 Created 2 months ago ✏️ Updated 2 months ago