Until now I worked in the R environment and create the chart as the chart shown below.
library("rJava")
library("xlsxjars")
library("xlsx")
require(tidyr)
library("ggplot2")
df1<-data.frame(structure(list(Groups = structure(c(1L, 1L, 1L, 1L, 1L, 1L, 1L,
1L, 1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 3L,
3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L), .Label = c("Group_1", "Group_2",
"Group_3"), class = "factor"), Country = structure(c(2L, 3L,
6L, 7L, 11L, 8L, 15L, 12L, 29L, 10L, 4L, 16L, 5L, 17L, 18L, 19L,
13L, 20L, 21L, 1L, 24L, 25L, 26L, 28L, 27L, 9L, 30L, 31L, 14L,
23L, 22L), .Label = c("Austria", "Belgium", "Bulgaria", "Croatia",
"Cyprus", "Czechia", "Denmark", "Estonia", "Finland", "France",
"Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy",
"Latvia", "Lithuania", "Luxembourg", "Malta", "Netherlands",
"North Macedonia", "Norway", "Poland", "Portugal", "Romania",
"Slovakia", "Slovenia", "Spain", "Sweden", "United Kingdom"), class = "factor"),
PCT_Direct_Tax = c(17.663, 6.048, 8.038, 29.092, 13.517,
7.409, 10.829, 10.155, 11.05, 13.904, 6.477, 14.184, 9.109,
7.41, 5.674, 16.528, 6.735, 13.56, 12.699, 13.562, 7.848,
10.145, 4.939, 7.866, 7.267, 16.231, 18.622, 14.213, 18.887,
17.741, 4.5), income_tax = c(104.760235996058, 101.665237369244,
113.316015798358, 103.222974550114, 90.0358198452755, 84.2840209478504,
80.9471275560435, 95.9558367315368, 98.6546600681184, 112.105602072378,
102.667541663828, 90.3520755939945, 116.772956715611, 101.484483234044,
93.3312408112594, 105.91143369624, 81.4327609978202, 96.9090350595422,
108.112192891704, 93.6795560712191, 101.587464123904, 108.80182909181,
86.071431257456, 87.5336162400315, 86.9970564933148, 105.33998752386,
100.568081949912, 94.8944733487053, 98.6353656794101, 83.7358939227726,
92.3836627060381), EU_28 = c(1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
0, 0)), class = "data.frame", row.names = c(NA, -31L)))
ggplot(df1, aes(x=PCT_Direct_Tax, y=income_tax))
geom_point(aes(size=5), shape=21, fill="lightblue")
geom_text(label=df1$Country, nudge_x = 0, nudge_y = -1, check_overlap=T)
geom_hline(yintercept=median(df1$income_tax), linetype="dashed", color = "black")
geom_vline(xintercept=median(df1$PCT_Direct_Tax), linetype="dashed", color = "black")
geom_text(aes(x=median(median(df1$PCT_Direct_Tax)),y=max(df1$income_tax)), label="Median",hjust=1, size=4)
geom_text(aes(y=median(median(df1$income_tax)),x=max(df1$PCT_Direct_Tax)), label="Median",hjust=1, size=4)
theme_minimal()
Now I want to start to work with Python and I want to make the same chart but with Matplotlib. I have already converted data into Python syntax, but I am not so proficient user of Matplotlib.
data = {
'Country': ["Austria", "Belgium", "Bulgaria", "Croatia",
"Cyprus", "Czechia", "Denmark", "Estonia", "Finland", "France",
"Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy",
"Latvia", "Lithuania", "Luxembourg", "Malta", "Netherlands",
"North Macedonia", "Norway", "Poland", "Portugal", "Romania",
"Slovakia", "Slovenia", "Spain", "Sweden", "United Kingdom"],
'PCT_Direct_Tax':[17.663, 6.048, 8.038, 29.092, 13.517,
7.409, 10.829, 10.155, 11.05, 13.904, 6.477, 14.184, 9.109,
7.41, 5.674, 16.528, 6.735, 13.56, 12.699, 13.562, 7.848,
10.145, 4.939, 7.866, 7.267, 16.231, 18.622, 14.213, 18.887,
17.741, 4.5],
'income_tax':[104.760235996058, 101.665237369244,
113.316015798358, 103.222974550114, 90.0358198452755, 84.2840209478504,
80.9471275560435, 95.9558367315368, 98.6546600681184, 112.105602072378,
102.667541663828, 90.3520755939945, 116.772956715611, 101.484483234044,
93.3312408112594, 105.91143369624, 81.4327609978202, 96.9090350595422,
108.112192891704, 93.6795560712191, 101.587464123904, 108.80182909181,
86.071431257456, 87.5336162400315, 86.9970564933148, 105.33998752386,
100.568081949912, 94.8944733487053, 98.6353656794101, 83.7358939227726,
92.3836627060381],
}
So can anybody help me with how to make this chart in Matplotlib
?
CodePudding user response:
This is what i got using your provided data, not exactly the same
# Importing libraries
import matplotlib.pyplot as plt
import numpy as np
# plotting scatter plot
plt.figure(figsize=(15,10))
plt.scatter(data["PCT_Direct_Tax"], data["income_tax"])
plt.axvline(np.quantile(data["PCT_Direct_Tax"], 0.5), c='b', label='Median PCT_Direct_Tax')
plt.axhline(np.quantile(data["income_tax"], 0.5), c='r', label='Median income_tax')
# Loop for annotation of all points
for i in range(len(data["income_tax"])):
plt.annotate(data["Country"][i], (data["PCT_Direct_Tax"][i], data["income_tax"][i] 0.2))
plt.xlabel("PCT_Direct_Tax")
plt.ylabel("income_tax")
plt.legend()
plt.show()