Home > database >  Python: How to type hint tf.keras object in functions?
Python: How to type hint tf.keras object in functions?

Time:08-12

This example function returns a dictionary of keras tensors:

import pandas as pd
import tensorflow as tf

def create_input_tensors(data: pd.DataFrame) -> Dict[str,tf.keras.engine.keras_tensor.KerasTensor]:
   """Turns each dataframe column into a keras tensor and returns them as a dict"""
   tensors = {}
   for name, column in data.items():
      tensors[name] = tf.keras.Input(shape=(1, ), name=name, dtype=float32)
   return tensors

I do not know how to correctly type hint the return value. Running the code snippet yields the following exception:

Exception has occurred: AttributeError
module 'keras.api._v2.keras' has no attribute 'engine'

Googling this exception did not help. Running type(tensors['year']) in the debugger at the end of the function to see what type one of the elements in the return dictionary is (year is one of the columns in data) yields <class 'keras.engine.keras_tensor.KerasTensor'>.

I have issues with this specific function, but also generally when trying to type hint functions that handle any kind of keras object. An answer that is applicable to these similar problems would be much appreciated.

CodePudding user response:

This works like a charm for me:

import typing
from keras.engine.keras_tensor import KerasTensor

def f() -> typing.Dict[str, KerasTensor]:
    return {"a": tf.keras.Input(shape=(1, ),)}

f()
  • Related