I would like to use a model from sentence-transformers
inside of a larger Keras model.
Here is the full example:
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
MODEL_PATH = 'sentence-transformers/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = TFAutoModel.from_pretrained(MODEL_PATH, from_pt=True)
class SBert(tf.keras.layers.Layer):
def __init__(self, tokenizer, model):
super(SBert, self).__init__()
self.tokenizer = tokenizer
self.model = model
def tf_encode(self, inputs):
def encode(inputs):
return self.tokenizer(
inputs, padding=True, truncation=True, return_tensors='tf'
)
return tf.py_function(func=encode, inp=[inputs], Tout=[tf.int64])
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = tf.cast(
tf.broadcast_to(tf.expand_dims(attention_mask, -1), token_embeddings.shape),
tf.float32
)
a = tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1)
b = tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)
embeddings = a / b
embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
return embeddings
def call(self, inputs):
encoded_input = self.tf_encode(inputs)
model_output = self.model(encoded_input)
embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
return embeddings
sbert = SBert(tokenizer, model)
sbert(['some text', 'more text'])
I am able to use the model and tokenizer outside of TF / Keras with no problems, the issue seems to happen when we try and build the graph and TF passing a symbolic tensor to the tokenizer, generating an error - this is why I have tried to wrap in tf.py_function
but with no success...
The error:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-20-a0c4a906e456> in <module>
44
45 sbert = SBert(tokenizer, model)
---> 46 sbert(['some text', 'more text'])
~/.pyenv/versions/3.7.8/lib/python3.7/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
<ipython-input-20-a0c4a906e456> in call(self, inputs)
36 def call(self, inputs):
37 tf.print(inputs, output_stream=sys.stdout)
---> 38 encoded_input = self.tf_encode(inputs)
39 tf.print(encoded_input, output_stream=sys.stdout)
40 model_output = self.model(encoded_input)
<ipython-input-20-a0c4a906e456> in tf_encode(self, inputs)
20 inputs, padding=True, truncation=True, return_tensors='tf'
21 )
---> 22 return tf.py_function(func=encode, inp=[inputs], Tout=[tf.int64])
23
24 def mean_pooling(model_output, attention_mask):
InvalidArgumentError: Exception encountered when calling layer "s_bert_6" (type SBert).
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
Traceback (most recent call last):
File "/Users/dennisyurkevich/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py", line 269, in __call__
return func(device, token, args)
File "/Users/dennisyurkevich/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py", line 147, in __call__
outputs = self._call(device, args)
File "/Users/dennisyurkevich/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py", line 154, in _call
ret = self._func(*args)
File "/Users/dennisyurkevich/.pyenv/versions/3.7.8/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
return func(*args, **kwargs)
File "<ipython-input-20-a0c4a906e456>", line 20, in encode
inputs, padding=True, truncation=True, return_tensors='tf'
File "/Users/dennisyurkevich/.pyenv/versions/3.7.8/lib/python3.7/site-packages/transformers/tokenization_utils_base.py", line 2378, in __call__
"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).
[Op:EagerPyFunc]
Call arguments received:
• inputs=["'some text'", "'more text'"]
CodePudding user response:
tf.py_function
does not seem to work with a dict output that’s why you can try returning three separate tensors. Also, I am decoding the inputs to remove the b
in the front of each string:
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
MODEL_PATH = 'sentence-transformers/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = TFAutoModel.from_pretrained(MODEL_PATH, from_pt=True)
class SBert(tf.keras.layers.Layer):
def __init__(self, tokenizer, model):
super(SBert, self).__init__()
self.tokenizer = tokenizer
self.model = model
def tf_encode(self, inputs):
def encode(inputs):
inputs = [x.decode("utf-8") for x in inputs.numpy()]
outputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='tf')
return outputs['input_ids'], outputs['token_type_ids'], outputs['attention_mask']
return tf.py_function(func=encode, inp=[inputs], Tout=[tf.int32, tf.int32, tf.int32])
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = tf.cast(
tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)),
tf.float32
)
a = tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1)
b = tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)
embeddings = a / b
embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
return embeddings
def call(self, inputs):
input_ids, token_type_ids, attention_mask = self.tf_encode(inputs)
model_output = self.model({'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask})
embeddings = self.mean_pooling(model_output, attention_mask)
return embeddings
sbert = SBert(tokenizer, model)
sbert(['some text', 'more text'])
If you want to use a Keras
model, you will have to do something like this:
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
MODEL_PATH = 'sentence-transformers/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = TFAutoModel.from_pretrained(MODEL_PATH, from_pt=True)
class SBert(tf.keras.layers.Layer):
def __init__(self, tokenizer, model):
super(SBert, self).__init__()
self.tokenizer = tokenizer
self.model = model
def tf_encode(self, inputs):
def encode(inputs):
inputs = [x[0].decode("utf-8") for x in inputs.numpy()]
outputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='tf')
return outputs['input_ids'], outputs['token_type_ids'], outputs['attention_mask']
return tf.py_function(func=encode, inp=[inputs], Tout=[tf.int32, tf.int32, tf.int32])
def process(self, i, t, a):
def __call(i, t, a):
model_output = self.model({'input_ids': i.numpy(), 'token_type_ids': t.numpy(), 'attention_mask': a.numpy()})
return model_output[0]
return tf.py_function(func=__call, inp=[i, t, a], Tout=[tf.float32])
def mean_pooling(self, model_output, attention_mask):
token_embeddings = tf.squeeze(tf.stack(model_output), axis=0)
input_mask_expanded = tf.cast(
tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)),
tf.float32
)
a = tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1)
b = tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)
embeddings = a / b
embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
return embeddings
def call(self, inputs):
input_ids, token_type_ids, attention_mask = self.tf_encode(inputs)
model_output = self.process(input_ids, token_type_ids, attention_mask)
embeddings = self.mean_pooling(model_output, attention_mask)
return embeddings
sbert = SBert(tokenizer, model)
inputs = tf.keras.layers.Input((1,), dtype=tf.string)
outputs = sbert(inputs)
model = tf.keras.Model(inputs, outputs)
model(tf.constant(['some text', 'more text']))
TensorShape([2, 384]).shape