I've got some data, grouped by category (i.e. "a","b","c" etc), and I'd like to draw lines between each pair of points within each category.
Basically, each category has a "before" and "after" value, so I've split it that way with hue. This is the plot now, but eventually I want each "before" and "after" value for a given category to be joined with a line (i.e. a_before joins to a_after, b_before joints to b_after, etc).
sns.pointplot (x = ‘category’, y = ‘correlation’,
hue = ‘time’, linestyles = ‘’, dodge = .3, data = sample_data)
I set linestyles to '' because otherwise it joins all the points rather than only the paired points.Is there a way to do this with seaborn?
Thanks!
edit: I'd like it to look something like this:
(I set linestyles to '' because otherwise it joins all the points rather than only the paired points.)
CodePudding user response:
Matplotlib stores the generated points into the lines
field of the ax
.
sns.pointplot()
always generates (possibly empty) confidence intervals which also get stored into the lines. The same positions are also stored in ax.collections
.
You can loop through collections[0]
and collections[1]
to access the exact position of the (dodged) points. Then, you can draw lines between them:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sample_data = pd.DataFrame({'category': ['a', 'a', 'b', 'b', 'c', 'c'],
'correlation': [0.33, 0.58, 0.51, 0.7, 0.49, 0.72],
'time': ['before', 'after', 'before', 'after', 'before', 'after']})
ax = sns.pointplot(x='category', y='correlation', hue='time', palette=['skyblue', 'dodgerblue'],
linestyles='', dodge=.3, data=sample_data)for (x0, y0), (x1, y1) in zip(ax.collections[0].get_offsets(), ax.collections[1].get_offsets()):
ax.plot([x0, x1], [y0, y1], color='black', ls=':', zorder=0)
ax.axhline(0, color='black', ls='--')
ax.set_ylim(-1, 1)
plt.show()