Home > Software design >  how to make a multi-colored line in matplotlib
how to make a multi-colored line in matplotlib

Time:08-20

I am trying to make a multicolor line plot using matplotlib. The color would change given a specific value in a column of my datataframe

time v1 v2 state
0 3.5 8 0
1 3.8 8.5 0
2 4.2 9 1
3 5 12 0
4 8 10 2

My code for now, which just display the plot normally without the color:

cols=['v1','v2']
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(15, 15))
df.plot(x='time',y=cols,subplots=True, ax=axes)   
plt.legend()
plt.xticks(rotation=45)
plt.show()

The result would be something like that (2nd graph), with the line changing color given the column state (red,blue,green) with 3 distinct colors

result

CodePudding user response:

for state, prev, cur in zip(df['state'].iloc[1:], df.index[:-1], df.index[1:]):
    if state==0: 
        color='blue'
    elif state==1:
        color='orange'
    else:
        color='green'
    plt.plot([df["time"][prev],df["time"][cur]],df.loc[[prev,cur],['v1','v2']], c=color)
plt.xticks(rotation=45)
plt.show()

CodePudding user response:

If you want to avoid for loops:

from matplotlib.collections import LineCollection
# other imports...

df[["time_shift", "v1_shift", "v2_shift"]] = df.shift(-1)[["time", "v1", "v2"]]
df = df.dropna()

# bulid separate line segments
lines1 = zip(df[["time", "v1"]].values, df[["time_shift", "v1_shift"]].values)
lines2 = zip(df[["time", "v2"]].values, df[["time_shift", "v2_shift"]].values)

# map "state" to RGB values, use black if mapping does not exist
color_map = {
    0: (0.8, 0.1, 0.1),
    1: (0.1, 0.8, 0.1),
    2: (0.1, 0.1, 0.8),
}
colors = df["state"].apply(lambda x: color_map.get(x, (0, 0, 0))).tolist()

xlim = (df["time"].min(), df["time"].max())
ylim1 = (df["v1"].min(), df["v1"].max())
ylim2 = (df["v2"].min(), df["v2"].max())

fig, ax = plt.subplots(nrows=2, ncols=1)
ax[0].set_xlim(*xlim)
ax[0].set_ylim(*ylim1)
ax[0].add_collection(LineCollection(lines1, linestyle="solid", colors=colors))
ax[1].set_xlim(*xlim)
ax[1].set_ylim(*ylim2)
ax[1].add_collection(LineCollection(lines2, linestyle="solid", colors=colors))
plt.show()
  • Related