Home > Back-end >  Result type cast error when doing calculations with Pytorch model parameters
Result type cast error when doing calculations with Pytorch model parameters

Time:03-12

When I ran the code below:

import torchvision

model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
    params[var] *= 0.1

a RuntimeError was reported:

RuntimeError: result type Float can't be cast to the desired output type Long

But when I changed params[var] *= 0.1 to params[var] = params[var] * 0.1, the error disappears.

Why would this happen?

I thought params[var] *= 0.1 had the same effect as params[var] = params[var] * 0.1.

CodePudding user response:

First, let us know the first long-type parameter in densenet201, you will find the features.norm0.num_batches_tracked which indicates the number of mini-batches during training used to calculate the mean and variance if there is BatchNormalization layer in the model. This parameter is a long-type number and cannot be float type because it behaves like a counter.

Second, in PyTorch, there are two types of operations:

  • Non-Inplace operations: you assign the new output after calculation to a new copy from the variable, e.g. x = x 1 or x = x / 2. The memory location of x before assignment not equal to the memory location after assignment because you have a copy from the original variable.
  • Inplace operations: when the calculations directly applied to the original copy of the variable without making any copy here e.g. x = 1 or x /= 2.

Let's move to your example to understand what happened:

  1. Non-Inplcae operation:

    model = torchvision.models.densenet201(num_classes=10)
    params = model.state_dict()
    name = 'features.norm0.num_batches_tracked'
    
    print(id(params[name]))  # 140247785908560
    params[name] = params[name]   0.1
    print(id(params[name]))  # 140247785908368  
    print(params[name].type()) # changed to torch.FloatTensor
    
  2. Inplace operation:

    print(id(params[name]))  # 140247785908560
    params[name]  = 1
    print(id(params[name]))  # 140247785908560 
    print(params[name].type()) # still torch.LongTensor
    
    params[name]  = 0.1     # you want to change the original copy type to float ,you got an error
    

Finally, some remarks:

  • In-place operations save some memory, but can be problematic when computing derivatives because of an immediate loss of history. Hence, their use is discouraged. Source
  • You should be cautious when you decide to use in-place operations since they overwrite the original content.
  • If you use pandas, this is a bit similar to the inplace=True in pandas :).

This is a good resource to read more about in-place operation source and read also this discussion source.

  • Related