The problem I am facing is that I have a dataframe - sector_features_ which looks like this:
After running tsne on it I then have a 2d df which I plot with a scatter graph. The problem is that I don't know how to color the scatter points with the original label information contained in the index that seen in picture 1, which is a tuple which contains the {country} and {year} the observation belongs to. I would ideally like to color according to country only or year only to see how this changes the visualisation.
The data frame containing the reduced data (tsne) looks like
I am using matplotlib and seaborn, but have seen some solutions using altair and I am not sure how to proceed.
the imports are:
import pandas as pd
import numpy as np
import random as rd
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
from sklearn import preprocessing
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
CodePudding user response:
Looking at your screenshot, I am guessing you have a multiindex data frame. Using an example dataset :
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.datasets import make_blobs
data, _ = make_blobs(n_samples=100,n_features=5,centers=4,cluster_std=3.5)
data = pd.DataFrame(data)
data['country'] = np.repeat(['A','B','C','D'],25)
data['year'] = np.repeat(np.arange(1,26),4)
data = data.set_index(['country','year'])
The index looks like this:
data.index[:10]
MultiIndex([('A', 1),
('A', 1),
('A', 1),
('A', 1),
('A', 2),
('A', 2),
('A', 2),
('A', 2),
('A', 3),
('A', 3)],
names=['country', 'year'])
Perform tsne :
X_embedded = TSNE(n_components=2,init='pca',learning_rate='auto').fit_transform(data.values)
You can basically do a reset_index() and extract the 2 index columns, in my case it will be ['country','year']
, and concat with your tsne results :
tsne_result = pd.concat([
data.reset_index()[['country','year']],
pd.DataFrame(X_embedded,columns=['tsne1','tsne2'])
],axis=1)
And plot :
sns.scatterplot(data = tsne_result, x = "tsne1", y = "tsne2",hue = "country")