Home > database >  How to select a especific number of digits from mnist in R?
How to select a especific number of digits from mnist in R?

Time:04-30

I need to select 100 examples of each digit from MNIST database. I've tried the following code, but instead of giving me the 100 examples of 0, for example, it gives me a large integer

library(keras)
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y

d <- x_train[sample(y_train == 0, 100, replace = FALSE)]

Can someone help me, please?

CodePudding user response:

You've run into a classic R gotcha. R recycles logical vectors when they're used to index an array. That's why indexing a large array with a logical vector of length 100, as you did, leads to a large output array.

set.seed(42)
library(keras)
mnist <- dataset_mnist()
#> Loaded Tensorflow version 2.8.0
x_train <- mnist$train$x
y_train <- mnist$train$y

## create a logical index for x_train, length 100
ind_logical <- sample(y_train == 0, 100, replace = FALSE)

## subset x_train with this length 100 index
d <- x_train[ind_logical]

length(ind_logical)
#> [1] 100
## good so far...

length(d)
#> [1] 2352000
## wtf?


sum(ind_logical)
#> [1] 5
## this is how many elements are TRUE, and should be extracted, wtf???

d_rep <- x_train[rep(ind_logical, length.out = length(x_train))]
## this is what R is doing internally, to be "helpful"

identical(d_rep, d)
#> [1] TRUE
## it's recycled the logical vector to be the same length as the input array

A safer way to index arrays is often to use numeric indices, as shown here:

ind_number <- sample(which(y_train == 0), 100, replace = FALSE)

length(x_train[ind_number])
#> [1] 100

CodePudding user response:

You can filter the data based on your condition and use the slice_sample to sample and select 100 values. You can use the following code:

library(dplyr)
d <- x_train %>% 
  as.data.frame() %>% 
  filter(y_train == 0) %>% 
  slice_sample(n = 100)

Output:

  • Related