Home > Back-end >  How can I replicate plot.lda() with of R `tidymodels`
How can I replicate plot.lda() with of R `tidymodels`

Time:11-13

I would like to replicate the plot.lda print method using ggplot2 and tidymodels. Is there an elegant way to get the plot?

I think I can fake the augment() function, which does not have a lda method, by using predict() and bind it onto the original data.

Here is an example with the base R and tidymodels code:

library(ISLR2)
library(MASS)

# First base R
train <- Smarket$Year < 2005

lda.fit <-
  lda(
    Direction ~ Lag1   Lag2,
    data = Smarket,
    subset = train
  )

plot(lda.fit)


# Next tidymodels

library(tidyverse)
library(tidymodels)
library(discrim)

lda_spec <- discrim_linear() %>%
  set_mode("classification") %>%
  set_engine("MASS")

the_rec <- recipe(
  Direction ~ Lag1   Lag2, 
  data = Smarket
)

the_workflow<- workflow() %>% 
  add_recipe(the_rec) %>% 
  add_model(lda_spec)

Smarket_train <- Smarket %>%
  filter(Year != 2005)

the_workflow_fit_lda_fit <- 
  fit(the_workflow, data = Smarket_train) %>% 
  extract_fit_parsnip()

# now my attempt to do the plot

predictions <- predict(the_workflow_fit_lda_fit, 
                 new_data = Smarket_train, 
                 type = "raw"
         )[[3]] %>% 
  as.vector()

bind_cols(Smarket_train, .fitted = predictions) %>% 
  ggplot(aes(x=.fitted))  
  geom_histogram(aes(y = stat(density)),binwidth = .5)   
  scale_x_continuous(breaks = seq(-4, 4, by = 2)) 
  facet_grid(vars(Direction))  
  xlab("")   
  ylab("Density")

There must be a better way to do this.... thoughts?

CodePudding user response:

You can do this by using a combination of extract_fit_*() and parsnip:::repair_call(). The plot.lda() method uses the $call object in the LDA fit, which we need to adjust since the call object from using tidymodels will be different than using lda() directly.

library(ISLR2)
library(MASS)

# First base R
train <- Smarket$Year < 2005

lda.fit <-
  lda(
    Direction ~ Lag1   Lag2,
    data = Smarket,
    subset = train
  )

# Next tidymodels

library(tidyverse)
library(tidymodels)
library(discrim)

lda_spec <- discrim_linear() %>%
  set_mode("classification") %>%
  set_engine("MASS")

the_rec <- recipe(
  Direction ~ Lag1   Lag2, 
  data = Smarket
)

the_workflow <- workflow() %>% 
  add_recipe(the_rec) %>% 
  add_model(lda_spec)

Smarket_train <- Smarket %>%
  filter(Year != 2005)

the_workflow_fit_lda_fit <- 
  fit(the_workflow, data = Smarket_train)

After fitting both models, we can inspect the $call objects and we see that they are different.

lda.fit$call
#> lda(formula = Direction ~ Lag1   Lag2, data = Smarket, subset = train)

extract_fit_engine(the_workflow_fit_lda_fit)$call
#> lda(formula = ..y ~ ., data = data)

The parsnip::repair_call() function will replace data with the data we pass in. Additionally, we will rename the response of the data to ..y to match the call.

the_workflow_fit_lda_fit %>%
  extract_fit_parsnip() %>%
  parsnip::repair_call(rename(Smarket_train, ..y = Direction)) %>%
  extract_fit_engine() %>%
  plot()

Created on 2021-11-12 by the reprex package (v2.0.1)

  • Related