Home > Software engineering >  Custom xticks in seaborn heatmap
Custom xticks in seaborn heatmap

Time:02-04

I have the following heatmap (just a minimum working example, my data is huge!)

df = pd.DataFrame({'set1': ['0', '2', '2'],
                   'set2': ['1', '2', '0'],
                   'set3': ['0', '2', '1'],
                   'set4': ['1', '4', '1']
                  }).T.astype(float)

sns.heatmap(df, yticklabels = df.index, xticklabels = df.columns) 

How can I put the xticks only of those where all the rows>=2? In this example it means putting only the '1' xtick only

So in this image '0' and '2' column names should not appear, only '1' because this is the column where all values are greater or equal than 2.

Problem is the the x axis gets too crowded. I want to show only the xticklabels of the columns where all values are >= 2. Still plotting everything but showing only those xticklabels.

enter image description here

CodePudding user response:

Are you looking to show the same heatmap, but only show xticklabels where ALL values are >=2? One way to do this might be to not use df.columns in heatmap, but mask and show only the ones you want. See if this is what you are looking for...

df = pd.DataFrame({'set1': ['0', '2', '2'],
                   'set2': ['1', '2', '0'],
                   'set3': ['0', '2', '1'],
                   'set4': ['1', '4', '1']
                  }).T.astype(float)

cols = [] ## Declare a new list to be used for xticklabels
for col in df.columns:
    if col in set(df.columns).intersection(df[df >= 2].T.dropna().index):
        cols.append(col) ##If all values are >=2, then add to label list
    else:
        cols.append('') ## Else, add blank

sns.heatmap(df, yticklabels = df.index, xticklabels = cols) ##Plot using new list

enter image description here

CodePudding user response:

Mask the DataFrame

  • Removes columns where the values are not >= to the specified value
# create a Boolean mask of df
mask = df.ge(2)

# apply the mask to df and dropna
dfm = df[mask].dropna(axis=1)

# plot the masked df
ax = sns.heatmap(dfm)

enter image description here

mask

          0     1      2
set1  False  True   True
set2  False  True  False
set3  False  True  False
set4  False  True  False

Mask the xtick labels

  • Labels to columns where the values are not >= to the specified value are replaced with ''.
# create a Boolean mask of df
mask = df.ge(2).all()

# use the mask to update a list of labels
cols = [col if m else '' for (col, m) in zip(df.columns, mask)]

# plot with custom labels
ax = sns.heatmap(df, xticklabels=cols)

enter image description here

mask

0    False
1     True
2    False
dtype: bool
  • Related