Home > database >  How can I compute gradient with tensorflow/keras?
How can I compute gradient with tensorflow/keras?

Time:11-01

I am trying to draw the saliency maps with tensorflow 2.6.0. I defined the model with tensorflow.keras.models.Sequential and finish training. However, when I tried to compute gradient like this:

with GradientTape(persistent=True) as tape:
    tape.watch(image)
    result = model.predict(image)[:, 4]
    gradient = tape.gradient(result, image)

where the image is the tensor and [:, 4] is for choosing the actual label for this image.

image.shape = (1, 48, 48, 1)
result.shape = (1, 10)

Then I got error message:

---------------------------------------------------------------------------
LookupError                               Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
    605           try:
--> 606             grad_fn = ops.get_gradient_function(op)
    607           except LookupError:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in get_gradient_function(op)
   2731     op_type = op.type
-> 2732   return gradient_registry.lookup(op_type)
   2733 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/registry.py in lookup(self, name)
     99       raise LookupError(
--> 100           "%s registry has no entry for: %s" % (self._name, name))

LookupError: gradient registry has no entry for: IteratorGetNext
During handling of the above exception, another exception occurred:

LookupError                               Traceback (most recent call last)
/tmp/ipykernel_36/2425374110.py in <module>
      1 with GradientTape(persistent=True) as tape:
      2     tape.watch(image)
----> 3     result = model.predict(image)[:, 4]
      4 #     result = tf.convert_to_tensor(result)
      5 #     probs = tf.nn.softmax(result, axis=-1)[:, 4]

/opt/conda/lib/python3.7/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
   1749           for step in data_handler.steps():
   1750             callbacks.on_predict_batch_begin(step)
-> 1751             tmp_batch_outputs = self.predict_function(iterator)
   1752             if data_handler.should_sync:
   1753               context.async_wait()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    922       # In this case we have not created variables on the first call. So we can
    923       # run the first trace but we should fail if variables are created.
--> 924       results = self._stateful_fn(*args, **kwds)
    925       if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION:
    926         raise ValueError("Creating variables on a non-first call to a function"

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   3038        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   3039     return graph_function._call_flat(
-> 3040         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   3041 
   3042   @property

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1967         possible_gradient_type,
   1968         executing_eagerly)
-> 1969     forward_function, args_with_tangents = forward_backward.forward()
   1970     if executing_eagerly:
   1971       flat_outputs = forward_function.call(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in forward(self)
   1493     """Builds or retrieves a forward function for this call."""
   1494     forward_function = self._functions.forward(
-> 1495         self._inference_args, self._input_tangents)
   1496     return forward_function, self._inference_args   self._input_tangents
   1497 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in forward(self, inference_args, input_tangents)
   1224       (self._forward, self._forward_graph, self._backward,
   1225        self._forwardprop_output_indices, self._num_forwardprop_outputs) = (
-> 1226            self._forward_and_backward_functions(inference_args, input_tangents))
   1227     return self._forward
   1228 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _forward_and_backward_functions(self, inference_args, input_tangents)
   1449       outputs = list(self._func_graph.outputs)
   1450       self._build_functions_for_outputs(
-> 1451           outputs, inference_args, input_tangents)
   1452 
   1453     (forward_function, forward_graph,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _build_functions_for_outputs(self, outputs, inference_args, input_tangents)
    946             self._func_graph.inputs,
    947             grad_ys=gradients_wrt_outputs,
--> 948             src_graph=self._func_graph)
    949 
    950       if input_tangents:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
    634               raise LookupError(
    635                   "No gradient defined for operation '%s' (op type: %s)" %
--> 636                   (op.name, op.type))
    637         if loop_state:
    638           loop_state.EnterGradWhileContext(op, before=False)

LookupError: No gradient defined for operation 'IteratorGetNext' (op type: IteratorGetNext)

What should I do to solve this problem? Thanks for answering.

CodePudding user response:

Use model(image)[:, 4] instead of model.predict(image)[:, 4]. The last one transforms the output to numpy arrays (without any autograd functionalities) and then tensorflow can't compute the gradient. By the way in a persistent context, it is significantly less efficient to call tape.gradient inside the context because you will aslo record the gradient operations. The only situation where it is interesting is when you want higher degree derivation.

with tf.GradientTape(persistent=True) as tape:
    tape.watch(image)
    result = model(image)[:, 4]
gradient = tape.gradient(preds, A)
  • Related