When I ran the code below:
import torchvision
model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
params[var] *= 0.1
a RuntimeError was reported:
RuntimeError: result type Float can't be cast to the desired output type Long
But when I changed params[var] *= 0.1
to params[var] = params[var] * 0.1
, the error disappears.
Why would this happen?
I thought params[var] *= 0.1
had the same effect as params[var] = params[var] * 0.1
.
CodePudding user response:
First, let us know the first long-type parameter in densenet201
, you will find the features.norm0.num_batches_tracked
which indicates the number of mini-batches during training used to calculate the mean and variance if there is BatchNormalization layer in the model. This parameter is a long-type number and cannot be float type because it behaves like a counter
.
Second, in PyTorch, there are two types of operations:
- Non-Inplace operations: you assign the new output after calculation to a new copy from the variable, e.g. x = x 1 or x = x / 2. The memory location of x before assignment not equal to the memory location after assignment because you have a copy from the original variable.
- Inplace operations: when the calculations directly applied to the original copy of the variable without making any copy here e.g. x = 1 or x /= 2.
Let's move to your example to understand what happened:
Non-Inplcae operation:
model = torchvision.models.densenet201(num_classes=10) params = model.state_dict() name = 'features.norm0.num_batches_tracked' print(id(params[name])) # 140247785908560 params[name] = params[name] 0.1 print(id(params[name])) # 140247785908368 print(params[name].type()) # changed to torch.FloatTensor
Inplace operation:
print(id(params[name])) # 140247785908560 params[name] = 1 print(id(params[name])) # 140247785908560 print(params[name].type()) # still torch.LongTensor params[name] = 0.1 # you want to change the original copy type to float ,you got an error
Finally, some remarks:
- In-place operations save some memory, but can be problematic when computing derivatives because of an immediate loss of history. Hence, their use is discouraged. Source
- You should be cautious when you decide to use in-place operations since they overwrite the original content.
- If you use pandas, this is a bit similar to the
inplace=True
in pandas :).
This is a good resource to read more about in-place operation source and read also this discussion source.