I would like to replicate the plot.lda print method using ggplot2
and tidymodels
. Is there an elegant way to get the plot?
I think I can fake the augment()
function, which does not have a lda method, by using predict()
and bind it onto the original data.
Here is an example with the base R and tidymodels
code:
library(ISLR2)
library(MASS)
# First base R
train <- Smarket$Year < 2005
lda.fit <-
lda(
Direction ~ Lag1 Lag2,
data = Smarket,
subset = train
)
plot(lda.fit)
# Next tidymodels
library(tidyverse)
library(tidymodels)
library(discrim)
lda_spec <- discrim_linear() %>%
set_mode("classification") %>%
set_engine("MASS")
the_rec <- recipe(
Direction ~ Lag1 Lag2,
data = Smarket
)
the_workflow<- workflow() %>%
add_recipe(the_rec) %>%
add_model(lda_spec)
Smarket_train <- Smarket %>%
filter(Year != 2005)
the_workflow_fit_lda_fit <-
fit(the_workflow, data = Smarket_train) %>%
extract_fit_parsnip()
# now my attempt to do the plot
predictions <- predict(the_workflow_fit_lda_fit,
new_data = Smarket_train,
type = "raw"
)[[3]] %>%
as.vector()
bind_cols(Smarket_train, .fitted = predictions) %>%
ggplot(aes(x=.fitted))
geom_histogram(aes(y = stat(density)),binwidth = .5)
scale_x_continuous(breaks = seq(-4, 4, by = 2))
facet_grid(vars(Direction))
xlab("")
ylab("Density")
There must be a better way to do this.... thoughts?
CodePudding user response:
You can do this by using a combination of extract_fit_*()
and parsnip:::repair_call()
. The plot.lda()
method uses the $call
object in the LDA fit, which we need to adjust since the call object from using tidymodels will be different than using lda()
directly.
library(ISLR2)
library(MASS)
# First base R
train <- Smarket$Year < 2005
lda.fit <-
lda(
Direction ~ Lag1 Lag2,
data = Smarket,
subset = train
)
# Next tidymodels
library(tidyverse)
library(tidymodels)
library(discrim)
lda_spec <- discrim_linear() %>%
set_mode("classification") %>%
set_engine("MASS")
the_rec <- recipe(
Direction ~ Lag1 Lag2,
data = Smarket
)
the_workflow <- workflow() %>%
add_recipe(the_rec) %>%
add_model(lda_spec)
Smarket_train <- Smarket %>%
filter(Year != 2005)
the_workflow_fit_lda_fit <-
fit(the_workflow, data = Smarket_train)
After fitting both models, we can inspect the $call
objects and we see that they are different.
lda.fit$call
#> lda(formula = Direction ~ Lag1 Lag2, data = Smarket, subset = train)
extract_fit_engine(the_workflow_fit_lda_fit)$call
#> lda(formula = ..y ~ ., data = data)
The parsnip::repair_call()
function will replace data
with the data we pass in. Additionally, we will rename the response of the data to ..y
to match the call.
the_workflow_fit_lda_fit %>%
extract_fit_parsnip() %>%
parsnip::repair_call(rename(Smarket_train, ..y = Direction)) %>%
extract_fit_engine() %>%
plot()
Created on 2021-11-12 by the reprex package (v2.0.1)