I have a 24k x 10 sized weight matrix 'W'
to be plotted, as such:
plt.imshow(W, cmap='summer', interpolation='nearest')
plt.title("Weights") plt.show()
Here's what W
looks like:
print(type(W))
print(W.shape)
<class 'numpy.ndarray'>
(24684, 10)
But the plotting output is entirely squished:
How to fix this? I'd like to stretch it out to be rectangular/square and show the x-axis dimensions (even if just 10 of them..)
Thanks in advance.
CodePudding user response:
You can do it by setting the aspect of the image to auto.
By default, imshow sets the aspect of the plot to 1, which in your case does not show sufficient x-axis
So in your code, you can do something like:
plt.imshow(arr, cmap='summer', interpolation='nearest', aspect='auto')
instead of
plt.imshow(W, cmap='summer', interpolation='nearest')