Home > Back-end >  Custom gradient with complex exponential in tensorflow
Custom gradient with complex exponential in tensorflow

Time:03-28

As an exercise I am trying to build a custom operator in Tensorflow, and checking the gradient against Tensorflow's autodiff of the same forward operation composed of Tensorflow API operations. However, the gradient of my custom operator is incorrect. It seems like my complex analysis is not correct and needs some brushing up.

import tensorflow as tf

shape = (1, 16)
dtype = tf.complex64

x = tf.cast(tf.complex(tf.random.normal(shape), tf.random.normal(shape)), dtype=dtype)

def fun(x):
    phi = x * tf.math.conj(x)
    e = tf.exp(1j * phi)
    return e

def d_fun(x):
    d_phi = x   tf.math.conj(x)
    phi = x * tf.math.conj(x)
    d_e = 1j * d_phi * tf.exp(1j * phi)
    return d_e

@tf.custom_gradient
def tf_custom(x):    
    e = fun(x)
    def grad(dy):
        d_e = d_fun(x)
        return dy * d_e
    return e, grad

with tf.GradientTape() as g:
    g.watch(x)
    res = fun(x)
    
dy_dx = g.gradient(res, x)

with tf.GradientTape() as g:
    g.watch(x)
    res2 = tf_custom(x)
    
dy_dx2 = g.gradient(res2, x)

print(tf.reduce_sum(tf.abs(res - res2)).numpy())
print(tf.reduce_sum(tf.abs(dy_dx - dy_dx2)).numpy())

CodePudding user response:

TensorFlow 2 does not directly computes the derivative of a function of complex variables. It seems that it computes the derivative of a function of a complex variable as the function of the real part and the imaginary part, using Wirtinger calculus. You can also find an explanation here.

  • Related