I'm trying to understand why I cannot directly overwrite the weights of a torch layer. Consider the following example:
import torch
from torch import nn
net = nn.Linear(3, 1)
weights = torch.zeros(1,3)
# Overwriting does not work
net.state_dict()["weight"] = weights # nothing happens
print(f"{net.state_dict()['weight']=}")
# But mutating does work
net.state_dict()["weight"][0] = weights # indexing works
print(f"{net.state_dict()['weight']=}")
#########
# output
: net.state_dict()['weight']=tensor([[ 0.5464, -0.4110, -0.1063]])
: net.state_dict()['weight']=tensor([[0., 0., 0.]])
I'm confused since state_dict()["weight"]
is just a torch tensor, so I feel I'm missing something really obvious here.
CodePudding user response:
I dont have torch installed right now,but try something like this from some saved code I have. I believe you need to make deep copys, like so
def zero_injection(initial_weights, trained_weights, mask):
''' zeros all weights and then injects in masked selection '''
# copy the weights
initial_weights_copy = copy.deepcopy(initial_weights.state_dict())
trained_weights_copy = copy.deepcopy(trained_weights.state_dict())
# set all the values to zero
for key, value in initial_weights_copy.items():
initial_weights_copy[key][initial_weights_copy[key] < 0] = 0
initial_weights_copy[key][initial_weights_copy[key] > 0] = 0
state_dict = {}
# for each key
for key, value in initial_weights_copy.items():
# add the key
state_dict[key] = []
# if False, replace initial value with trained value
state_dict[key] = initial_weights_copy[key].cuda().where(mask[key].cuda(), trained_weights_copy[key].cuda())
return state_dict
CodePudding user response:
This is because net.state_dict()
first creates a collections.OrderedDict
object, then stores the weight tensor(s) of this module to it, and returns the dict:
state_dict = net.state_dict()
print(type(state_dict)) # <class 'collections.OrderedDict'>
When you "overwrite" (it's in fact not an overwrite; it's assignment in python) this ordered dict, you reassign an int 0 to the key 'weights'
of this ordered dict. The data in that tensor is not modified, it's just not referred to by the ordered dict.
When you check whether the tensor is modified by:
print(f"{net.state_dict()['weight']}")
a new ordered dict different from the one you have modified is created, so you see the unchanged tensor.
However, when you use indexing like this:
net.state_dict()["weight"][0] = weights # indexing works
then it's not assignment to the ordered dict anymore. Instead, the __setitem__
method of the tensor is called, which allows you to access and modify the underlying memory inplace. Other tensor APIs such as copy_
can also achieve desired results.
A clear explanation on the difference of a = b
and a[:] = b
when a
is a tensor/array can be found here: https://stackoverflow.com/a/68978622/11790637