I generated some data to perfome a regression on it:
library(tidyverse)
library(nnet)
# Generating the data --------------------------
set.seed(100)
helicopter <- rnorm(20, mean = 35, sd = 3)
car <- rnorm(20, mean = 30, sd = 3)
bus <- rnorm(20, mean = 25, sd = 3)
bike <- rnorm(20, mean = 20, sd = 3)
transportation_data <- data.frame(helicopter, car, bus, bike) %>%
pivot_longer(cols = 1:4, values_to = "income", names_to = "mode")
# Setting up the regression -------------------
transportation_regression <- multinom(mode~income, data = transportation_data)
So far, so good. I now want to plot the regression results (probability of choosing a certain mode of transportation based on income) using stat_function
:
ins <- coef(transportation_regression)[1:3]
betas <- coef(transportation_regression)[4:6]
transportation_data %>%
ggplot(aes(x = income))
stat_function(fun = function(x) { 1 / (1 sum(exp(ins betas * x))) }, aes(color = "bike"))
stat_function(fun = function(x) { exp(ins[1] betas[1] * x) / (1 sum(exp(ins betas * x))) }, aes(color = "bus"))
stat_function(fun = function(x) { exp(ins[2] betas[2] * x) / (1 sum(exp(ins betas * x))) }, aes(color = "car"))
stat_function(fun = function(x) { exp(ins[3] betas[3] * x) / (1 sum(exp(ins betas * x))) }, aes(color = "helicopter"))
I get this output, which is obviously wrong, and a warning Warning: longer object length is not a multiple of shorter object length
where I don't know what it means.
When I use the same functions, but predict data points first, everything works just fine:
income <- seq(0,50,0.1)
result <- matrix( , nrow = length(income), ncol = 4)
i <- 1
for(x in income){
result[i,1] <- 1 / (1 sum(exp(ins betas * x))) # bike
result[i,2] <- exp(ins[1] betas[1] * x) / (1 sum(exp(ins betas * x))) # bus
result[i,3] <- exp(ins[2] betas[2] * x) / (1 sum(exp(ins betas * x))) # car
result[i,4] <- exp(ins[3] betas[3] * x) / (1 sum(exp(ins betas * x))) # helicopter
i <- i 1
}
cbind(income, as.data.frame(result)) %>%
pivot_longer(cols = V1:V4) %>%
ggplot(aes(x = income, y = value, color = name))
geom_line()
Why don't the stat_function()
in ggplot
work?
CodePudding user response:
I think it's just a misunderstanding of how the function works. Here's an example of using stat_function()
to generate the right result:
library(tidyverse)
library(nnet)
# Generating the data --------------------------
set.seed(100)
helicopter <- rnorm(20, mean = 35, sd = 3)
car <- rnorm(20, mean = 30, sd = 3)
bus <- rnorm(20, mean = 25, sd = 3)
bike <- rnorm(20, mean = 20, sd = 3)
transportation_data <- data.frame(helicopter, car, bus, bike) %>%
pivot_longer(cols = 1:4, values_to = "income", names_to = "mode")
# Setting up the regression -------------------
transportation_regression <- multinom(mode~income, data = transportation_data)
#> # weights: 12 (6 variable)
#> initial value 110.903549
#> iter 10 value 48.674542
#> iter 20 value 46.980349
#> iter 30 value 46.766625
#> iter 40 value 46.734782
#> iter 50 value 46.732249
#> final value 46.732163
#> converged
ins <- coef(transportation_regression)[1:3]
betas <- coef(transportation_regression)[4:6]
transportation_data %>%
ggplot(aes(x = income))
stat_function(fun = function(x) { 1 / (1 exp(ins[1] betas[1] * x) exp(ins[2] betas[2] * x) exp(ins[3] betas[3] * x)) }, aes(color = "bike"))
stat_function(fun = function(x) { exp(ins[1] betas[1] * x) / (1 exp(ins[1] betas[1] * x) exp(ins[2] betas[2] * x) exp(ins[3] betas[3] * x)) }, aes(color = "bus"))
stat_function(fun = function(x) { exp(ins[2] betas[2] * x) / (1 exp(ins[1] betas[1] * x) exp(ins[2] betas[2] * x) exp(ins[3] betas[3] * x)) }, aes(color = "car"))
stat_function(fun = function(x) { exp(ins[3] betas[3] * x) / (1 exp(ins[1] betas[1] * x) exp(ins[2] betas[2] * x) exp(ins[3] betas[3] * x)) }, aes(color = "helicopter"))
There were a couple of problems originally. Take, for example, the first instance of stat_function()
,
stat_function(fun = function(x) {
1 / (1 sum(exp(ins betas * x))) },
aes(color = "bike"))
You're expecting ins betas * x
to be equivalent to ins[1] betas[1] * x ins[2] betas[2] * x ins[3] betas[3] * x
, but it isn't essentially recycling ins
and betas
to make them vectors as long as x
and then multiplying betas
by x
and adding ins
.
The other problem was the sum()
around exp(ins ...)
Rather than summing the rows, it's summing all rows and columns of the output, making a scalar value.
You could also make it a bit more general using matrix calculations:
b <- coef(transportation_regression)
transportation_data %>%
ggplot(aes(x = income))
stat_function(fun = function(x) { 1 / (1 rowSums(exp(cbind(1, x) %*% t(b)))) }, aes(color = "bike"))
stat_function(fun = function(x) { exp(ins[1] betas[1] * x) / (1 rowSums(exp(cbind(1, x) %*% t(b)))) }, aes(color = "bus"))
stat_function(fun = function(x) { exp(ins[2] betas[2] * x) / (1 rowSums(exp(cbind(1, x) %*% t(b)))) }, aes(color = "car"))
stat_function(fun = function(x) { exp(ins[3] betas[3] * x) / (1 rowSums(exp(cbind(1, x) %*% t(b)))) }, aes(color = "helicopter"))
Created on 2023-02-04 by the reprex package (v2.0.1)