I have a custom parallel training function that takes pairs of train and test data and builds different model for different data. The problem is an array can not seem to store the following kind of data. How do I create a list that can hold the following kind of data.
for i in range(0,5):
def create_dataset():
...
...
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.cache().shuffle(buffer_size).batch(batch_size).repeat()
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_data = test_data.batch(batch_size).repeat()
return train_data,test_data
td[i],vd[i] = create_dataset()
model = create_model() # create the model
datasets = [(td[0],vd[0]),(td[1],vd[1]),(td[2],vd[3]),(td[3],vd[3]),(td[4],vd[4])]
parallel_trainer(model, datasets)
The parameters of the parallel trainer are defined like this,
def parallel_trainer(model, XY_train_datasets : list[tuple])
Defining my "datasets" like this, returns an error,
TypeError: 'type' object is not subscriptable
How do I create a list of my train data and test data so that this error is resolved. The solution maybe obvious, but i am fairly new to this.
Thanks in advance.
CodePudding user response:
To annotate types for builtin lists you need to be using python >= 3.9
You can fix this by using hinting for List
which is imported from typing import List
What it thinks you are doing it typing to access an index of list
rather than annotating it's contents