I've been trying to do an overlay graph or a multi-plot grid of a filtered data from a pandas dataframe, but I've only been able to generate the plots separately. The code that generates the separated plots is the following:
# Get the class counts for all objects
class_counts = get_class_counts(clean_df, 0.4)
# Select the top 5 most common objects
top_5_class_counts = class_counts.head(5)
# Create a new dataframe
df_filtered = df[['image', 'class_name']]
# Merge the class counts dataframe with the dataframe containing the image_file and class_name columns
merged_df = df_filtered.merge(top_5_class_counts, on='class_name')
# Group the data by the class_name column
grouped_df = merged_df.groupby('class_name')
# Iterate over the groups and plot the histograms
for name, group in grouped_df:
# Count the number of times each image appears in the group and store the results
image_counts = group.groupby(['image', 'class_name']).size().reset_index(name='count')
# Create a histogram of the count column using seaborn's displot function
sns.displot(image_counts, x='count', kind='kde', multiple='stack')
plt.show()
Any help will be appreciate.
A snippet of the merged dataframe:
image | class_name | class_id | count |
---|---|---|---|
berl_000000.png | person | 0 | 1462 |
berl_000002.png | person | 0 | 1462 |
berl_000002.png | person | 0 | 1462 |
berl_000003.png | person | 0 | 1462 |
berl_000003.png | person | 0 | 1462 |
zur_000119.png | truck | 7 | 189 |
zur_000116.png | truck | 7 | 189 |
CodePudding user response:
Edit: Thanks for editing your question. You can accomplish the work in your for
loop without looping and create a new table which you can then plot from:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
##################
#Create fake data#
##################
num_classes = 8
counts_per_class = 1000
num_images = 100
class_names = np.repeat([f'class_{i 1}' for i in range(num_classes)],counts_per_class)
images = [f'image_{c 1}.png' for c in np.random.randint(0,num_images,num_classes*counts_per_class)]
#this table has a row for each class/image and can have duplicate rows like:
#this is just like your merged_df I think
# class_name image
# class_1 image_10.png
# class_1 image_11.png
# class_1 image_11.png
# class_1 image_12.png
# ... ...
df = pd.DataFrame({
'class_name':class_names,
'image':images,
})
#Do the work you're doing in your loop all at once w/out a loop
#this table has a unique row per class/image and a new count column
# class_name image count
# class_1 image_10.png 1
# class_1 image_11.png 2
# class_1 image_12.png 1
# ... ...
count_df = df.groupby(['class_name','image']).size().reset_index(name='count')
###################################
#Make the overlay displot with hue#
###################################
sns.displot(
x = 'count',
hue = 'class_name',
kind = 'kde',
data = count_df,
)
plt.show()
plt.close()
You can achieve the same plot with the for-loop if you want, but the workflow would be (1) create a new empty table, (2) loop through the classes in your loop, appending the per-class table to the new table (3) after looping plot