You'll need this notebook to reproduce the error which downloads the files below and runs the exact same code following the description.
labels.csv
: each row containsx0
,y0
,x1
,y1
text coordinates, and other columns not affecting the outcome.yolo-train-0.tfrecord
: Contains 90% of the examples found inlabels.csv
. Each example contains all labels/rows corresponding to the image in the example.
I'm experiencing a recurring error that happens when iterating over a tfrecord dataset. After 2000-4000 iterations that successfully read batches from the dataset, I get the following error:
iteration: 3240 2022-02-14 04:25:15.376625: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at scatter_nd_op.cc:219 : INVALID_ARGUMENT: indices[189] = [6, 30, 38, 0] does not index into shape [8,38,38,3,6]
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/iterator_ops.py", line 800, in __next__
return self._next_internal()
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/iterator_ops.py", line 786, in _next_internal
output_shapes=self._flat_output_shapes)
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/gen_dataset_ops.py", line 2845, in iterator_get_next
_ops.raise_from_not_ok_status(e, name)
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py", line 7107, in raise_from_not_ok_status
raise core._status_to_exception(e) from None # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[189] = [6, 30, 38, 0] does not index into shape [8,38,38,3,6]
[[{{function_node __inference_transform_targets_for_output_1051}}{{node TensorScatterUpdate}}]] [Op:IteratorGetNext]
It is near impossible to tell which exact inputs that are causing the issue thanks to tensorflow's brilliant graph execution. I tried using pdb
, tf.print
statements and many other desperate measures trying to identify which examples in labels.csv
that cause the problem and need to be excluded, and nothing looks particularly suspicious.
Here's what the notebook runs and eventually results in the error mentioned.
import numpy as np
import pandas as pd
import tensorflow as tf
def transform_images(x, image_shape):
x = tf.image.resize(x, image_shape)
return x / 255
@tf.function
def transform_targets_for_output(y_true, grid_size, anchor_indices):
n = tf.shape(y_true)[0]
y_true_out = tf.zeros((n, grid_size, grid_size, tf.shape(anchor_indices)[0], 6))
anchor_indices = tf.cast(anchor_indices, tf.int32)
indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
idx = 0
for i in tf.range(n):
for j in tf.range(tf.shape(y_true)[1]):
if tf.equal(y_true[i][j][2], 0):
continue
anchor_eq = tf.equal(anchor_indices, tf.cast(y_true[i][j][5], tf.int32))
if tf.reduce_any(anchor_eq):
box = y_true[i][j][0:4]
box_xy = (y_true[i][j][0:2] y_true[i][j][2:4]) / 2
anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
grid_xy = tf.cast(box_xy // (1 / grid_size), tf.int32)
indexes = indexes.write(
idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]]
)
updates = updates.write(
idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]]
)
idx = 1
return tf.tensor_scatter_nd_update(y_true_out, indexes.stack(), updates.stack())
def transform_targets(y, anchors, anchor_masks, size):
y_outs = []
grid_size = size // 32
anchors = tf.cast(anchors, tf.float32)
anchor_area = anchors[..., 0] * anchors[..., 1]
box_wh = y[..., 2:4] - y[..., 0:2]
box_wh = tf.tile(tf.expand_dims(box_wh, -2), (1, 1, tf.shape(anchors)[0], 1))
box_area = box_wh[..., 0] * box_wh[..., 1]
intersection = tf.minimum(box_wh[..., 0], anchors[..., 0]) * tf.minimum(
box_wh[..., 1], anchors[..., 1]
)
iou = intersection / (box_area anchor_area - intersection)
anchor_idx = tf.cast(tf.argmax(iou, axis=-1), tf.float32)
anchor_idx = tf.expand_dims(anchor_idx, axis=-1)
y = tf.concat([y, anchor_idx], axis=-1)
for anchor_indices in anchor_masks:
y_outs.append(transform_targets_for_output(y, grid_size, anchor_indices))
grid_size *= 2
return tuple(y_outs)
def read_example(
example,
feature_map,
class_table,
max_boxes,
image_shape,
):
features = tf.io.parse_single_example(example, feature_map)
image = tf.image.decode_png(features['image'], channels=3)
image = tf.image.resize(image, image_shape)
object_name = tf.sparse.to_dense(features['object_name'])
label = tf.cast(class_table.lookup(object_name), tf.float32)
label = tf.stack(
[tf.sparse.to_dense(features[feature]) for feature in ['x0', 'y0', 'x1', 'y1']]
[label],
1,
)
padding = [[0, max_boxes - tf.shape(label)[0]], [0, 0]]
label = tf.pad(label, padding)
return image, label
def read_tfrecord(
fp,
classes_file,
image_shape,
max_boxes,
shuffle_buffer_size,
batch_size,
anchors,
masks,
classes_delimiter='\n',
):
text_initializer = tf.lookup.TextFileInitializer(
classes_file, tf.string, 0, tf.int64, -1, delimiter=classes_delimiter
)
class_table = tf.lookup.StaticHashTable(text_initializer, -1)
files = tf.data.Dataset.list_files(fp)
dataset = files.flat_map(tf.data.TFRecordDataset)
feature_map = {
'image': tf.io.FixedLenFeature([], tf.string),
'x0': tf.io.VarLenFeature(tf.float32),
'y0': tf.io.VarLenFeature(tf.float32),
'x1': tf.io.VarLenFeature(tf.float32),
'y1': tf.io.VarLenFeature(tf.float32),
'object_name': tf.io.VarLenFeature(tf.string),
'object_index': tf.io.VarLenFeature(tf.int64),
}
return (
dataset.map(
lambda x: read_example(x, feature_map, class_table, max_boxes, image_shape),
tf.data.experimental.AUTOTUNE,
)
.batch(batch_size)
.shuffle(shuffle_buffer_size)
.map(
lambda x, y: (
transform_images(x, image_shape),
transform_targets(y, anchors, masks, image_shape[0]),
)
)
.prefetch(tf.data.experimental.AUTOTUNE)
)
if __name__ == '__main__':
input_shape = (608, 608, 3)
labels = pd.read_csv('labels.csv')
classes_file = 'classes.txt'
max_boxes = max([g[1].shape[0] for g in labels.groupby('image')])
shuffle_buffer_size = 256
batch_size = 8
anchors = np.array(
[
(10, 13),
(16, 30),
(33, 23),
(30, 61),
(62, 45),
(59, 119),
(116, 90),
(156, 198),
(373, 326),
]
) / np.array(input_shape[:-1])
masks = np.array([[6, 7, 8], [3, 4, 5], [0, 1, 2]])
train_dataset = read_tfrecord(
'/content/yolo-train-0.tfrecord',
classes_file,
input_shape[:-1],
max_boxes,
shuffle_buffer_size,
batch_size,
anchors,
masks,
)
for i, _ in enumerate(train_dataset, 1): # There should be around 11000 iterations
print(f'\riteration: {i}', end='')
Is there a way to filter out the problematic examples?
I tried the following using try and except blocks and it doesn't work and gives the exception being specified despite adding the following to create_tfrecord
dataset = iter(dataset)
while True:
try:
yield next(dataset)
except InvalidArgumentError:
pass
CodePudding user response:
Wrapping the transform_targets_for_output
method with a try-except-raise
clause and applying tf.data.experimental.ignore_errors
to the dataset seems to actually work:
def transform_targets_for_output(y_true, grid_size, anchor_indices):
try:
n = tf.shape(y_true)[0]
y_true_out = tf.zeros((n, grid_size, grid_size, tf.shape(anchor_indices)[0], 6))
anchor_indices = tf.cast(anchor_indices, tf.int32)
indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
idx = 0
for i in tf.range(n):
for j in tf.range(tf.shape(y_true)[1]):
if tf.equal(y_true[i][j][2], 0):
continue
anchor_eq = tf.equal(anchor_indices, tf.cast(y_true[i][j][5], tf.int32))
if tf.reduce_any(anchor_eq):
box = y_true[i][j][0:4]
box_xy = (y_true[i][j][0:2] y_true[i][j][2:4]) / 2
anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
grid_xy = tf.cast(box_xy // (1 / grid_size), tf.int32)
indexes = indexes.write(
idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]]
)
updates = updates.write(
idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]]
)
idx = 1
return tf.tensor_scatter_nd_update(y_true_out, indexes.stack(), updates.stack())
except tf.errors.InvalidArgumentError:
raise
Using a batch size of 8, I was able to iterate through the dataset successfully:
train_dataset = train_dataset.apply(tf.data.experimental.ignore_errors())
for i, _ in enumerate(train_dataset, 1): # There should be around 11000 iterations
print(f'\riteration: {i}', end='')
iteration: 11244