Home > other >  Matplot histogram of Pytorch Tensor
Matplot histogram of Pytorch Tensor

Time:03-14

I have a tensor of size 10, with only 1 values: 0 and 1.

I want to plot an histogram of the tensor above, simply using matplotlib.pyplot.hist. This is my code:

import torch
import matplotlib.pyplot as plt
t = torch.tensor([0., 1., 1., 1., 0., 1., 1., 0., 0., 0.])
print(t)
plt.hist(t, bins=2)
plt.show()

And the output:

enter image description here

Why are there so many values in the histogram? Where did the rest of the values come from? How can I plot a correct histogram for my tensor?

CodePudding user response:

The plt.hist(t, bins=2) function is not meant to work with tensors. For this to work properly, you can try using t.numpy() or t.tolist() instead. As far as I could educate myself, the way to compute a histogramwith pytorch is through the torch.histc() function and to plot the histogram you use plt.bar() function as follows:

import torch
import matplotlib.pyplot as plt

t = torch.tensor([0., 0., 1., 1., 0., 1., 1., 0., 0., 0.])
hist = torch.histc(t, bins = 2, min = 0, max = 1)

bins = 2
x = range(bins)
plt.bar(x, hist, align='center')
plt.xlabel('Bins')

Some sources for plotting a histogram can be seen here and here . I could not find the root cause for this and if some could educate me it will be great, but as far as I am aware, this is the way to plot a tensor

I changed the tensor to have 4 '1.0' and 6 '0.0' to be able to see the difference

  • Related