Home > Software design >  TensorFlow: apply convolutions along one specific axis
TensorFlow: apply convolutions along one specific axis


I have two tensors A and B of shape (batch_size, height, width, 1) which I want to convolve along the width axis, i.e. convolve A[0, 0] with B[0, 0], A[0, 1] with B[0, 1], A[3, 6] with B[3, 6], etc. I've tried to achieve this with a combination of tf.nn.conv1d and tf.map_fn, but I keep getting errors related to the input shapes, AutoGraph, etc.

How do I efficiently convolve these two tensors along a specific axis?

EDIT: added non-functional code to illustrate the idea

# Create tensors
A, B = tf.random.normal(shape=(2, 1, 8, 512, 1))

# Reshape tensors suitable for `conv1d`
A = tf.transpose(A, [1, 0, 2, 3])
B = tf.transpose(B, [1, 2, 0, 3])

# Define convolution function
conv_fn = partial(tf.nn.conv1d, padding="SAME", stride=1)

# Apply map
AB = tf.map_fn(lambda x: conv_fn(x[0], x[1]), (A, B))

Result (vector values omitted for brevity):

ValueError: The two structures don't have the same nested structure.

First structure: type=tuple str=(<tf.Tensor: shape=(8, 1, 512, 1), dtype=float32, numpy=
array([[omitted]], dtype=float32)>, <tf.Tensor: shape=(8, 512, 1, 1), dtype=float32, numpy=
array([[omitted]], dtype=float32)>)

Second structure: type=EagerTensor str=tf.Tensor(
[[omitted]], shape=(1, 512, 1), dtype=float32)

More specifically: Substructure "type=tuple str=(<tf.Tensor: shape=(8, 1, 512, 1), dtype=float32, numpy=
array([[omitted]], dtype=float32)>, <tf.Tensor: shape=(8, 512, 1, 1), dtype=float32, numpy=
array([[omitted]], dtype=float32)>)" is a sequence, while substructure "type=EagerTensor str=tf.Tensor(
[[omitted]], shape=(1, 512, 1), dtype=float32)" is not

CodePudding user response:

Check the docs regarding the tensor shape:

The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as elems. Its output must have the same structure as fn_output_signature if one is provided; otherwise it must have the same structure as elems.

So try a more flexible output signature:

import tensorflow as tf
from functools import partial

A, B = tf.random.normal(shape=(2, 1, 8, 512, 1))

# Reshape tensors suitable for `conv1d`
A = tf.transpose(A, [1, 0, 2, 3])
B = tf.transpose(B, [1, 2, 0, 3])

# Define convolution function
conv_fn = partial(tf.nn.conv1d, padding="SAME", stride=1)

# Apply map
AB = tf.map_fn(lambda x: conv_fn(x[0], x[1]), (A, B), fn_output_signature = tf.TensorSpec((None)))
# (8, 1, 512, 1)

You could also consider using tf.while_loop.

  • Related