Wondering how I can plot a seaborn plot onto a different matplotlib plot. Currently I have two plots (one a heatmap, the other a soccer pitch), but when I plot the heatmap onto the pitch, I get the results below. (Plotting the pitch onto the heatmap isn't pretty either.) Any ideas how to fix it?
Note: Plots don't need a colorbar and the grid structure isn't required either. Just care about the heatmap covering the entire space of the pitch. Thanks!
import pandas as pd
import numpy as np
from mplsoccer import Pitch
import seaborn as sns
nmf_shot_W = pd.read_csv('https://raw.githubusercontent.com/lucas-nelson-uiuc/datasets/main/nmf_show_W.csv').iloc[:, 1:]
nmf_shot_ThierryHenry = pd.read_csv('https://raw.githubusercontent.com/lucas-nelson-uiuc/datasets/main/nmf_show_Hth.csv')['Thierry Henry']
pitch = Pitch(pitch_type='statsbomb', line_zorder=2,
pitch_color='#22312b', line_color='#efefef')
dfdfdf = np.array(np.matmul(nmf_shot_W, nmf_shot_ThierryHenry)).reshape((24,25))
g_ax = sns.heatmap(dfdfdf)
pitch.draw(ax=g_ax)
Current output:
Desired output:
CodePudding user response:
pitch.heatmap
expects a stats
dictionary of binned data, bin mesh, and bin centers:
stats
(dict) – The keys arestatistic
(the calculated statistic),x_grid
andy_grid
(the bin's edges), andcx
andcy
(the bin centers).
In the mplsoccer heatmap demos, they construct this stats
object using pitch.bin_statistic
because they have raw data. However, you already have binned data ("calculated statistic"), so reconstruct the stats
object manually by building the mesh and centers:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mplsoccer import Pitch
nmf_shot_W = pd.read_csv('71878281/nmf_show_W.csv', index_col=0)
nmf_shot_ThierryHenry = pd.read_csv('71878281/nmf_show_Hth.csv')['Thierry Henry']
statistic = np.dot(nmf_shot_W, nmf_shot_ThierryHenry.to_numpy()).reshape((24, 25))
# construct stats object from binned data, bin mesh, and bin centers
y, x = statistic.shape
x_grid = np.linspace(0, 120, x 1)
y_grid = np.linspace(0, 80, y 1)
cx = x_grid[:-1] 0.5 * (x_grid[1] - x_grid[0])
cy = y_grid[:-1] 0.5 * (y_grid[1] - y_grid[0])
stats = dict(statistic=statistic, x_grid=x_grid, y_grid=y_grid, cx=cx, cy=cy)
# use pitch.draw and pitch.heatmap as per mplsoccer demo
pitch = Pitch(pitch_type='statsbomb', line_zorder=2, pitch_color='#22312b', line_color='#efefef')
fig, ax = pitch.draw(figsize=(6.6, 4.125))
pcm = pitch.heatmap(stats, ax=ax, cmap='plasma')
cbar = fig.colorbar(pcm, ax=ax, shrink=0.6)
cbar.outline.set_edgecolor('#efefef')
cbar.ax.yaxis.set_tick_params(color='#efefef')
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='#efefef')