I am trying to make a multicolor line plot using matplotlib. The color would change given a specific value in a column of my datataframe
time | v1 | v2 | state |
---|---|---|---|
0 | 3.5 | 8 | 0 |
1 | 3.8 | 8.5 | 0 |
2 | 4.2 | 9 | 1 |
3 | 5 | 12 | 0 |
4 | 8 | 10 | 2 |
My code for now, which just display the plot normally without the color:
cols=['v1','v2']
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(15, 15))
df.plot(x='time',y=cols,subplots=True, ax=axes)
plt.legend()
plt.xticks(rotation=45)
plt.show()
The result would be something like that (2nd graph), with the line changing color given the column state (red,blue,green) with 3 distinct colors
CodePudding user response:
for state, prev, cur in zip(df['state'].iloc[1:], df.index[:-1], df.index[1:]):
if state==0:
color='blue'
elif state==1:
color='orange'
else:
color='green'
plt.plot([df["time"][prev],df["time"][cur]],df.loc[[prev,cur],['v1','v2']], c=color)
plt.xticks(rotation=45)
plt.show()
CodePudding user response:
If you want to avoid for loops:
from matplotlib.collections import LineCollection
# other imports...
df[["time_shift", "v1_shift", "v2_shift"]] = df.shift(-1)[["time", "v1", "v2"]]
df = df.dropna()
# bulid separate line segments
lines1 = zip(df[["time", "v1"]].values, df[["time_shift", "v1_shift"]].values)
lines2 = zip(df[["time", "v2"]].values, df[["time_shift", "v2_shift"]].values)
# map "state" to RGB values, use black if mapping does not exist
color_map = {
0: (0.8, 0.1, 0.1),
1: (0.1, 0.8, 0.1),
2: (0.1, 0.1, 0.8),
}
colors = df["state"].apply(lambda x: color_map.get(x, (0, 0, 0))).tolist()
xlim = (df["time"].min(), df["time"].max())
ylim1 = (df["v1"].min(), df["v1"].max())
ylim2 = (df["v2"].min(), df["v2"].max())
fig, ax = plt.subplots(nrows=2, ncols=1)
ax[0].set_xlim(*xlim)
ax[0].set_ylim(*ylim1)
ax[0].add_collection(LineCollection(lines1, linestyle="solid", colors=colors))
ax[1].set_xlim(*xlim)
ax[1].set_ylim(*ylim2)
ax[1].add_collection(LineCollection(lines2, linestyle="solid", colors=colors))
plt.show()