Home > Software design >  PyTorch RetinaNet train model inputs
PyTorch RetinaNet train model inputs

Time:10-19

I have model = torchvision.models.detection.retinanet_resnet50_fpn_v2(progress=True) and would like to train it on custom data. To get the loss, I have to exececute

classification_loss, regression_loss = model(images, targets)

I have create a batch tensor for images, but for the life of me, cannot find how I am supposed to format targets for object detection... Each target has a bounding box and a class label.

CodePudding user response:

check this official tutorial: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

In general , targets is a list of dict, for e.g

targets = [
    {
        "boxes": torch.as_tensor([[xmin, ymin, xmax, ymax]], dtype=torch.float32),
        "labels": torch.as_tensor([1,], dtype=torch.int64)
    }
]
  • Related