Home > front end >  How to split data into train, val and test?
How to split data into train, val and test?

Time:10-30

I am trying to split my cats and dogs dataset into train, validation and test sets with a 0.8, 0.1, 0.1 split, I've made this function to do it.

def split_data(main_dir, training_dir, validation_dir, test_dir):
    """
    Splits the data into train and test sets

    Args:
    main_dir (string):  path containing the images
    training_dir (string):  path to be used for training
    validation_dir (string):  path to be used for validation
    test_dir (string): path to be used for testing
    split_size (float): size of the dataset to be used for training
    """
    files = []
    for file in os.listdir(main_dir):
        if  os.path.getsize(os.path.join(main_dir, file)): # check if the file's size isn't 0
            files.append(file) # appends file name to a list

    shuffled_files = random.sample(files,  len(files)) # shuffles the data
    split = int(0.8 * len(shuffled_files)) #the training split casted into int for numeric rounding
    test_split = int(0.9 * len(shuffled_files))#the test split
    
    train = shuffled_files[:split] #training split
    validation = shuffled_files[split:test_split] # validation split
    test = shuffled_files[test_split:]
    
    for element in train:
            copyfile(os.path.join(main_dir,  element), os.path.join(training_dir, element)) 

    for element in validation:
        copyfile(os.path.join(main_dir,  element), os.path.join(validation_dir, element))
        
    for element in test:
        copyfile(os.path.join(main_dir,  element), os.path.join(validation_dir, element))

Heres the function call:

split_data(CAT_DIR, '/tmp/cats-v-dogs/training/cats','/tmp/cats-v-dogs/validation/cats', '/tmp/cats-v-dogs/testing/cats')
split_data(DOG_DIR, '/tmp/cats-v-dogs/training/dogs', '/tmp/cats-v-dogs/validation/dogs', '/tmp/cats-v-dogs/testing/dogs')

And here I list the directory lengths

print(len(os.listdir('/tmp/cats-v-dogs/training/cats')))
print(len(os.listdir('/tmp/cats-v-dogs/training/dogs')))

print(len(os.listdir('/tmp/cats-v-dogs/validation/cats')))
print(len(os.listdir('/tmp/cats-v-dogs/validation/dogs')))

print(len(os.listdir('/tmp/cats-v-dogs/testing/cats')))
print(len(os.listdir('/tmp/cats-v-dogs/testing/dogs')))

Which gives the output

10000

10000

2500

2500

0

0

The function is splitting the data 0.8, 0.2 into the train and validation sets but I need it 0.8, 0.1, 0.1, for train val and test sets but I don't know where I'm going wrong.

CodePudding user response:

why dont you use the tran_test_split function in sklearn its so fast and useful

CodePudding user response:

You are going wrong in the lines:

for element in test:
        copyfile(os.path.join(main_dir,  element), os.path.join(validation_dir, element))

It should be test_dir not validation_dir ...

  • Related