Reference for ultralytics/models/rtdetr/train.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/rtdetr/train.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.rtdetr.train.RTDETRTrainer
RTDETRTrainer(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)
Bases: DetectionTrainer
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
Attributes:
Name | Type | Description |
---|---|---|
loss_names |
tuple
|
Names of the loss components used for training. |
data |
dict
|
Dataset configuration containing class count and other parameters. |
args |
dict
|
Training arguments and hyperparameters. |
save_dir |
Path
|
Directory to save training results. |
test_loader |
DataLoader
|
DataLoader for validation/testing data. |
Methods:
Name | Description |
---|---|
get_model |
Initialize and return an RT-DETR model for object detection tasks. |
build_dataset |
Build and return an RT-DETR dataset for training or validation. |
get_validator |
Return a DetectionValidator suitable for RT-DETR model validation. |
Notes
- F.grid_sample used in RT-DETR does not support the
deterministic=True
argument. - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Examples:
>>> from ultralytics.models.rtdetr.train import RTDETRTrainer
>>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
>>> trainer = RTDETRTrainer(overrides=args)
>>> trainer.train()
Source code in ultralytics/engine/trainer.py
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
|
build_dataset
build_dataset(img_path: str, mode: str = 'val', batch: Optional[int] = None)
Build and return an RT-DETR dataset for training or validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img_path
|
str
|
Path to the folder containing images. |
required |
mode
|
str
|
Dataset mode, either 'train' or 'val'. |
'val'
|
batch
|
int
|
Batch size for rectangle training. |
None
|
Returns:
Type | Description |
---|---|
RTDETRDataset
|
Dataset object for the specific mode. |
Source code in ultralytics/models/rtdetr/train.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
|
get_model
get_model(
cfg: Optional[dict] = None,
weights: Optional[str] = None,
verbose: bool = True,
)
Initialize and return an RT-DETR model for object detection tasks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
dict
|
Model configuration. |
None
|
weights
|
str
|
Path to pre-trained model weights. |
None
|
verbose
|
bool
|
Verbose logging if True. |
True
|
Returns:
Type | Description |
---|---|
RTDETRDetectionModel
|
Initialized model. |
Source code in ultralytics/models/rtdetr/train.py
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
|
get_validator
get_validator()
Return a DetectionValidator suitable for RT-DETR model validation.
Source code in ultralytics/models/rtdetr/train.py
88 89 90 91 |
|