I have the error being displayed whilst trying to plot the graph...
I am sharing the code in the following link:
https://colab.research.google.com/drive/1nILxtGSSCmOKrcHg3-_SL0l2bsIwUM1p?usp=sharing
I think I'm missing 'tensor.cpu()' somewhere but I can't really pinpoint it.. Everything else works :/ Can anyone help please?
CodePudding user response:
The thing is that the result of torch.sum(...)
is a tensor. Try to change it in the following way in the lines where you add to running_corrects/val_running_corrects: torch.sum(...).item()
and then update code accordingly. There are probably other ways to do it but it should work. Here is very similar question by the way: https://stackoverflow.com/a/72704295/14787618