I have a tensor of size 10
, with only 1
values: 0
and 1
.
I want to plot an histogram of the tensor above, simply using matplotlib.pyplot.hist
. This is my code:
import torch
import matplotlib.pyplot as plt
t = torch.tensor([0., 1., 1., 1., 0., 1., 1., 0., 0., 0.])
print(t)
plt.hist(t, bins=2)
plt.show()
And the output:
Why are there so many values in the histogram? Where did the rest of the values come from? How can I plot a correct histogram for my tensor?
CodePudding user response:
The plt.hist(t, bins=2)
function is not meant to work with tensors. For this to work properly, you can try using t.numpy()
or t.tolist()
instead. As far as I could educate myself, the way to compute a histogramwith pytorch is through the torch.histc()
function and to plot the histogram you use plt.bar()
function as follows:
import torch
import matplotlib.pyplot as plt
t = torch.tensor([0., 0., 1., 1., 0., 1., 1., 0., 0., 0.])
hist = torch.histc(t, bins = 2, min = 0, max = 1)
bins = 2
x = range(bins)
plt.bar(x, hist, align='center')
plt.xlabel('Bins')
Some sources for plotting a histogram can be seen here and here . I could not find the root cause for this and if some could educate me it will be great, but as far as I am aware, this is the way to plot a tensor
I changed the tensor to have 4 '1.0' and 6 '0.0' to be able to see the difference