Home > Enterprise >  Join paired points within each category in seaborn pointplot
Join paired points within each category in seaborn pointplot

Time:12-21

I've got some data, grouped by category (i.e. "a","b","c" etc), and I'd like to draw lines between each pair of points within each category.

Basically, each category has a "before" and "after" value, so I've split it that way with hue. This is the plot now, but eventually I want each "before" and "after" value for a given category to be joined with a line (i.e. a_before joins to a_after, b_before joints to b_after, etc).

sns.pointplot (x = ‘category’, y = ‘correlation’, 
    hue = ‘time’, linestyles = ‘’, dodge = .3, data = sample_data)

point plot

I set linestyles to '' because otherwise it joins all the points rather than only the paired points.Is there a way to do this with seaborn?

Thanks!

edit: I'd like it to look something like this:

enter image description here

(I set linestyles to '' because otherwise it joins all the points rather than only the paired points.)

CodePudding user response:

Matplotlib stores the generated points into the lines field of the ax. sns.pointplot() always generates (possibly empty) confidence intervals which also get stored into the lines. The same positions are also stored in ax.collections.

You can loop through collections[0] and collections[1] to access the exact position of the (dodged) points. Then, you can draw lines between them:

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sample_data = pd.DataFrame({'category': ['a', 'a', 'b', 'b', 'c', 'c'],
                            'correlation': [0.33, 0.58, 0.51, 0.7, 0.49, 0.72],
                            'time': ['before', 'after', 'before', 'after', 'before', 'after']})
ax = sns.pointplot(x='category', y='correlation', hue='time', palette=['skyblue', 'dodgerblue'],
                   linestyles='', dodge=.3, data=sample_data)for (x0, y0), (x1, y1) in zip(ax.collections[0].get_offsets(), ax.collections[1].get_offsets()):
    ax.plot([x0, x1], [y0, y1], color='black', ls=':', zorder=0)
ax.axhline(0, color='black', ls='--')
ax.set_ylim(-1, 1)
plt.show()

sns.pointplot connecting different hues

  • Related