Imagine a dataframe like this:
i X Y Z label
1 23 45 23 0
2 56 67 24 0
3 34 87 25 0
4 43 78 26 0
5 45 45 37 1
6 34 98 38 1
7 23 45 39 1
8 34 76 40 1
9 54 87 41 1
I know how to visualize x,y,z using matplot, but the problem is I want to set the color of each data using the label column, for example all 0 labeled rows should be in green and 1 labeled rows should be orange. I'm pretty new to python and it would be grate if give an example implementation.
Thank you so much for your help.
CodePudding user response:
You can do this with set color for each label
and plot scatter
like below:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
np.random.seed(365)
df = pd.DataFrame({
'X': np.random.rand(200),
'Y': np.random.rand(200),
'Z': np.random.rand(200),
'label' : np.hstack((np.zeros(100),np.ones(100)))})
fig = plt.figure(figsize=(12,7))
ax = fig.add_subplot(projection = '3d')
colors = {0 : 'orange', 1:'g'}
for l in df['label'].unique():
ax.scatter(xs = df.loc[df.label == l, 'X'],
ys = df.loc[df.label == l, 'Y'],
zs = df.loc[df.label == l, 'Z'],
color = colors[l])
plt.show()
Output: