Good evening,
I am a beginner in Pytorch lightning and I am trying to implement a NN and plot the graph (loss and accuracy) on various sets.
The code is this one
def training_step(self, train_batch, batch_idx):
X, y = train_batch
y_copy = y # Integer y for the accuracy
X = X.type(torch.float32)
y = y.type(torch.float32)
# forward pass
y_pred = self.forward(X).squeeze()
# accuracy
accuracy = Accuracy()
acc = accuracy(y_pred, y_copy)
# compute loss
loss = self.loss_fun(y_pred, y)
self.log_dict({'train_loss': loss, 'train_accuracy': acc}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, validation_batch, batch_idx):
X, y = validation_batch
X = X.type(torch.float32)
# forward pass
y_pred = self.forward(X).squeeze()
# compute metrics
accuracy = Accuracy()
acc = accuracy(y_pred, y)
loss = self.loss_fun(y_pred, y)
self.log_dict({'validation_loss': loss, 'validation_accuracy': acc}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
def test_step(self, test_batch, batch_idx):
X, y = test_batch
X = X.type(torch.float32)
# forward pass
y_pred = self.forward(X).squeeze()
# compute metrics
accuracy = Accuracy()
acc = accuracy(y_pred, y)
loss = self.loss_fun(y_pred, y)
self.log_dict({'test_loss': loss, 'test_accuracy': acc}, on_epoch=True, prog_bar=True, logger=True)
return loss
After training the NN, I run this peace of code:
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
del metrics["step"]
metrics
It is okay than on the validation set there's only one accuracy and one loss, because I am performing an hold out CV. On the test set, I noticed that the value test_accuracy=0.97
is the mean of all the accuracy for each epoch. with that I can't see the intermediate values (for each epoch) and then I can't plot any curve. It would be useful also when I'll do a cross validation with KFold.
Why he's taking the mean and How can I see the intermediate results ? For the training_step
it works properly, I can't figure out why the logger doesn't perform the same print for the test_step
.
Can someone help me please ?
CodePudding user response:
It looks like the test_step method is logging the metrics using the on_epoch parameter, which means that the logged values will be averaged over the entire epoch and only logged once per epoch. To log the metrics at each step, you should set the on_epoch parameter to False in the test_step method like this:
self.log_dict({'test_loss': loss, 'test_accuracy': acc}, on_epoch=False, prog_bar=True, logger=True)
This will log the loss and accuracy at each step in the test_step method, allowing you to see the intermediate values and plot the curve.