I want to make a plot in seaborn but I am having some difficulties. The data has 2 variable: time (2 levels) and state (2 levels). I want to plot time on the x axis and state as different subplots, showing individual data lines. Finally, to the right of these I want to show a difference plot of the difference between time 2 and time 1, for each of the levels of state. I cannot do it very well, because I cannot get the second plot to show onto the right. Here has been my try:
import numpy as np
import pandas as pd
import seaborn as sns
# Just making some fake data
ids = [1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5]
times = [1,1,2,2,1,1,2,2,1,1,2,2,1,1,2,2,1,1,2,2]
states = ['A', 'B', 'A', 'B'] * 5
np.random.seed(121)
resps = [(i*t) np.random.normal() for i, t in zip(ids, times)]
DATA = {
'identity': ids,
'time': times,
'state': states,
'resps': resps
}
df = pd.DataFrame(DATA)
# Done with data
g = sns.relplot(
data=df, kind='line',
col='state', x='time', y='resps', units='identity',
estimator=None, alpha=.5, height=5, aspect=.7)
# # Draw a line onto each Axes
g.map(sns.lineplot,"time", "resps", lw=5, ci=None)
# Make a wide data to make the difference
wide = df.set_index(['identity', 'state', 'time']).unstack().reset_index()
A = wide['state']=='A'
B = wide['state']=='B'
wide['diffA'] = wide[A][('resps', 2)] - wide[A][('resps', 1)]
wide['diffB'] = wide[B][('resps', 2)] - wide[B][('resps', 1)]
wide['difference'] = wide[['diffA', 'diffB']].sum(axis=1)
wide = wide.drop(columns=[('diffA', ''), ('diffB', '')])
sns.pointplot(x='state', y='difference', data=wide, join=False)
Is there no way to put them together? Even though they are different data? I did try to use matplotlib. And then achieved slightly better results but this still had a problem because I wanted the two left plots to have a shared y axis but not the difference. This created lots of work as well, because I want to be flexible for different numbers of the state
variable, but only kept to 2 for simplicity. Here is a paint version of what I want to do (sorry for the poor quality), hopefully with some more control over appearance but this is secondary:
Is there a reliable way to do this in a simpler way? Thanks!
CodePudding user response:
The problem is that sns.relplot
operates at a figure level. This means it creates its own figure object and we cannot control the axes it uses. If you want to leverage seaborn for the creation of the lines without using "pure" matplotlib, you can copy the lines on matplotlib axes:
import numpy as np
import pandas as pd
import seaborn as sns
# Just making some fake data
ids = [1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5]
times = [1,1,2,2,1,1,2,2,1,1,2,2,1,1,2,2,1,1,2,2]
states = ['A', 'B', 'A', 'B'] * 5
np.random.seed(121)
resps = [(i*t) np.random.normal() for i, t in zip(ids, times)]
DATA = {
'identity': ids,
'time': times,
'state': states,
'resps': resps
}
df = pd.DataFrame(DATA)
# Done with data
g = sns.relplot(
data=df, kind='line',
col='state', x='time', y='resps', units='identity',
estimator=None, alpha=.5, height=5, aspect=.7)
# # Draw a line onto each Axes
g.map(sns.lineplot,"time", "resps", lw=5, ci=None)
# Make a wide data to make the difference
wide = df.set_index(['identity', 'state', 'time']).unstack().reset_index()
A = wide['state']=='A'
B = wide['state']=='B'
wide['diffA'] = wide[A][('resps', 2)] - wide[A][('resps', 1)]
wide['diffB'] = wide[B][('resps', 2)] - wide[B][('resps', 1)]
wide['difference'] = wide[['diffA', 'diffB']].sum(axis=1)
wide = wide.drop(columns=[('diffA', ''), ('diffB', '')])
# New code ----------------------------------------
import matplotlib.pyplot as plt
plt.close(g.figure)
fig = plt.figure(figsize=(12, 4))
ax1 = fig.add_subplot(1, 3, 1)
ax2 = fig.add_subplot(1, 3, 2, sharey=ax1)
ax3 = fig.add_subplot(1, 3, 3)
l = list(g.axes[0][0].get_lines())
l2 = list(g.axes[0][1].get_lines())
for ax, g_ax in zip([ax1, ax2], g.axes[0]):
l = list(g_ax.get_lines())
for line in l:
ax.plot(line.get_data()[0], line.get_data()[1], color=line.get_color(), lw=line.get_linewidth())
ax.set_title(g_ax.get_title())
sns.pointplot(ax=ax3, x='state', y='difference', data=wide, join=False)
# End of new code ----------------------------------
plt.show()