Home > OS >  Stratified data splitting in R
Stratified data splitting in R

Time:11-25

I've been using caret::createDataPartition() in order to split the data in a stratified way. Now I'm trying another approach that I found here in stack, which is splitstackshape::stratified(), and the reason I'm intrested in this is that it allows to stratifiy based on features that I choose manually, very handy.

I have problem with splitting the data:

library(splitstackshape)

set.seed(40)
Train = stratified(Data, c('age','gender','treatment_1','treatment_2','cancers'), 0.75)

This produces the train set, but how do I get the test set? I didn't get it. I tired the createDataPartition way:

INDEX = stratified(Data, c('age','gender','treatment_1','treatment_2','cancers'), 0.75)
Train = Data[INDEX , ]
Test = Data[-INDEX ,]

But that doesn't work because stratified creates an actual train data, not an index.

So how do I get the test data using this function? thanks!

CodePudding user response:

If you add a unique sequential row identifier to the data, you can use it to extract the rows that were not selected for the training data frame as follows. We'll use mtcars for a reproducible example.

library(splitstackshape)
set.seed(19108379) # for reproducibility

# add a unique sequential ID to track rows in the sample, using mtcars

mtcars$rowId <- 1:nrow(mtcars)

# take a stratified sample by cyl

train <- stratified(mtcars,"cyl",size = 0.6)

test <- mtcars[!(mtcars$rowId %in% train$rowId),]

nrow(train)   nrow(test) # should add to 32 

...and the output:

> nrow(train)   nrow(test) # should add to 32 
[1] 32

Next level of detail...

The stratified() function extracts a set of rows based on the by groups passed to the function. By adding a rowId field we can track the observations that are included in the training data.

> # list the rows included in the sample
> train$rowId
 [1]  6 11 10  4  3 27 18  8  9 21 28 23 17 16 29 22 15  7 14
> nrow(train)
[1] 19

We then use the extract operator to create the test data frame via the ! operator:

> # illustrate the selection criteria used to extract rows not in the training data
> !(mtcars$rowId %in% train$rowId)
 [1]  TRUE  TRUE FALSE FALSE  TRUE FALSE FALSE FALSE FALSE FALSE FALSE  TRUE  TRUE FALSE
[15] FALSE FALSE FALSE FALSE  TRUE  TRUE FALSE FALSE FALSE  TRUE  TRUE  TRUE FALSE FALSE
[29] FALSE  TRUE  TRUE  TRUE
> 

Finally we count the number of rows to be included in the test data frame, given the selection criteria, which should equal 32 - 19 or 13:

> # count rows to be included in test data frame 
> sum(!(mtcars$rowId %in% train$rowId)) # should add to 13
[1] 13

Comparison to bothSets argument

Another answer noted that the stratified() function includes an argument, bothSets, that generates a list with both the sampled data and the remaining data. We can demonstrate equivalence of the two approaches as follows.

# alternative answer: use the package's bothSets argument
set.seed(19108379)
sampleData <- stratified(mtcars,"cyl",size = 0.6,bothSets = TRUE)

# compare rowIds in test vs. SAMP2 data frames
sampleData$SAMP2$rowId
test$rowId

...and the output:

> sampleData$SAMP2$rowId
 [1]  1  2  5 12 13 19 20 24 25 26 30 31 32
> test$rowId
 [1]  1  2  5 12 13 19 20 24 25 26 30 31 32
> 

A Final Comment

It's important to note that caret::createDataPartition() splits the data according to values of the dependent variable so the training and test partitions have equal representation across values of the dependent variable.

In contrast, stratified() partitions according to combinations of one or more features, i.e. the independent variables. Partitioning based on independent variables has the potential to introduce variability in the distributions of values of the dependent variable across the training and test partitions. That is, the distribution of dependent variable values in the training partition may be significantly different from the dependent variable distribution in the test partition.

CodePudding user response:

If you want to keep it in the package just add bothSets

library(splitstackshape)
stratified(mtcars,"am",size=0.75,bothSets=T)

which returns a list with both samples.

  • Related