Home > Software design >  Tensorflow gives "<tensor> is not an element of this graph" error when calling from
Tensorflow gives "<tensor> is not an element of this graph" error when calling from

Time:09-24

Firstly, I've already visited all similar topics here and other sites, but none of them worked in my case.

Let's say I have a class that deals with loading model and prediction:

class MyModel():
  def __init__(self):
    pass

  def load_model(self, model_path):
    self.model = tf.keras.models.load_model(model_path)

  def predict(self, img):
    return self.model.predict(img)

Now, I have another class in another file which calls MyModel:

from mymodel import MyModel
class MyDetector():
    def __init__(self):
        self.detector = MyModel()
        self.detector.load_model('mymodel.h5')

    def detect(self, img: numpy.ndarray):        
        return self.detector.predict(img)

This, however, throws an error saying <tensor> is not an element of this graph. I've tried all those tf.Graph.as_default() related answers available, but nothing changed. The most common suggestion is to revise the model loading and prediction part as follows:

def load_model(self, model_path):
    global model
    model = tf.keras.models.load_model(model_path)
    global graph
    graph = tf.get_default_graph() 

def predict(self, img):
    with graph.as_default():
      preds = model.predict(img)
    return preds

Still this didn't help, as all those other suggestions also available on: https://github.com/keras-team/keras/issues/6462

I think my case is different than those who have already solved similar cases because I try to call the model class from a totally different class file. My Tensorflow version is 2.6.0. Could someone have better idea how to solve it?

Update

The actual case is that I'm using gRPC to communicate with a remote server for model inference. For that, I used a quite simple client-server communication based on gRPC. My client code is defined as follows (client.py):

import cv2
import grpc 
import pybase64

import protos.mydetector_pb2 as mydetector_pb2
import protos.mydetector_pb2_grpc as mydetector_pb2_grpc 

# open a gRPC channel
channel = grpc.insecure_channel('[::]:50051')
stub = mydetector_pb2_grpc.MyDetectionServiceStub(channel)

img = cv2.imread('test.jpg')
retval, buffer = cv2.imencode('.jpg', img)
b64img = pybase64.b64encode(buffer)

print('\nSending single request to port 50051...')
request = mydetector_pb2.MyDetectionRequest(image=b64img)

response = stub.detect(request)

Then, on the receiving server side, the main server is implemented as follows (server.py):

import grpc
from concurrent import futures
import protos.mydetector_pb2_grpc as reid_grpc
import MyDetectionService

MAX_MESSAGE_IN_MB = 10

options = [
    ('grpc.max_send_message_length', MAX_MESSAGE_IN_MB * 1024 * 1024),
    ('grpc.max_receive_message_length', MAX_MESSAGE_IN_MB * 1024 * 1024)
]

server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), options=options)
reid_grpc.add_MyDetectionServiceServicer_to_server(MyDetectionService(), server)

print('Starting server. Listening on port 50051.')
server.add_insecure_port('[::]:50051')
server.start()

The MyDetectionService class is implemented as follows:

import protos.mydetector_pb2 as mydetector
import protos.mydetector_pb2_grpc as mydetector_grpc
from mydetector.service.detector import MyDetector
from mydetector.utils import img_converter

import cv2
import numpy as np

class MyDetectionService(mydetector_grpc.MyDetectionServiceServicer):
def __init__(self):
    self.detector = MyDetector()

def detect(self, request, context):
    print('detecting on received image...')
    encoded_img = request.image
    img = img_converter(encoded_img)
    img = cv2.resize(img, (240, 240))
    img2 = np.expand_dims(img, axis=0)
    result = self.detector.detect(img2)
    return mydetector.MyDetectionResponse(ans=result)

in which, MyDetector class is implemented as shown on top.

I found out that if I don't use gRPC-based server-client communication, but rather call MyDetector from any other regular outside class, everything works smoothly. However, when I send out the image from the client side via gRPC, it successfully loads the model in MyDetector class (I can call model.summary() to get full description of the model), but fails in the detect function.

Important: Based on the information available here, I believe every time when I send out a gRPC request, it will create new thread with its own Tensorflow session, and that's causing the main problem here. However, I still couldn't make it work even after following all the instructions described on that site.

CodePudding user response:

I have a working replication of the intent of the problem at https://colab.research.google.com/drive/1OaH7ZoAsY_V1sMUmc1NumWmJPnNr_54F?usp=sharing. The colab successfully uses MyDetector to predict MNIST images.

As part of this exercise, I saw a couple of things happening here:

  1. MyModel.model_path is not defined. Even though model_path is provided in MyModel.load_model as a parameter, it is unused. Said differently, I am guessing there is just a typo in the load_model section or in the description of the problem.

In addition, here are a couple of thoughts:

  1. tf.get_default_graph() does not work in TensorFlow 2.6.0. TF 2.6 has a comparable tf.compat.v1.get_default_graph(). I would run strongly recommend running tf.version to confirm that the executing code is really using 2.6.0.

  2. If you can, add MyDetector to the MyModel file. If it works, then you know that the problem has something to do with the fact that some of the code is in a separate file, and that may help with troubleshooting.

Based on the above, I recommend debugging the problem with eager execution enabled, and see if you can get things working that way.

CodePudding user response:

Your server loads the model in a different graph/session than the one it uses to receive client request and response. Revise your MyModel class to this should work:

class MyModel():
  def __init__(self):
    pass

  def load_model(self, model_path):
    self.graph = tf.compat.v1.get_default_graph()

    with self.graph.as_default():
      self.model = tf.keras.models.load_model(model_path)

    self.sess = tf.compat.v1.keras.backend.get_session()

  def predict(self, img):
    with self.graph.as_default():
      try:
        preds= self.model.predict(img)
      except tf.errors.FailedPreconditionError:
        tf.compat.v1.keras.backend.set_session(self.sess)
        preds= self.model.predict(img)

    return preds
  • Related