Home > Mobile >  Pytorch: '<=' not supported between instances of 'float' and 'function&#
Pytorch: '<=' not supported between instances of 'float' and 'function&#

Time:10-06

I am trying to calculate Intersection over Union (IOU) score. Here's my code implementation, which works fine.

def IoU(predict: torch.Tensor, target: torch.Tensor):

    i = (predict & target).float().sum()
    u = (predict | target).float().sum()
    x = i/u
    IOU = x.item()

    return IoU

But when I run my unit test:

def test_IoU1():
    pred = torch.tensor([[1, 0], [1, 0]])
    target = torch.tensor([[1, 0], [1, 1]])
    
    iou = IoU(pred,target)
    
    assert 0.66 <= iou
    assert iou <= 2/3

I get:

 TypeError: '<=' not supported between instances of 'float' and 'function'

How do I fix this without changing anything on the unit test? Thank you

CodePudding user response:

In this function

def IoU(predict: torch.Tensor, target: torch.Tensor):

    i = (predict & target).float().sum()
    u = (predict | target).float().sum()
    x = i/u
    IOU = x.item()
    
    return IoU

You are returning IoU which is the name of function, I suppose you need to return IOU. So the correct way would be -

def IoU(predict: torch.Tensor, target: torch.Tensor):

    i = (predict & target).float().sum()
    u = (predict | target).float().sum()
    x = i/u
    IOU = x.item()
    
    return IOU
  • Related