Since I'm plotting a subset of the dataset in each subplot, I would like to have the color of each class consistent, in whichever plot it appears. For class, a blue
color for the car
mode in whichever subplot it appears, a black
color for bus
mode in which ever plot it appears, etc...
As it is now, foot
is red in subplot(2, 2, 1)
, and also car
is read in subplot(2, 2, 2)
although the rest are consistent.
CodePudding user response:
For this use case, seaborn allows a dictionary as palette. The dictionary will assign a color to each hue value.
Here is an example of how such a dictionary could be created for your data:
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
df1 = pd.DataFrame({'tsne_one': np.random.randn(10),
'tsne-two': np.random.randn(10),
'mode': np.random.choice(['foot', 'metro', 'bike'], 10)})
df2 = pd.DataFrame({'tsne_one': np.random.randn(10),
'tsne-two': np.random.randn(10),
'mode': np.random.choice(['car', 'metro', 'bike'], 10)})
df3 = pd.DataFrame({'tsne_one': np.random.randn(10),
'tsne-two': np.random.randn(10),
'mode': np.random.choice(['foot', 'bus', 'metro', 'bike'], 10)})
df4 = pd.DataFrame({'tsne_one': np.random.randn(10),
'tsne-two': np.random.randn(10),
'mode': np.random.choice(['car', 'bus', 'metro', 'bike'], 10)})
modes = pd.concat([df['mode'] for df in (df1, df2, df3, df4)], ignore_index=True).unique()
colors = sns.color_palette('hls', len(modes))
palette = {mode: color for mode, color in zip(modes, colors)}
fig, axs = plt.subplots(2, 2, figsize=(12,6))
for df, ax in zip((df1, df2, df3, df4), axs.flatten()):
sns.scatterplot(x='tsne_one', y='tsne-two', hue='mode', data=df, palette=palette, legend='full', alpha=0.7, ax=ax)
plt.tight_layout()
plt.show()