I'm trying to implement the Learner
object and its steps and facing an issue with the loss.backward()
function as it raises and AttributeError: 'NoneType' object has no attribute 'data'
The entire process works when I follow the Chapter 04 MNIST Basics. However, implementing within a class raises this error. Could anybody guide me on why this occurs and ways to fix this?
Here's the code below:
class Basic_Optim:
def __init__(self, params, lr):
self.params = list(params)
self.lr = lr
def step(self):
for p in self.params:
p.data -= self.lr * p.grad.data
def zero(self):
for p in self.params:
p.grad = None
class Learner_self:
def __init__(self, train, valid, model, loss, metric, params, lr):
self.x = train
self.y = valid
self.model = model
self.loss = loss
self.metric = metric
self.opt_func = Basic_Optim(params, lr)
def fit(self, epochs):
for epoch in range(epochs):
self.train_data()
score = self.valid_data()
print(score, end = ' | ')
def train_data(self):
for x, y in self.x:
preds = self.model(x)
loss = self.loss(preds, y)
loss_b = loss.backward()
print(f'Loss: {loss:.4f}, Loss Backward: {loss_b}')
self.opt_func.step()
self.opt_func.zero()
def valid_data(self):
accuracy = [self.metric(xb, yb) for xb, yb in self.y]
return round(torch.stack(accuracy).mean().item(), 4)
learn = Learner_self(dl, valid_dl, simple_net, mnist_loss, metric=batch_accuracy,
params=linear_model.parameters(), lr = 1)
learn.fit(10)
OUTPUT from the print statement inside the train_data
prints: Loss: 0.0516, Loss Backward: None
and then raises the Attribute error shared above.
Please let me know if you want any more details. Every other function such as mnist_loss
, batch_accuracy
, simple_net
are exactly the same from the book.
Thank you in advance.
CodePudding user response:
It seems like your optimizer and your trainer do not work on the same model.
You have model=simple_net
, while the parameters for the optimizer are those of a different model params=linear_model.parameters()
.
Try passing params=simple_net.parameters()
-- that is, make sure the trainer's params
are those of model
.