I want to use tfrecord to deal with heavy MRI images but I don't know how to. Below is my code, the error and data link. (Sorry if you find the code is a bit long).
About the data:
- 484 training images, each has a shape of (240, 240, 155, 4), these 4 numbers are the height, width, number of layers and channels respectively.
- 484 labels, each has a shape of (240, 240, 155)
First I rearrange my data,
image_data_path = './drive/MyDrive/Brain Tumour/Task01_BrainTumour/imagesTr/'
label_data_path = './drive/MyDrive/Brain Tumour/Task01_BrainTumour/labelsTr/'
image_paths = [image_data_path name
for name in os.listdir(image_data_path)
if not name.startswith(".")]
label_paths = [label_data_path name
for name in os.listdir(label_data_path)
if not name.startswith(".")]
image_paths = sorted(image_paths)
label_paths = sorted(label_paths)
And define a function to load 1 nii file. I use nibabel.
def load_one_sample(image_path, label_path):
image = nib.load(image_path).get_fdata()
label = nib.load(label_path).get_fdata().astype(int) # the original dtype is float64
return image, label
Here I write some helper functions, 'float' for images and 'int' for labels:
def float_feature(value):
return tf.train.Feature(float_list = tf.train.FloatList(value = value))
def int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = value))
def create_example(image_path, label_path):
image, label = load_one_sample(image_path, label_path)
image, label = image.ravel(), label.ravel()
feature = {'image': float_feature(image),
'label': int64_feature(label)}
example = tf.train.Example(features = tf.train.Features(feature = feature))
return example
def parse_tfrecord(example):
feature = {'image': tf.io.FixedLenFeature([240, 240, 155, 4], tf.float32),
'label': tf.io.FixedLenFeature([240, 240, 155], tf.int64)}
parsed_example = tf.io.parse_single_example(example, feature)
return parsed_example
Then start to convert to and read tfrecord with only one example:
test_writer = tf.io.TFRecordWriter('test.tfrecords')
example = create_example(image_paths[0], label_paths[0])
test_writer.write(example.SerializeToString())
serialised_example = tf.data.TFRecordDataset('test.tfrecords')
parsed_example = serialised_example.map(parse_tfrecord)
Finally I try plotting one image and I got this error message:
for features in parsed_example.take(1):
plt.imshow(features['image'][:, :, 100, 0])
Error: truncated record at 0' failed with Read less bytes than requested [Op:IteratorGetNext]