Skip to content

Commit

Permalink
Makes RetrievalModelV2 support item tower with transforms (e.g. pre-t…
Browse files Browse the repository at this point in the history
…rained embeddings) (#1198)

* Making retrieval model to_top_k_model(), candidate_embeddings() and batch_predict() support Loader with transforms for pre-trained embeddings in item tower

* Fixing test error and ensuring all batch_predict() with the new API support Loader with transforms (which include pre-trained embeddings)

* Fixing retrieval example, which was using wrong schema to export query and item embeddings

* Added missing importorskip on torch and pytorch_lightning for torch integration tests

* Skiping a test if nvtabular is available
  • Loading branch information
gabrielspmoreira authored Jul 20, 2023
1 parent 52c89a4 commit a1d0be2
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 27 deletions.
8 changes: 5 additions & 3 deletions examples/05-Retrieval-Model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,8 @@
}
],
"source": [
"queries = model.query_embeddings(Dataset(user_features, schema=schema), batch_size=1024, index=Tags.USER_ID)\n",
"queries = model.query_embeddings(Dataset(user_features, schema=schema.select_by_tag(Tags.USER)), \n",
" batch_size=1024, index=Tags.USER_ID)\n",
"query_embs_df = queries.compute(scheduler=\"synchronous\").reset_index()"
]
},
Expand Down Expand Up @@ -1996,7 +1997,8 @@
}
],
"source": [
"item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema), batch_size=1024, index=Tags.ITEM_ID)"
"item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema.select_by_tag(Tags.ITEM)), \n",
" batch_size=1024, index=Tags.ITEM_ID)"
]
},
{
Expand Down Expand Up @@ -2460,7 +2462,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.8.10"
},
"merlin": {
"containers": [
Expand Down
53 changes: 40 additions & 13 deletions merlin/models/tf/core/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from merlin.models.tf.core.prediction import TopKPrediction
from merlin.models.tf.inputs.base import InputBlockV2
from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable
from merlin.models.tf.loader import Loader
from merlin.models.tf.models.base import BaseModel, get_output_schema
from merlin.models.tf.outputs.topk import TopKOutput
from merlin.models.tf.transforms.features import PrepareFeatures
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(

def encode(
self,
dataset: merlin.io.Dataset,
dataset: Union[merlin.io.Dataset, Loader],
index: Union[str, ColumnSchema, Schema, Tags],
batch_size: int,
**kwargs,
Expand All @@ -93,7 +94,7 @@ def encode(
Parameters
----------
dataset: merlin.io.Dataset
dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
The dataset to encode.
index: Union[str, ColumnSchema, Schema, Tags]
The index to use for encoding.
Expand Down Expand Up @@ -127,7 +128,7 @@ def encode(

def batch_predict(
self,
dataset: merlin.io.Dataset,
dataset: Union[merlin.io.Dataset, Loader],
batch_size: int,
output_schema: Optional[Schema] = None,
index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
Expand All @@ -137,8 +138,8 @@ def batch_predict(
Parameters
----------
dataset: merlin.io.Dataset
Dataset to predict on.
dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
Dataset or Loader to predict on.
batch_size: int
Batch size to use for prediction.
Expand All @@ -161,18 +162,35 @@ def batch_predict(
raise ValueError("Only one column can be used as index")
index = index.first.name

dataset_schema = None
if hasattr(dataset, "schema"):
if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)):
dataset_schema = dataset.schema
data_output_schema = dataset_schema
if isinstance(dataset, Loader):
data_output_schema = dataset.output_schema
if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)):
raise ValueError(
f"Model schema {self.schema.column_names} does not match dataset schema"
+ f" {dataset.schema.column_names}"
+ f" {data_output_schema.column_names}"
)

loader_transforms = None
if isinstance(dataset, Loader):
loader_transforms = dataset.transforms
batch_size = dataset.batch_size
dataset = dataset.dataset

# Check if merlin-dataset is passed
if hasattr(dataset, "to_ddf"):
dataset = dataset.to_ddf()

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
model_encode = TFModelEncode(
self,
batch_size=batch_size,
loader_transforms=loader_transforms,
schema=dataset_schema,
**kwargs,
)

encode_kwargs = {}
if output_schema:
Expand Down Expand Up @@ -583,7 +601,7 @@ def encode_candidates(

def batch_predict(
self,
dataset: merlin.io.Dataset,
dataset: Union[merlin.io.Dataset, Loader],
batch_size: int,
output_schema: Optional[Schema] = None,
**kwargs,
Expand All @@ -592,8 +610,8 @@ def batch_predict(
Parameters
----------
dataset : merlin.io.Dataset
Raw queries features dataset
dataset : Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
Raw queries features dataset or Loader
batch_size : int
The number of queries to process at each prediction step
output_schema: Schema, optional
Expand All @@ -606,15 +624,24 @@ def batch_predict(
"""
from merlin.models.tf.utils.batch_utils import TFModelEncode

loader_transforms = None
if isinstance(dataset, Loader):
loader_transforms = dataset.transforms
batch_size = dataset.batch_size
dataset = dataset.dataset

dataset_schema = dataset.schema
dataset = dataset.to_ddf()

model_encode = TFModelEncode(
model=self,
batch_size=batch_size,
loader_transforms=loader_transforms,
schema=dataset_schema,
output_names=TopKPrediction.output_names(self.k),
**kwargs,
)

dataset = dataset.to_ddf()

encode_kwargs = {}
if output_schema:
encode_kwargs["filter_input_columns"] = output_schema.column_names
Expand Down
34 changes: 26 additions & 8 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,7 @@ def predict(
return out

def batch_predict(
self, dataset: merlin.io.Dataset, batch_size: int, **kwargs
self, dataset: Union[merlin.io.Dataset, Loader], batch_size: int, **kwargs
) -> merlin.io.Dataset:
"""Batched prediction using the Dask.
Parameters
Expand All @@ -1565,20 +1565,38 @@ def batch_predict(
Returns merlin.io.Dataset
-------
"""
dataset_schema = None
if hasattr(dataset, "schema"):
if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)):
dataset_schema = dataset.schema
data_output_schema = dataset_schema
if isinstance(dataset, Loader):
data_output_schema = dataset.output_schema

if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)):
raise ValueError(
f"Model schema {self.schema.column_names} does not match dataset schema"
+ f" {dataset.schema.column_names}"
+ f" {data_output_schema.column_names}"
)

loader_transforms = None
if isinstance(dataset, Loader):
loader_transforms = dataset.transforms
batch_size = dataset.batch_size
dataset = dataset.dataset

# Check if merlin-dataset is passed
if hasattr(dataset, "to_ddf"):
dataset = dataset.to_ddf()

from merlin.models.tf.utils.batch_utils import TFModelEncode

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
model_encode = TFModelEncode(
self,
batch_size=batch_size,
loader_transforms=loader_transforms,
schema=dataset_schema,
**kwargs,
)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
Expand Down Expand Up @@ -2510,20 +2528,20 @@ def query_embeddings(

def candidate_embeddings(
self,
dataset: Optional[merlin.io.Dataset] = None,
data: Optional[Union[merlin.io.Dataset, Loader]] = None,
index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
**kwargs,
) -> merlin.io.Dataset:
if self.has_candidate_encoder:
candidate = self.candidate_encoder

if dataset is not None and hasattr(candidate, "encode"):
return candidate.encode(dataset, index=index, **kwargs)
if data is not None and hasattr(candidate, "encode"):
return candidate.encode(data, index=index, **kwargs)

if hasattr(candidate, "to_dataset"):
return candidate.to_dataset(**kwargs)

return candidate.encode(dataset, index=index, **kwargs)
return candidate.encode(data, index=index, **kwargs)

if isinstance(self.last, (ContrastiveOutput, CategoricalOutput)):
return self.last.to_dataset()
Expand Down
8 changes: 6 additions & 2 deletions merlin/models/tf/utils/batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
block_load_func: tp.Optional[tp.Callable[[str], Block]] = None,
schema: tp.Optional[Schema] = None,
output_concat_func=None,
loader_transforms=None,
):
save_path = save_path or tempfile.mkdtemp()
model.save(save_path)
Expand All @@ -95,7 +96,9 @@ def __init__(
super().__init__(
save_path,
output_names,
data_iterator_func=data_iterator_func(self.schema, batch_size=batch_size),
data_iterator_func=data_iterator_func(
self.schema, batch_size=batch_size, loader_transforms=loader_transforms
),
model_load_func=model_load_func,
model_encode_func=model_encode,
output_concat_func=output_concat_func,
Expand Down Expand Up @@ -172,14 +175,15 @@ def encode_output(output: tf.Tensor):
return output.numpy()


def data_iterator_func(schema, batch_size: int = 512):
def data_iterator_func(schema, batch_size: int = 512, loader_transforms=None):
import merlin.io.dataset

def data_iterator(dataset):
return Loader(
merlin.io.dataset.Dataset(dataset, schema=schema),
batch_size=batch_size,
shuffle=False,
transforms=loader_transforms,
)

return data_iterator
19 changes: 19 additions & 0 deletions tests/integration/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest

pytest.importorskip("torch")
pytest.importorskip("pytorch_lightning")
65 changes: 64 additions & 1 deletion tests/unit/tf/models/test_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from pathlib import Path

import nvtabular as nvt
import numpy as np
import pytest
import tensorflow as tf

import merlin.models.tf as mm
from merlin.core.dispatch import make_df
from merlin.dataloader.ops.embeddings import EmbeddingOperator
from merlin.io import Dataset
from merlin.models.tf.metrics.topk import (
AvgPrecisionAt,
Expand All @@ -24,6 +25,8 @@


def test_two_tower_shared_embeddings():
nvt = pytest.importorskip("nvtabular")

train = make_df(
{
"user_id": [1, 3, 3, 4, 3, 1, 2, 4, 6, 7, 8, 9] * 100,
Expand Down Expand Up @@ -435,6 +438,66 @@ def test_two_tower_model_topk_evaluation(ecommerce_data: Dataset, run_eagerly):
assert all([metric >= 0 for metric in metrics.values()])


@pytest.mark.parametrize("run_eagerly", [True, False])
def test_two_tower_model_topk_evaluation_with_pretrained_emb(music_streaming_data, run_eagerly):
music_streaming_data.schema = music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM])

cardinality = music_streaming_data.schema["item_category"].int_domain.max + 1
pretrained_embedding = np.random.rand(cardinality, 12)

loader_transforms = [
EmbeddingOperator(
pretrained_embedding,
lookup_key="item_category",
embedding_name="pretrained_category_embeddings",
),
]
loader = mm.Loader(
music_streaming_data,
schema=music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]),
batch_size=10,
transforms=loader_transforms,
)
schema = loader.output_schema

pretrained_embeddings = mm.PretrainedEmbeddings(
schema.select_by_tag(Tags.EMBEDDING),
output_dims=16,
)

schema = loader.output_schema

query_input = mm.InputBlockV2(schema.select_by_tag(Tags.USER))
query = mm.Encoder(query_input, mm.MLPBlock([4], no_activation_last_layer=True))
candidate_input = mm.InputBlockV2(
schema.select_by_tag(Tags.ITEM), pretrained_embeddings=pretrained_embeddings
)
candidate = mm.Encoder(candidate_input, mm.MLPBlock([4], no_activation_last_layer=True))
model = mm.TwoTowerModelV2(
query,
candidate,
negative_samplers=["in-batch"],
)
model.compile(optimizer="adam", run_eagerly=run_eagerly)
_ = testing_utils.model_test(model, loader)

# Top-K evaluation
candidate_features_data = unique_rows_by_features(music_streaming_data, Tags.ITEM, Tags.ITEM_ID)
loader_candidates = mm.Loader(
candidate_features_data,
batch_size=16,
transforms=loader_transforms,
)

topk_model = model.to_top_k_encoder(loader_candidates, k=20, batch_size=16)
topk_model.compile(run_eagerly=run_eagerly)

loader = mm.Loader(music_streaming_data, batch_size=32).map(mm.ToTarget(schema, "item_id"))

metrics = topk_model.evaluate(loader, return_dict=True)
assert all([metric >= 0 for metric in metrics.values()])


@pytest.mark.parametrize("run_eagerly", [True, False])
@pytest.mark.parametrize("logits_pop_logq_correction", [True, False])
@pytest.mark.parametrize("loss", ["categorical_crossentropy", "bpr-max", "binary_crossentropy"])
Expand Down

0 comments on commit a1d0be2

Please sign in to comment.