Home > Software engineering >  Properly format 3d subplots in python
Properly format 3d subplots in python

Time:02-16

So I have 8 3d plots (7 3d plots and one 2d plot)

I want to position them in 4 x 2 format. Here is my code for it:

sensor_data = self._util_sensor(sub_df)
fig = plt.figure()
fig.tight_layout()

ax = fig.add_subplot(1, 2, 1, projection = '3d')
ax.scatter(sensor_data['chest'][0], sensor_data['chest'][1], sensor_data['chest'][2])
ax.set_title('Chest')
            
ax = fig.add_subplot(1, 2, 2)
ax.scatter(sensor_data['ecg'][0], sensor_data['ecg'][1])
ax.set_title('ECG')
            
ax = fig.add_subplot(2, 2, 1, projection = '3d')
ax.scatter(sensor_data['left_accel'][0], sensor_data['left_accel'][1], sensor_data['left_accel'][2])
ax.set_title('left accel')
            
ax = fig.add_subplot(2, 2, 2, projection = '3d')
ax.scatter(sensor_data['left_gyro'][0], sensor_data['left_gyro'][1], sensor_data['left_gyro'][2])
ax.set_title('left gyro')
            
ax = fig.add_subplot(3, 2, 1, projection = '3d')
ax.scatter(sensor_data['left_mag'][0], sensor_data['left_mag'][1], sensor_data['left_mag'][2])
ax.set_title('left mag')
            
ax = fig.add_subplot(3, 2, 2, projection = '3d')
ax.scatter(sensor_data['right_accel'][0], sensor_data['right_accel'][1], sensor_data['right_accel'][2])
ax.set_title('right accel')

ax = fig.add_subplot(4, 2, 1, projection = '3d')
ax.scatter(sensor_data['right_gyro'][0], sensor_data['right_gyro'][1], sensor_data['right_gyro'][2])
ax.set_title('right gyro')

ax = fig.add_subplot(4, 2, 2, projection = '3d')
ax.scatter(sensor_data['right_mag'][0], sensor_data['right_mag'][1], sensor_data['right_mag'][2])
ax.set_title('right mag')
plt.show()

The result is the following image. How do I properly format these? enter image description here

CodePudding user response:

There is something wrong with how you use rows, columns and index in add_subplot. I hope this here helps:

fig = plt.figure()
for i in range(6):
    ax = fig.add_subplot(3, 2, i 1, projection = '3d')

The first numbers are the (total) number of rows and columns, the 3rd number is the "index", i.e. going from 1..6 (=rows * columns) for each of the subplots.

This is how it looks like: enter image description here

But to completely answer the question here would be what needs to be done:

fig = plt.figure(figsize=(6, 12))
ax1 = fig.add_subplot(3, 2, 1, projection = '3d')
ax2 = fig.add_subplot(3, 2, 2)
ax3 = fig.add_subplot(3, 2, 3, projection = '3d')
ax4 = fig.add_subplot(3, 2, 4, projection = '3d')
ax5 = fig.add_subplot(3, 2, 5, projection = '3d')
ax6 = fig.add_subplot(3, 2, 6, projection = '3d')

You can then use the axes (ax1..6) as, e.g.,

ax2.plot([1, 3, 2, 7])
  • Related