Home > front end >  Overwriting vs mutating pytorch weights
Overwriting vs mutating pytorch weights

Time:06-30

I'm trying to understand why I cannot directly overwrite the weights of a torch layer. Consider the following example:

import torch
from torch import nn

net = nn.Linear(3, 1)
weights = torch.zeros(1,3)

# Overwriting does not work
net.state_dict()["weight"] = weights  # nothing happens
print(f"{net.state_dict()['weight']=}")

# But mutating does work
net.state_dict()["weight"][0] = weights  # indexing works
print(f"{net.state_dict()['weight']=}")

#########
# output
: net.state_dict()['weight']=tensor([[ 0.5464, -0.4110, -0.1063]])
: net.state_dict()['weight']=tensor([[0., 0., 0.]])

I'm confused since state_dict()["weight"] is just a torch tensor, so I feel I'm missing something really obvious here.

CodePudding user response:

I dont have torch installed right now,but try something like this from some saved code I have. I believe you need to make deep copys, like so

def zero_injection(initial_weights, trained_weights, mask):
    ''' zeros all weights and then injects in masked selection '''
    # copy the weights
    initial_weights_copy = copy.deepcopy(initial_weights.state_dict())
    trained_weights_copy = copy.deepcopy(trained_weights.state_dict())

    # set all the values to zero
    for key, value in initial_weights_copy.items():
        initial_weights_copy[key][initial_weights_copy[key] < 0] = 0
        initial_weights_copy[key][initial_weights_copy[key] > 0] = 0

    state_dict = {}
    # for each key
    for key, value in initial_weights_copy.items():
        # add the key
        state_dict[key] = []
        # if False, replace initial value with trained value
        state_dict[key] = initial_weights_copy[key].cuda().where(mask[key].cuda(), trained_weights_copy[key].cuda())

    return state_dict

CodePudding user response:

This is because net.state_dict() first creates a collections.OrderedDict object, then stores the weight tensor(s) of this module to it, and returns the dict:

state_dict = net.state_dict()
print(type(state_dict))    # <class 'collections.OrderedDict'>

When you "overwrite" (it's in fact not an overwrite; it's assignment in python) this ordered dict, you reassign an int 0 to the key 'weights' of this ordered dict. The data in that tensor is not modified, it's just not referred to by the ordered dict.

When you check whether the tensor is modified by:

print(f"{net.state_dict()['weight']}")

a new ordered dict different from the one you have modified is created, so you see the unchanged tensor.

However, when you use indexing like this:

net.state_dict()["weight"][0] = weights  # indexing works

then it's not assignment to the ordered dict anymore. Instead, the __setitem__ method of the tensor is called, which allows you to access and modify the underlying memory inplace. Other tensor APIs such as copy_ can also achieve desired results.

A clear explanation on the difference of a = b and a[:] = b when a is a tensor/array can be found here: https://stackoverflow.com/a/68978622/11790637

  • Related