I trained a model, now I would like to use it to detect objects in images. Using the DefaultDetector only the boundyboxes are returned, I would need the masks. I saw that you can also perform inference with this method:
model.eval()
with torch.no_grad():
outputs = model(inputs)
I think that's what he should use. The problem is that I don't know how to set the inputs, starting with images.
import torch
import glob
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/"
"mask_rcnn_R_101_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.SOLVER.IMS_PER_BATCH = 1
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class
cfg.INPUT.FORMAT = "BGR"
#Just run these lines if you have the trained model im memory
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set the testing threshold for this model
#build model
model = build_model(cfg)
DetectionCheckpointer(model).load("output/model_final.pth")
model.eval()#make sure its in eval mode
image = cv2.imread("/kaggle/working/detectron2/images/73-ab1.jpg")
height, width = image.shape[:2]
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
image = ImageList.from_tensors([image])
with torch.no_grad():
inputs = image
outputs = model(inputs)
Unfortunately, however, I think I'm wrong, can someone enlighten me?
CodePudding user response:
See the Model Input Format for the builtin models.
Basically, the model in your code is not expecting an ImageList
object, but a list
of dict
s where each dict
needs to provide specific information about one image, as explained in the documentation linked above.
So, your inference code needs to be corrected to the following.
image = cv2.imread("/kaggle/working/detectron2/images/73-ab1.jpg")
height, width = image.shape[:2]
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = [{"image": image, "height": height, "width": width}]
with torch.no_grad():
outputs = model(inputs)
You can also see this in the code - the forward
method of the GeneralizedRCNN
class.