I am trying to draw surface chart with matplotlib, but I have some issues. Here is the code
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
K=1024
M=1024*K
#X = np.array([ 4*K, 8*K ])
#Y = np.array([ 1, 4 ])
#Z = np.array([[0.1925, 0.1848], [0.1807, 0.1966]])
X = np.array([1, 4, 7, 10, 13, 16, 19, 22, 25, 28])
Y = np.array([4*K, 8*K, 16*K, 32*K, 64*K, 128*K, 256*K, 512*K, 1*M, 2*M, 4*M, 8*M, 16*M, 32*M, 64*M, 128*M, 256*M])
Z= np.array([[0.1925, 0.1848, 0.1955, 0.3149, 0.3149, 0.3146, 0.3149, 0.7689, 0.9035, 0.9883, 1.1056, 1.158, 1.1766, 1.197, 1.1996, 1.1979, 1.204],
[0.1807, 0.1966, 0.2026, 0.3286, 0.3287, 0.3297, 0.3296, 0.7029, 0.8511, 0.9371, 1.075, 1.1416, 1.164, 1.1813, 1.1789, 1.1857, 1.1847],
[0.2685, 0.2594, 0.277, 0.3624, 0.3647, 0.3644, 0.3641, 0.7091, 0.8583, 0.9198, 1.0631, 1.1231, 1.1538, 1.1665, 1.173, 1.1775, 1.1765],
[0.3846, 0.3542, 0.3871, 0.4351, 0.4361, 0.4363, 0.4374, 0.7554, 0.8977, 0.9534, 1.0816, 1.1301, 1.1526, 1.1649, 1.1683, 1.1734, 1.1754],
[0.5504, 0.4978, 0.5304, 0.569, 0.5553, 0.5668, 0.5665, 0.836, 0.9877, 1.0203, 1.1456, 1.1987, 1.2169, 1.2277, 1.2277, 1.2333, 1.2383],
[0.7146, 0.6435, 0.6744, 0.6977, 0.6939, 0.7009, 0.7062, 0.9823, 1.1053, 1.1281, 1.2141, 1.2532, 1.2637, 1.2746, 1.2737, 1.2821, 1.3354],
[1.2495, 1.0737, 1.0957, 1.1029, 1.0825, 1.083, 1.0879, 1.2768, 1.4372, 1.5031, 1.5488, 1.5618, 1.5702, 1.5699, 1.5753, 1.5829, 1.646],
[1.988, 1.6089, 1.4796, 1.4147, 1.3781, 1.3613, 1.3468, 1.5427, 1.6185, 1.7221, 1.7601, 1.796, 1.8099, 1.8121, 1.8164, 1.8309, 1.8226],
[2.8714, 2.7638, 2.7661, 2.1941, 1.9076, 1.7578, 1.6938, 1.837, 1.9554, 2.0082, 2.0058, 2.0111, 2.0065, 2.0083, 2.009, 1.9993, 2.008],
[3.1903, 3.0468, 3.0542, 3.0647, 3.0645, 3.0871, 2.5002, 2.4192, 2.2776, 2.221, 2.2625, 2.2613, 2.2489, 2.2468, 2.2723, 2.268, 2.2883]])
# Plot a basic wireframe.
ax.set_xticks(X, ['4K', '8K', '16K', '32K', '64K', '128K', '256K', '512K', '1M', '2M', '4M', '8M', '16M', '32M', '64M', '128M', '256M'])
ax.plot_wireframe(X, Y, Z)
plt.show()
But I get an error and I don't know what to do from this point:
Traceback (most recent call last):
File "wire3d.py", line 40, in <module>
ax.plot_wireframe(X, Y, Z)
File "/usr/local/lib/python3.8/dist-packages/matplotlib/_api/deprecation.py", line 415, in wrapper
return func(*inner_args, **inner_kwargs)
File "/usr/local/lib/python3.8/dist-packages/mpl_toolkits/mplot3d/axes3d.py", line 1836, in plot_wireframe
X, Y, Z = np.broadcast_arrays(X, Y, Z)
File "<__array_function__ internals>", line 180, in broadcast_arrays
File "/usr/local/lib/python3.8/dist-packages/numpy/lib/stride_tricks.py", line 540, in broadcast_arrays
shape = _broadcast_shape(*args)
File "/usr/local/lib/python3.8/dist-packages/numpy/lib/stride_tricks.py", line 422, in _broadcast_shape
b = np.broadcast(*args[:32])
ValueError: shape mismatch: objects cannot be broadcast to a single shape. Mismatch is between arg 0 with shape (10,) and arg 1 with shape (17,).
Help please!
CodePudding user response:
Based on your data, it looks like you meant to call set_yticks(Y, ...)
instead.
CodePudding user response:
All the arrays you send to the wire mesh command should have the same shape. Use the np.meshgrid()
command for it. Keep the vectors of X and Y for labelling:
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
K=1024
M=1024*K
X_v = np.array([1, 4, 7, 10, 13, 16, 19, 22, 25, 28])
Y_v = np.array([4*K, 8*K, 16*K, 32*K, 64*K, 128*K, 256*K, 512*K, 1*M, 2*M, 4*M, 8*M, 16*M, 32*M, 64*M, 128*M, 256*M])
Z = np.array([[0.1925, 0.1848, 0.1955, 0.3149, 0.3149, 0.3146, 0.3149, 0.7689, 0.9035, 0.9883, 1.1056, 1.158, 1.1766, 1.197, 1.1996, 1.1979, 1.204],
[0.1807, 0.1966, 0.2026, 0.3286, 0.3287, 0.3297, 0.3296, 0.7029, 0.8511, 0.9371, 1.075, 1.1416, 1.164, 1.1813, 1.1789, 1.1857, 1.1847],
[0.2685, 0.2594, 0.277, 0.3624, 0.3647, 0.3644, 0.3641, 0.7091, 0.8583, 0.9198, 1.0631, 1.1231, 1.1538, 1.1665, 1.173, 1.1775, 1.1765],
[0.3846, 0.3542, 0.3871, 0.4351, 0.4361, 0.4363, 0.4374, 0.7554, 0.8977, 0.9534, 1.0816, 1.1301, 1.1526, 1.1649, 1.1683, 1.1734, 1.1754],
[0.5504, 0.4978, 0.5304, 0.569, 0.5553, 0.5668, 0.5665, 0.836, 0.9877, 1.0203, 1.1456, 1.1987, 1.2169, 1.2277, 1.2277, 1.2333, 1.2383],
[0.7146, 0.6435, 0.6744, 0.6977, 0.6939, 0.7009, 0.7062, 0.9823, 1.1053, 1.1281, 1.2141, 1.2532, 1.2637, 1.2746, 1.2737, 1.2821, 1.3354],
[1.2495, 1.0737, 1.0957, 1.1029, 1.0825, 1.083, 1.0879, 1.2768, 1.4372, 1.5031, 1.5488, 1.5618, 1.5702, 1.5699, 1.5753, 1.5829, 1.646],
[1.988, 1.6089, 1.4796, 1.4147, 1.3781, 1.3613, 1.3468, 1.5427, 1.6185, 1.7221, 1.7601, 1.796, 1.8099, 1.8121, 1.8164, 1.8309, 1.8226],
[2.8714, 2.7638, 2.7661, 2.1941, 1.9076, 1.7578, 1.6938, 1.837, 1.9554, 2.0082, 2.0058, 2.0111, 2.0065, 2.0083, 2.009, 1.9993, 2.008],
[3.1903, 3.0468, 3.0542, 3.0647, 3.0645, 3.0871, 2.5002, 2.4192, 2.2776, 2.221, 2.2625, 2.2613, 2.2489, 2.2468, 2.2723, 2.268, 2.2883]])
Y,X =np.meshgrid(Y_v,X_v)
#Plot a basic wireframe.
ax.set_yticks(Y_v, ['4K', '8K', '16K', '32K', '64K', '128K', '256K', '512K', '1M', '2M', '4M', '8M', '16M', '32M', '64M', '128M', '256M'])
ax.plot_wireframe(X, Y, Z)
plt.show()