Home > OS >  Plot a single line in multiple colors
Plot a single line in multiple colors

Time:05-01

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

rng = np.random.default_rng()

data = pd.DataFrame({
    'group': pd.Categorical(['a', 'b', 'b', 'a', 'a', 'a', 'b', 'b']),
})
data['value'] = rng.uniform(size=len(data))

Using either Matplotlib or Seaborn, is there a straightforward way to plot this data as a single line, but where the line is colored according to the group? It's not really important where exactly the color changes in between two points, as long as it's consistent.

CodePudding user response:

You could make a dictionary of colours and plot the line between adjacent points in a for loop:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

rng = np.random.default_rng()

data = pd.DataFrame({
    'group': pd.Categorical(['a', 'b', 'b', 'a', 'a', 'a', 'b', 'b']),
})
colours = {'a': 'red', 'b':'blue'}
data['value'] = rng.uniform(size=len(data))
plt.figure()
for i in range(len(data['value'])-1):
    plt.plot([i, i 1],[data['value'][i], data['value'][i 1]], color=colours[data['group'][i]])
plt.show()

Output: enter image description here

CodePudding user response:

I don't know whether this qualifies as "straightforward", but:

from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

rng = np.random.default_rng()

data = pd.DataFrame({
    'group': pd.Categorical(['a', 'b', 'b', 'a', 'a', 'a', 'b', 'a']),
})
data['value'] = rng.uniform(size=len(data))

f, ax = plt.subplots()

for i in range(len(data)-1):
    ax.plot([data.index[i], data.index[i 1]], [data['value'].iat[i], data['value'].iat[i 1]], color=f'C{data.group.cat.codes.iat[i]}', linewidth=2, marker='o')

# To remain consistent, the last point should be of the correct color. 
# Here, I changed the last point's group to 'a' for an example.
ax.plot([data.index[-1]]*2, [data['value'].iat[-1]]*2, color=f'C{data.group.cat.codes.iat[-1]}', linewidth=2, marker='o')    

legend_lines = [Line2D([0], [0], color=f'C{code}', lw=2) for code in data['group'].unique().codes]
legend_labels = [g for g in data['group'].unique()]
plt.legend(legend_lines, legend_labels, title='group')
plt.show()

Which results in:

Example

  • Related