I am using BruteForce from Tensorflow Recommenders
index = tfrs.layers.factorized_top_k.BruteForce(model.customer_model, k = 400)
the candidates dataset looks like this:
<ZipDataset element_spec=({'article_id': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'prod_name': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'product_type_name': TensorSpec(shape=(None,), dtype=tf.string, name=None)}, TensorSpec(shape=(None, 64), dtype=tf.float32, name=None))>
but when i try to build the retrievel index
index.index_from_dataset(candidates)
i get the following error:
AttributeError Traceback (most recent call last)
Input In [28], in <cell line: 6>()
4 candidates = tf.data.Dataset.zip((articles.batch(128), articles.batch(128).map(model.article_model)))
5 print(candidates)
----> 6 index.index_from_dataset(candidates)
File ~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_recommenders/layers/factorized_top_k.py:197, in TopK.index_from_dataset(self, candidates)
174 def index_from_dataset(
175 self,
176 candidates: tf.data.Dataset
177 ) -> "TopK":
178 """Builds the retrieval index.
179
180 When called multiple times the existing index will be dropped and a new one
(...)
194 ValueError if the dataset does not have the correct structure.
195 """
--> 197 _check_candidates_with_identifiers(candidates)
199 spec = candidates.element_spec
201 if isinstance(spec, tuple):
File ~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_recommenders/layers/factorized_top_k.py:127, in _check_candidates_with_identifiers(candidates)
119 raise ValueError(
120 "The dataset must yield candidate embeddings or "
121 "tuples of (candidate identifiers, candidate embeddings). "
122 f"Got {spec} instead."
123 )
125 identifiers_spec, candidates_spec = spec
--> 127 if candidates_spec.shape[0] != identifiers_spec.shape[0]:
128 raise ValueError(
129 "Candidates and identifiers have to have the same batch dimension. "
130 f"Got {candidates_spec.shape[0]} and {identifiers_spec.shape[0]}."
131 )
AttributeError: 'dict' object has no attribute 'shape'
I assume it has a problem with my dataset which is created from a dictionary.
How should i pass the candidates Dataset so i don't get the error ?
CodePudding user response:
i figured it out
i was building the candidates dataset like so:
candidates = tf.data.Dataset.zip(articles.batch(128).map(model.article_model)))
index.index_from_dataset(candidates)
But i needed to also pass the candidate identifiers, not just candidates embeddings:
candidates = tf.data.Dataset.zip((articles.batch(128).map(lambda x: x["article_id"]), articles.batch(128).map(model.article_model)))
index.index_from_dataset(candidates)