From 40e082c2be84bfcf02887339bc98f7ea3c58e908 Mon Sep 17 00:00:00 2001 From: markstur Date: Wed, 4 Oct 2023 21:45:08 -0700 Subject: [PATCH 01/17] Adding reranker Signed-off-by: markstur --- caikit_nlp/data_model/__init__.py | 1 + caikit_nlp/data_model/reranker.py | 57 ++++++++ caikit_nlp/modules/reranker/__init__.py | 15 ++ caikit_nlp/modules/reranker/rerank.py | 160 +++++++++++++++++++++ caikit_nlp/modules/reranker/rerank_task.py | 33 +++++ tests/data_model/test_reranker.py | 89 ++++++++++++ 6 files changed, 355 insertions(+) create mode 100644 caikit_nlp/data_model/reranker.py create mode 100644 caikit_nlp/modules/reranker/__init__.py create mode 100644 caikit_nlp/modules/reranker/rerank.py create mode 100644 caikit_nlp/modules/reranker/rerank_task.py create mode 100644 tests/data_model/test_reranker.py diff --git a/caikit_nlp/data_model/__init__.py b/caikit_nlp/data_model/__init__.py index dcbf0585..9354d875 100644 --- a/caikit_nlp/data_model/__init__.py +++ b/caikit_nlp/data_model/__init__.py @@ -18,3 +18,4 @@ from . import embedding_vectors, generation from .embedding_vectors import * from .generation import * +from .reranker import * diff --git a/caikit_nlp/data_model/reranker.py b/caikit_nlp/data_model/reranker.py new file mode 100644 index 00000000..2ca11b22 --- /dev/null +++ b/caikit_nlp/data_model/reranker.py @@ -0,0 +1,57 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from caikit.core.data_model.json_dict import JsonDict +from caikit.core import ( + dataobject, + DataObjectBase, +) + +from typing import List, Dict + + +@dataobject() +class RerankDocument(DataObjectBase): + """An input document with key of text else _text else empty string used for comparison""" + document: Dict[str, str] # TODO: get any JsonDict working for input + + +@dataobject() +class RerankDocuments(DataObjectBase): + """An input list of documents""" + documents: List[RerankDocument] + + @classmethod + def from_proto(cls, proto): + return cls([{"document": dict(d.document.items())} for d in proto.documents]) + + +@dataobject() +class RerankScore(DataObjectBase): + """The score for one document (one query)""" + document: JsonDict + corpus_id: int + score: float + + +@dataobject() +class RerankQueryResult(DataObjectBase): + """Result for one query in a rerank task""" + scores: List[RerankScore] + + +@dataobject() +class RerankPrediction(DataObjectBase): + """Result for a rerank task""" + results: List[RerankQueryResult] diff --git a/caikit_nlp/modules/reranker/__init__.py b/caikit_nlp/modules/reranker/__init__.py new file mode 100644 index 00000000..a3375577 --- /dev/null +++ b/caikit_nlp/modules/reranker/__init__.py @@ -0,0 +1,15 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .rerank import Rerank diff --git a/caikit_nlp/modules/reranker/rerank.py b/caikit_nlp/modules/reranker/rerank.py new file mode 100644 index 00000000..d8f5d2d1 --- /dev/null +++ b/caikit_nlp/modules/reranker/rerank.py @@ -0,0 +1,160 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from pathlib import Path +from typing import List +import os + +# Third Party +from sentence_transformers import SentenceTransformer +from sentence_transformers.util import semantic_search, normalize_embeddings, dot_score + +# First Party +from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module +from caikit.core.exceptions import error_handler +import alog + +from caikit_nlp.data_model.reranker import RerankPrediction, RerankQueryResult, RerankScore, RerankDocuments +from .rerank_task import RerankTask + +logger = alog.use_channel("") +error = error_handler.get(logger) + +HOME = Path.home() +DEFAULT_HF_MODEL = "sentence-transformers/all-MiniLM-L6-v2" + + +@module( + "00110203-0405-0607-0809-0a0b02dd0e0f", + "RerankerModule", + "0.0.1", + RerankTask, +) +class Rerank(ModuleBase): + + def __init__( + self, + model: SentenceTransformer, + ): + """Initialize + This function gets called by `.load` and `.train` function + which initializes this module. + """ + super().__init__() + self.model = model + + @classmethod + def load(cls, model_path: str) -> "Rerank": + """Load a model + + Args: + model_path: str + Path to the config dir of the model to be loaded. + + Returns: + Rerank + Instance of this class built from the model. + """ + + config = ModuleConfig.load(model_path) + load_path = config.get("model_artifacts") + + artifact_path = False + if load_path: + if os.path.isabs(load_path) and os.path.isdir(load_path): + artifact_path = load_path + else: + full_path = os.path.join(model_path, load_path) + if os.path.isdir(full_path): + artifact_path = full_path + + if not artifact_path: + artifact_path = config.get("hf_model", DEFAULT_HF_MODEL) + + return cls.bootstrap(artifact_path) + + def run(self, queries: List[str], documents: RerankDocuments, top_n: int = 10) -> RerankPrediction: + """Run inference on model. + Args: + queries: List[str] + documents: RerankDocuments + top_n: int + Returns: + RerankPrediction + """ + + if len(queries) < 1: + return RerankPrediction() + + if len(documents.documents) < 1: + return RerankPrediction() + + if top_n < 1: + top_n = 10 # Default to 10 (instead of JSON default 0) + + # Using input document dicts so get "text" else "_text" else default to "" + doc_texts = [srd.document.get("text") or srd.document.get("_text", "") for srd in documents.documents] + + doc_embeddings = self.model.encode(doc_texts, convert_to_tensor=True) + doc_embeddings = doc_embeddings.to(self.model.device) + doc_embeddings = normalize_embeddings(doc_embeddings) + + query_embeddings = self.model.encode(queries, convert_to_tensor=True) + query_embeddings = query_embeddings.to(self.model.device) + query_embeddings = normalize_embeddings(query_embeddings) + + res = semantic_search(query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score) + + for r in res: + for x in r: + x['document'] = documents.documents[x['corpus_id']].document + + results = [RerankQueryResult([RerankScore(**x) for x in r]) for r in res] + + return RerankPrediction(results=results) + + @classmethod + def bootstrap(cls, base_model_path: str) -> "Rerank": + """Bootstrap a sentence-transformers model + + Args: + base_model_path: str + Path to the model to be loaded. + """ + model = SentenceTransformer( + base_model_path, + cache_folder=f"{HOME}/.cache/huggingface/sentence_transformers", + ) + return cls( + model=model, + ) + + def save(self, model_path: str): + """Save model in target path + + Args: + model_path: str + Path to store model artifact(s) + """ + saver = ModuleSaver( + self, + model_path=model_path, + ) + + # Extract object to be saved + with saver: + saver.update_config({"model_artifacts": "."}) + if self.model: # This condition allows for empty placeholders + self.model.save(model_path) diff --git a/caikit_nlp/modules/reranker/rerank_task.py b/caikit_nlp/modules/reranker/rerank_task.py new file mode 100644 index 00000000..9dbefd42 --- /dev/null +++ b/caikit_nlp/modules/reranker/rerank_task.py @@ -0,0 +1,33 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import alog +from caikit.core import TaskBase, task +from caikit.core.exceptions import error_handler +from caikit_nlp.data_model.reranker import RerankPrediction, RerankDocuments + +from typing import List + +logger = alog.use_channel("") +error = error_handler.get(logger) + +@task( + required_parameters={ + "documents": RerankDocuments, + "queries": List[str], + }, + output_type=RerankPrediction, +) +class RerankTask(TaskBase): + pass diff --git a/tests/data_model/test_reranker.py b/tests/data_model/test_reranker.py new file mode 100644 index 00000000..2e971403 --- /dev/null +++ b/tests/data_model/test_reranker.py @@ -0,0 +1,89 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for reranker +""" + +# Local +from caikit_nlp import data_model as dm + +from typing import Dict +import random +import string + +## Setup ######################################################################### + +input_document = { + "text": "this is the input text", + "title": "some title attribute here", + "anything": "another string attribute", + "_text": "alternate _text here", +} + +# TODO: don't use random (could flake). Make sure to use a working document. +key = ''.join(random.choices(string.ascii_letters, k=20)) +value = ''.join(random.choices(string.printable, k=100)) +print(key) +print(value) +input_random_document = { + key: value +} + +input_documents = [input_document, input_random_document] +input_queries = [] + +## Tests ######################################################################## + + +def _compare_rerank_document_and_dict(rerank_doc: dm.RerankDocument, d: Dict): + assert isinstance(rerank_doc, dm.RerankDocument) + assert isinstance(rerank_doc.document, Dict) + assert isinstance(rerank_doc.document["text"], str) + assert rerank_doc.document["text"] == d["text"] + assert rerank_doc.document == d + + +def test_rerank_document(): + new_dm_from_init = dm.RerankDocument(document=input_document) + + _compare_rerank_document_and_dict(new_dm_from_init, input_document) + + # Test proto + proto_from_dm = new_dm_from_init.to_proto() + new_dm_from_proto = dm.RerankDocument.from_proto(proto_from_dm) + _compare_rerank_document_and_dict(new_dm_from_proto, input_document) + + # Test json + json_from_dm = new_dm_from_init.to_json() + new_dm_from_json = dm.RerankDocument.from_json(json_from_dm) + _compare_rerank_document_and_dict(new_dm_from_json, input_document) + + +def test_rerank_documents(): + in_docs = [dm.RerankDocument(doc) for doc in input_documents] + new_dm_from_init = dm.RerankDocuments(documents=in_docs) + assert isinstance(new_dm_from_init, dm.RerankDocuments) + out_docs = [d.document for d in new_dm_from_init.documents] + assert out_docs == input_documents + + # Test proto + proto_from_dm = new_dm_from_init.to_proto() + new_dm_from_proto = dm.RerankDocuments.from_proto(proto_from_dm) + out_docs = [d["document"] for d in new_dm_from_proto.documents] + assert out_docs == input_documents + + # Test json + json_from_dm = new_dm_from_init.to_json() + new_dm_from_json = dm.RerankDocuments.from_json(json_from_dm) + out_docs = [d["document"] for d in new_dm_from_json.documents] + assert out_docs == input_documents From 0acde21803ac20320af2d818c5f34f29a5301cd9 Mon Sep 17 00:00:00 2001 From: markstur Date: Wed, 11 Oct 2023 08:56:44 -0700 Subject: [PATCH 02/17] Simpler input model. Expand to any JsonDict. * Less data objects and more primitives * Fixes str,str limitation in the input JSON * Add tests * More ready for review changes Signed-off-by: markstur --- caikit_nlp/data_model/reranker.py | 40 ++++--- caikit_nlp/modules/reranker/__init__.py | 1 + caikit_nlp/modules/reranker/rerank.py | 115 +++++++++++-------- caikit_nlp/modules/reranker/rerank_task.py | 14 ++- tests/data_model/test_reranker.py | 59 +++------- tests/modules/reranker/test_rerank.py | 124 +++++++++++++++++++++ 6 files changed, 239 insertions(+), 114 deletions(-) create mode 100644 tests/modules/reranker/test_rerank.py diff --git a/caikit_nlp/data_model/reranker.py b/caikit_nlp/data_model/reranker.py index 2ca11b22..48692afd 100644 --- a/caikit_nlp/data_model/reranker.py +++ b/caikit_nlp/data_model/reranker.py @@ -12,46 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -from caikit.core.data_model.json_dict import JsonDict -from caikit.core import ( - dataobject, - DataObjectBase, -) - -from typing import List, Dict +# Standard +from dataclasses import dataclass +from typing import List - -@dataobject() -class RerankDocument(DataObjectBase): - """An input document with key of text else _text else empty string used for comparison""" - document: Dict[str, str] # TODO: get any JsonDict working for input +# First Party +from caikit.core import DataObjectBase, dataobject +from caikit.core.data_model.json_dict import JsonDict -@dataobject() +@dataobject(package="caikit_data_model.caikit_nlp") +@dataclass class RerankDocuments(DataObjectBase): - """An input list of documents""" - documents: List[RerankDocument] + """An input list of JSON documents""" - @classmethod - def from_proto(cls, proto): - return cls([{"document": dict(d.document.items())} for d in proto.documents]) + documents: List[JsonDict] -@dataobject() +@dataobject(package="caikit_data_model.caikit_nlp") +@dataclass class RerankScore(DataObjectBase): """The score for one document (one query)""" + document: JsonDict corpus_id: int score: float -@dataobject() +@dataobject(package="caikit_data_model.caikit_nlp") +@dataclass class RerankQueryResult(DataObjectBase): """Result for one query in a rerank task""" + scores: List[RerankScore] -@dataobject() +@dataobject(package="caikit_data_model.caikit_nlp") +@dataclass class RerankPrediction(DataObjectBase): """Result for a rerank task""" + results: List[RerankQueryResult] diff --git a/caikit_nlp/modules/reranker/__init__.py b/caikit_nlp/modules/reranker/__init__.py index a3375577..e5a6d2bd 100644 --- a/caikit_nlp/modules/reranker/__init__.py +++ b/caikit_nlp/modules/reranker/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Local from .rerank import Rerank diff --git a/caikit_nlp/modules/reranker/rerank.py b/caikit_nlp/modules/reranker/rerank.py index d8f5d2d1..08843ef0 100644 --- a/caikit_nlp/modules/reranker/rerank.py +++ b/caikit_nlp/modules/reranker/rerank.py @@ -13,28 +13,30 @@ # limitations under the License. # Standard -from pathlib import Path -from typing import List +from typing import List, Type, Optional import os # Third Party from sentence_transformers import SentenceTransformer -from sentence_transformers.util import semantic_search, normalize_embeddings, dot_score +from sentence_transformers.util import dot_score, normalize_embeddings, semantic_search # First Party from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module from caikit.core.exceptions import error_handler +from caikit.core.data_model.json_dict import JsonDict import alog -from caikit_nlp.data_model.reranker import RerankPrediction, RerankQueryResult, RerankScore, RerankDocuments +# Local from .rerank_task import RerankTask +from caikit_nlp.data_model.reranker import ( + RerankPrediction, + RerankQueryResult, + RerankScore, +) logger = alog.use_channel("") error = error_handler.get(logger) -HOME = Path.home() -DEFAULT_HF_MODEL = "sentence-transformers/all-MiniLM-L6-v2" - @module( "00110203-0405-0607-0809-0a0b02dd0e0f", @@ -44,9 +46,13 @@ ) class Rerank(ModuleBase): + _MODEL_ARTIFACTS_CONFIG_KEY = "artifacts_path" + _MODEL_ARTIFACTS_CONFIG_DEFAULT = "artifacts" + _MODEL_HF_HUB_KEY = "hf_model" + def __init__( - self, - model: SentenceTransformer, + self, + model: SentenceTransformer, ): """Initialize This function gets called by `.load` and `.train` function @@ -56,7 +62,7 @@ def __init__( self.model = model @classmethod - def load(cls, model_path: str) -> "Rerank": + def load(cls, model_path: str) -> Type["Rerank"]: """Load a model Args: @@ -69,43 +75,56 @@ def load(cls, model_path: str) -> "Rerank": """ config = ModuleConfig.load(model_path) - load_path = config.get("model_artifacts") - - artifact_path = False - if load_path: - if os.path.isabs(load_path) and os.path.isdir(load_path): - artifact_path = load_path - else: - full_path = os.path.join(model_path, load_path) - if os.path.isdir(full_path): - artifact_path = full_path - - if not artifact_path: - artifact_path = config.get("hf_model", DEFAULT_HF_MODEL) - return cls.bootstrap(artifact_path) - - def run(self, queries: List[str], documents: RerankDocuments, top_n: int = 10) -> RerankPrediction: + artifacts_path = config.get(cls._MODEL_ARTIFACTS_CONFIG_KEY) + if artifacts_path: + # artifacts_path is used to find the model artifacts (can be absolute or relative to model config) + model_name_or_path = os.path.abspath(os.path.join(model_path, artifacts_path)) + error.dir_check("", model_name_or_path) + else: + # If no artifacts_path, look for hf_model Hugging Face model by name (or path) + model_name_or_path = config.get(cls._MODEL_HF_HUB_KEY) + error.value_check( + "", + model_name_or_path, + ValueError(f"Model config missing '{cls._MODEL_ARTIFACTS_CONFIG_KEY}' or '{cls._MODEL_HF_HUB_KEY}'") + ) + + return cls.bootstrap(model_name_or_path=model_name_or_path) + + def run( + self, queries: List[str], + documents: List[JsonDict], + top_n: Optional[int] = None, + ) -> RerankPrediction: """Run inference on model. Args: queries: List[str] - documents: RerankDocuments - top_n: int + documents: List[JsonDict] + top_n: Optional[int] Returns: RerankPrediction """ + error.type_check( + "", + list, + queries=queries, documents=documents, + ) + if len(queries) < 1: return RerankPrediction() - if len(documents.documents) < 1: + if len(documents) < 1: return RerankPrediction() - if top_n < 1: - top_n = 10 # Default to 10 (instead of JSON default 0) + if top_n is None or top_n < 1: + top_n = len(documents) # Using input document dicts so get "text" else "_text" else default to "" - doc_texts = [srd.document.get("text") or srd.document.get("_text", "") for srd in documents.documents] + doc_texts = [ + srd.get("text") or srd.get("_text", "") for srd in documents + ] doc_embeddings = self.model.encode(doc_texts, convert_to_tensor=True) doc_embeddings = doc_embeddings.to(self.model.device) @@ -115,33 +134,29 @@ def run(self, queries: List[str], documents: RerankDocuments, top_n: int = 10) - query_embeddings = query_embeddings.to(self.model.device) query_embeddings = normalize_embeddings(query_embeddings) - res = semantic_search(query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score) + res = semantic_search( + query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score + ) for r in res: for x in r: - x['document'] = documents.documents[x['corpus_id']].document + x["document"] = documents[x["corpus_id"]] # TODO: .document results = [RerankQueryResult([RerankScore(**x) for x in r]) for r in res] return RerankPrediction(results=results) @classmethod - def bootstrap(cls, base_model_path: str) -> "Rerank": + def bootstrap(cls, model_name_or_path: str) -> "Rerank": """Bootstrap a sentence-transformers model Args: - base_model_path: str - Path to the model to be loaded. + model_name_or_path: str + Model name (Hugging Face hub) or path to model to load. """ - model = SentenceTransformer( - base_model_path, - cache_folder=f"{HOME}/.cache/huggingface/sentence_transformers", - ) - return cls( - model=model, - ) + return cls(model=(SentenceTransformer(model_name_or_path=model_name_or_path))) - def save(self, model_path: str): + def save(self, model_path: str, *args, **kwargs): """Save model in target path Args: @@ -153,8 +168,12 @@ def save(self, model_path: str): model_path=model_path, ) - # Extract object to be saved + # Save artifacts with saver: - saver.update_config({"model_artifacts": "."}) + artifacts_path = saver.config.get(self._MODEL_ARTIFACTS_CONFIG_KEY) + if not artifacts_path: + artifacts_path = self._MODEL_ARTIFACTS_CONFIG_DEFAULT + saver.update_config({self._MODEL_ARTIFACTS_CONFIG_KEY: artifacts_path}) if self.model: # This condition allows for empty placeholders - self.model.save(model_path) + artifacts_abspath = os.path.abspath(os.path.join(model_path, artifacts_path)) + self.model.save(artifacts_abspath, create_model_card=True) diff --git a/caikit_nlp/modules/reranker/rerank_task.py b/caikit_nlp/modules/reranker/rerank_task.py index 9dbefd42..2f2e5f93 100644 --- a/caikit_nlp/modules/reranker/rerank_task.py +++ b/caikit_nlp/modules/reranker/rerank_task.py @@ -12,19 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import alog +# Standard +from typing import List + +# First Party from caikit.core import TaskBase, task from caikit.core.exceptions import error_handler -from caikit_nlp.data_model.reranker import RerankPrediction, RerankDocuments +import alog -from typing import List +# Local +from caikit_nlp.data_model.reranker import RerankPrediction +from caikit.core.data_model.json_dict import JsonDict logger = alog.use_channel("") error = error_handler.get(logger) + @task( required_parameters={ - "documents": RerankDocuments, + "documents": List[JsonDict], "queries": List[str], }, output_type=RerankPrediction, diff --git a/tests/data_model/test_reranker.py b/tests/data_model/test_reranker.py index 2e971403..bf502d46 100644 --- a/tests/data_model/test_reranker.py +++ b/tests/data_model/test_reranker.py @@ -14,29 +14,33 @@ """Test for reranker """ -# Local -from caikit_nlp import data_model as dm - -from typing import Dict +# Standard import random import string +# Local +from caikit_nlp import data_model as dm + ## Setup ######################################################################### input_document = { "text": "this is the input text", + "_text": "alternate _text here", "title": "some title attribute here", "anything": "another string attribute", - "_text": "alternate _text here", + "str_test": "test string", + "int_test": 1234, + "float_test": 9876.4321, } # TODO: don't use random (could flake). Make sure to use a working document. -key = ''.join(random.choices(string.ascii_letters, k=20)) -value = ''.join(random.choices(string.printable, k=100)) -print(key) -print(value) +key = "".join(random.choices(string.ascii_letters, k=20)) +value = "".join(random.choices(string.printable, k=100)) input_random_document = { - key: value + "text": "".join(random.choices(string.printable, k=100)), + "random_str": "".join(random.choices(string.printable, k=100)), + "random_int": random.randint(-99999, 99999), + "random_float": random.uniform(-99999, 99999), } input_documents = [input_document, input_random_document] @@ -45,45 +49,18 @@ ## Tests ######################################################################## -def _compare_rerank_document_and_dict(rerank_doc: dm.RerankDocument, d: Dict): - assert isinstance(rerank_doc, dm.RerankDocument) - assert isinstance(rerank_doc.document, Dict) - assert isinstance(rerank_doc.document["text"], str) - assert rerank_doc.document["text"] == d["text"] - assert rerank_doc.document == d - - -def test_rerank_document(): - new_dm_from_init = dm.RerankDocument(document=input_document) - - _compare_rerank_document_and_dict(new_dm_from_init, input_document) - - # Test proto - proto_from_dm = new_dm_from_init.to_proto() - new_dm_from_proto = dm.RerankDocument.from_proto(proto_from_dm) - _compare_rerank_document_and_dict(new_dm_from_proto, input_document) - - # Test json - json_from_dm = new_dm_from_init.to_json() - new_dm_from_json = dm.RerankDocument.from_json(json_from_dm) - _compare_rerank_document_and_dict(new_dm_from_json, input_document) - - def test_rerank_documents(): - in_docs = [dm.RerankDocument(doc) for doc in input_documents] + in_docs = input_documents new_dm_from_init = dm.RerankDocuments(documents=in_docs) assert isinstance(new_dm_from_init, dm.RerankDocuments) - out_docs = [d.document for d in new_dm_from_init.documents] - assert out_docs == input_documents + assert new_dm_from_init.documents == input_documents # Test proto proto_from_dm = new_dm_from_init.to_proto() new_dm_from_proto = dm.RerankDocuments.from_proto(proto_from_dm) - out_docs = [d["document"] for d in new_dm_from_proto.documents] - assert out_docs == input_documents + assert new_dm_from_proto.documents == input_documents # Test json json_from_dm = new_dm_from_init.to_json() new_dm_from_json = dm.RerankDocuments.from_json(json_from_dm) - out_docs = [d["document"] for d in new_dm_from_json.documents] - assert out_docs == input_documents + assert new_dm_from_json.documents == input_documents diff --git a/tests/modules/reranker/test_rerank.py b/tests/modules/reranker/test_rerank.py new file mode 100644 index 00000000..2b473196 --- /dev/null +++ b/tests/modules/reranker/test_rerank.py @@ -0,0 +1,124 @@ +"""Tests for sequence classification module +""" +# Standard +import tempfile +from typing import List + +import pytest + +from caikit_nlp import RerankQueryResult, RerankScore +# Local +from caikit_nlp.data_model import RerankPrediction +from caikit_nlp.modules.reranker import Rerank +from tests.fixtures import SEQ_CLASS_MODEL + +## Setup ######################################################################## + +# Bootstrapped sequence classification model for reusability across tests +# .bootstrap is tested separately in the first test +BOOTSTRAPPED_MODEL = Rerank.bootstrap(SEQ_CLASS_MODEL) + +QUERIES: List[str] = [ + "Who is foo?", + "Where is the bar?", +] + +DOCS = [ + { + "text": "foo", + "title": "title or whatever", + "str_test": "test string", + "int_test": 1, + "float_test": 1.11, + "score": 99999, + }, + { + "_text": "bar", + "title": "title 2", + }, + { + "text": "foo and bar", + }, + { + "_text": "Where is the bar", + "another": "something else", + }, +] + +## Tests ######################################################################## + + +def test_bootstrap(): + assert isinstance(BOOTSTRAPPED_MODEL, Rerank), "bootstrap error" + + +@pytest.mark.parametrize("queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})]) +def test_run_type_error(queries, docs): + """type error check ensures params are lists and not just 1 string or just one doc (for example)""" + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.run(queries=queries, documents=docs) + pytest.fail("Should not reach here.") + + +def test_run_no_type_error(): + """no type error with list of string queries and list of dict documents""" + BOOTSTRAPPED_MODEL.run(queries=QUERIES, documents=DOCS) + + +@pytest.mark.parametrize("top_n, expected", [(1, 1), (2, 2), (None, len(DOCS)), (-1, len(DOCS)), (0, len(DOCS)), (9999, len(DOCS))]) +def test_run_top_n(top_n, expected): + """no type error with list of string queries and list of dict documents""" + res = BOOTSTRAPPED_MODEL.run(queries=QUERIES, documents=DOCS, top_n=top_n) + assert isinstance(res, RerankPrediction) + assert len(res.results) == len(QUERIES) + for result in res.results: + assert len(result.scores) == expected + + +def test_save_and_load_and_run_model(): + """Save and load and run a model""" + with tempfile.TemporaryDirectory() as model_dir: + BOOTSTRAPPED_MODEL.save(model_dir) + new_model = Rerank.load(model_dir) + + assert isinstance(new_model, Rerank), "save and load error" + assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model" + + top_n = 2 + rerank_result = new_model.run(queries=QUERIES, documents=DOCS, top_n=top_n) + assert isinstance(rerank_result, RerankPrediction) + + results = rerank_result.results + assert isinstance(results, list) + assert ( + len(results) == 2 == len(QUERIES) + ) # 2 queries yields 2 result(s) + + # Collect some of the pass-through extras to verify we can do some types + str_test = None + int_test = None + float_test = None + + for result in results: + assert isinstance(result, RerankQueryResult) + scores = result.scores + assert isinstance(scores, list) + assert len(scores) == top_n + for score in scores: + assert isinstance(score, RerankScore) + assert isinstance(score.score, float) + assert isinstance(score.corpus_id, int) + assert score.document == DOCS[score.corpus_id] + + # Test pass-through score (None or 9999) is independent of the result score + assert score.score != score.document.get("score"), "unexpected passthru score same as result score" + + # Gather various type test values + str_test = score.document.get("str_test", str_test) + int_test = score.document.get("int_test", int_test) + float_test = score.document.get("float_test", float_test) + + assert type(str_test) == str, "passthru str value type check" + assert type(int_test) == int, "passthru int value type check" + assert type(float_test) == float, "passthru float value type check" + From 956e3942a936a0b5a2f154c811171771df96c4f4 Mon Sep 17 00:00:00 2001 From: markstur Date: Fri, 13 Oct 2023 23:35:07 -0700 Subject: [PATCH 03/17] More tests. Updates from embeddings feedback. * Tests * Work on save() Signed-off-by: markstur --- caikit_nlp/modules/reranker/rerank.py | 87 +++++++++++++--------- caikit_nlp/modules/reranker/rerank_task.py | 2 +- tests/data_model/test_reranker.py | 79 ++++++++++++++++---- tests/modules/reranker/test_rerank.py | 40 +++++++--- 4 files changed, 146 insertions(+), 62 deletions(-) diff --git a/caikit_nlp/modules/reranker/rerank.py b/caikit_nlp/modules/reranker/rerank.py index 08843ef0..a40676e1 100644 --- a/caikit_nlp/modules/reranker/rerank.py +++ b/caikit_nlp/modules/reranker/rerank.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import List, Type, Optional +from typing import List, Optional import os # Third Party @@ -22,8 +22,8 @@ # First Party from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module -from caikit.core.exceptions import error_handler from caikit.core.data_model.json_dict import JsonDict +from caikit.core.exceptions import error_handler import alog # Local @@ -46,9 +46,9 @@ ) class Rerank(ModuleBase): - _MODEL_ARTIFACTS_CONFIG_KEY = "artifacts_path" - _MODEL_ARTIFACTS_CONFIG_DEFAULT = "artifacts" - _MODEL_HF_HUB_KEY = "hf_model" + _ARTIFACTS_PATH_KEY = "artifacts_path" + _ARTIFACTS_PATH_DEFAULT = "artifacts" + _HF_HUB_KEY = "hf_model" def __init__( self, @@ -62,7 +62,7 @@ def __init__( self.model = model @classmethod - def load(cls, model_path: str) -> Type["Rerank"]: + def load(cls, model_path: str) -> "Rerank": """Load a model Args: @@ -76,26 +76,30 @@ def load(cls, model_path: str) -> Type["Rerank"]: config = ModuleConfig.load(model_path) - artifacts_path = config.get(cls._MODEL_ARTIFACTS_CONFIG_KEY) + artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY) if artifacts_path: - # artifacts_path is used to find the model artifacts (can be absolute or relative to model config) - model_name_or_path = os.path.abspath(os.path.join(model_path, artifacts_path)) + model_name_or_path = os.path.abspath( + os.path.join(model_path, artifacts_path) + ) error.dir_check("", model_name_or_path) else: # If no artifacts_path, look for hf_model Hugging Face model by name (or path) - model_name_or_path = config.get(cls._MODEL_HF_HUB_KEY) + model_name_or_path = config.get(cls._HF_HUB_KEY) error.value_check( "", model_name_or_path, - ValueError(f"Model config missing '{cls._MODEL_ARTIFACTS_CONFIG_KEY}' or '{cls._MODEL_HF_HUB_KEY}'") + ValueError( + f"Model config missing '{cls._ARTIFACTS_CONFIG_KEY}' or '{cls._HF_HUB_KEY}'" + ), ) return cls.bootstrap(model_name_or_path=model_name_or_path) def run( - self, queries: List[str], - documents: List[JsonDict], - top_n: Optional[int] = None, + self, + queries: List[str], + documents: List[JsonDict], + top_n: Optional[int] = None, ) -> RerankPrediction: """Run inference on model. Args: @@ -109,22 +113,18 @@ def run( error.type_check( "", list, - queries=queries, documents=documents, + queries=queries, + documents=documents, ) - if len(queries) < 1: - return RerankPrediction() - - if len(documents) < 1: - return RerankPrediction() + if len(queries) < 1 or len(documents) < 1: + return RerankPrediction([]) if top_n is None or top_n < 1: top_n = len(documents) # Using input document dicts so get "text" else "_text" else default to "" - doc_texts = [ - srd.get("text") or srd.get("_text", "") for srd in documents - ] + doc_texts = [srd.get("text") or srd.get("_text", "") for srd in documents] doc_embeddings = self.model.encode(doc_texts, convert_to_tensor=True) doc_embeddings = doc_embeddings.to(self.model.device) @@ -140,7 +140,7 @@ def run( for r in res: for x in r: - x["document"] = documents[x["corpus_id"]] # TODO: .document + x["document"] = documents[x["corpus_id"]] results = [RerankQueryResult([RerankScore(**x) for x in r]) for r in res] @@ -154,26 +154,45 @@ def bootstrap(cls, model_name_or_path: str) -> "Rerank": model_name_or_path: str Model name (Hugging Face hub) or path to model to load. """ - return cls(model=(SentenceTransformer(model_name_or_path=model_name_or_path))) + return cls(model=SentenceTransformer(model_name_or_path=model_name_or_path)) def save(self, model_path: str, *args, **kwargs): - """Save model in target path + """Save model using config in model_path Args: model_path: str - Path to store model artifact(s) + Path to model config """ + error.type_check("", str, model_path=model_path) + error.value_check( + "", + model_path is not None and model_path.strip(), + f"model_path '{model_path}' is invalid", + ) + + model_path = os.path.abspath( + model_path.strip() + ) # No leading/trailing spaces sneaky weirdness + + if os.path.exists(model_path): + error( + "", + FileExistsError(f"model_path '{model_path}' already exists"), + ) + saver = ModuleSaver( - self, + module=self, model_path=model_path, ) - # Save artifacts + # Save update config (artifacts_path) and save artifacts with saver: - artifacts_path = saver.config.get(self._MODEL_ARTIFACTS_CONFIG_KEY) + artifacts_path = saver.config.get(self._ARTIFACTS_PATH_KEY) if not artifacts_path: - artifacts_path = self._MODEL_ARTIFACTS_CONFIG_DEFAULT - saver.update_config({self._MODEL_ARTIFACTS_CONFIG_KEY: artifacts_path}) + artifacts_path = self._ARTIFACTS_PATH_DEFAULT + saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path}) if self.model: # This condition allows for empty placeholders - artifacts_abspath = os.path.abspath(os.path.join(model_path, artifacts_path)) - self.model.save(artifacts_abspath, create_model_card=True) + artifacts_path = os.path.abspath( + os.path.join(model_path, artifacts_path) + ) + self.model.save(artifacts_path, create_model_card=True) diff --git a/caikit_nlp/modules/reranker/rerank_task.py b/caikit_nlp/modules/reranker/rerank_task.py index 2f2e5f93..9c7379b0 100644 --- a/caikit_nlp/modules/reranker/rerank_task.py +++ b/caikit_nlp/modules/reranker/rerank_task.py @@ -17,12 +17,12 @@ # First Party from caikit.core import TaskBase, task +from caikit.core.data_model.json_dict import JsonDict from caikit.core.exceptions import error_handler import alog # Local from caikit_nlp.data_model.reranker import RerankPrediction -from caikit.core.data_model.json_dict import JsonDict logger = alog.use_channel("") error = error_handler.get(logger) diff --git a/tests/data_model/test_reranker.py b/tests/data_model/test_reranker.py index bf502d46..f20df204 100644 --- a/tests/data_model/test_reranker.py +++ b/tests/data_model/test_reranker.py @@ -18,6 +18,9 @@ import random import string +# Third Party +import pytest + # Local from caikit_nlp import data_model as dm @@ -33,7 +36,6 @@ "float_test": 9876.4321, } -# TODO: don't use random (could flake). Make sure to use a working document. key = "".join(random.choices(string.ascii_letters, k=20)) value = "".join(random.choices(string.printable, k=100)) input_random_document = { @@ -44,23 +46,70 @@ } input_documents = [input_document, input_random_document] -input_queries = [] + +input_score = { + "document": input_document, + "corpus_id": 1234, + "score": 9876.54321, +} + +input_random_score = { + "document": input_random_document, + "corpus_id": random.randint(-99999, 99999), + "score": random.uniform(-99999, 99999), +} + +input_random_score_3 = { + "document": {"text": "random foo3"}, + "corpus_id": random.randint(-99999, 99999), + "score": random.uniform(-99999, 99999), +} + +input_scores = [dm.RerankScore(**input_score), dm.RerankScore(**input_random_score)] +input_scores2 = [ + dm.RerankScore(**input_random_score), + dm.RerankScore(**input_random_score_3), +] +input_results = [ + dm.RerankQueryResult(scores=input_scores), + dm.RerankQueryResult(scores=input_scores2), +] + ## Tests ######################################################################## -def test_rerank_documents(): - in_docs = input_documents - new_dm_from_init = dm.RerankDocuments(documents=in_docs) - assert isinstance(new_dm_from_init, dm.RerankDocuments) - assert new_dm_from_init.documents == input_documents +@pytest.mark.parametrize( + "data_object, inputs", + [ + (dm.RerankDocuments, {"documents": input_documents}), + (dm.RerankScore, input_score), + (dm.RerankScore, input_random_score), + (dm.RerankQueryResult, {"scores": input_scores}), + (dm.RerankPrediction, {"results": input_results}), + ], +) +def test_data_object(data_object, inputs): + # Init data object + new_do_from_init = data_object(**inputs) + assert isinstance(new_do_from_init, data_object) + assert_fields_match(new_do_from_init, inputs) + + # Test to/from proto + proto_from_dm = new_do_from_init.to_proto() + new_do_from_proto = data_object.from_proto(proto_from_dm) + assert isinstance(new_do_from_proto, data_object) + assert_fields_match(new_do_from_proto, inputs) + assert new_do_from_init == new_do_from_proto + + # Test to/from json + json_from_dm = new_do_from_init.to_json() + new_do_from_json = data_object.from_json(json_from_dm) + assert isinstance(new_do_from_json, data_object) + assert_fields_match(new_do_from_json, inputs) + assert new_do_from_init == new_do_from_json - # Test proto - proto_from_dm = new_dm_from_init.to_proto() - new_dm_from_proto = dm.RerankDocuments.from_proto(proto_from_dm) - assert new_dm_from_proto.documents == input_documents - # Test json - json_from_dm = new_dm_from_init.to_json() - new_dm_from_json = dm.RerankDocuments.from_json(json_from_dm) - assert new_dm_from_json.documents == input_documents +def assert_fields_match(data_object, inputs): + for k, v in inputs.items(): + assert getattr(data_object, k) == inputs[k] diff --git a/tests/modules/reranker/test_rerank.py b/tests/modules/reranker/test_rerank.py index 2b473196..8a2cb4bf 100644 --- a/tests/modules/reranker/test_rerank.py +++ b/tests/modules/reranker/test_rerank.py @@ -1,13 +1,15 @@ """Tests for sequence classification module """ # Standard -import tempfile from typing import List +import os +import tempfile +# Third Party import pytest -from caikit_nlp import RerankQueryResult, RerankScore # Local +from caikit_nlp import RerankQueryResult, RerankScore from caikit_nlp.data_model import RerankPrediction from caikit_nlp.modules.reranker import Rerank from tests.fixtures import SEQ_CLASS_MODEL @@ -52,7 +54,9 @@ def test_bootstrap(): assert isinstance(BOOTSTRAPPED_MODEL, Rerank), "bootstrap error" -@pytest.mark.parametrize("queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})]) +@pytest.mark.parametrize( + "queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})] +) def test_run_type_error(queries, docs): """type error check ensures params are lists and not just 1 string or just one doc (for example)""" with pytest.raises(TypeError): @@ -65,7 +69,17 @@ def test_run_no_type_error(): BOOTSTRAPPED_MODEL.run(queries=QUERIES, documents=DOCS) -@pytest.mark.parametrize("top_n, expected", [(1, 1), (2, 2), (None, len(DOCS)), (-1, len(DOCS)), (0, len(DOCS)), (9999, len(DOCS))]) +@pytest.mark.parametrize( + "top_n, expected", + [ + (1, 1), + (2, 2), + (None, len(DOCS)), + (-1, len(DOCS)), + (0, len(DOCS)), + (9999, len(DOCS)), + ], +) def test_run_top_n(top_n, expected): """no type error with list of string queries and list of dict documents""" res = BOOTSTRAPPED_MODEL.run(queries=QUERIES, documents=DOCS, top_n=top_n) @@ -77,9 +91,12 @@ def test_run_top_n(top_n, expected): def test_save_and_load_and_run_model(): """Save and load and run a model""" - with tempfile.TemporaryDirectory() as model_dir: - BOOTSTRAPPED_MODEL.save(model_dir) - new_model = Rerank.load(model_dir) + + model_id = "model_id" + with tempfile.TemporaryDirectory(suffix="-1st") as model_dir: + model_path = os.path.join(model_dir, model_id) + BOOTSTRAPPED_MODEL.save(model_path) + new_model = Rerank.load(model_path) assert isinstance(new_model, Rerank), "save and load error" assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model" @@ -90,9 +107,7 @@ def test_save_and_load_and_run_model(): results = rerank_result.results assert isinstance(results, list) - assert ( - len(results) == 2 == len(QUERIES) - ) # 2 queries yields 2 result(s) + assert len(results) == 2 == len(QUERIES) # 2 queries yields 2 result(s) # Collect some of the pass-through extras to verify we can do some types str_test = None @@ -111,7 +126,9 @@ def test_save_and_load_and_run_model(): assert score.document == DOCS[score.corpus_id] # Test pass-through score (None or 9999) is independent of the result score - assert score.score != score.document.get("score"), "unexpected passthru score same as result score" + assert score.score != score.document.get( + "score" + ), "unexpected passthru score same as result score" # Gather various type test values str_test = score.document.get("str_test", str_test) @@ -121,4 +138,3 @@ def test_save_and_load_and_run_model(): assert type(str_test) == str, "passthru str value type check" assert type(int_test) == int, "passthru int value type check" assert type(float_test) == float, "passthru float value type check" - From e7e012ba650d22e34f197cf0972db58d0d324a23 Mon Sep 17 00:00:00 2001 From: markstur Date: Sun, 15 Oct 2023 23:09:02 -0700 Subject: [PATCH 04/17] Fix wrong var in error message. Add tests for coverage. * Error message had wrong var in f-string message * Added test to catch that mistake * Added save tests and empty queries/docs test to complete coverage Signed-off-by: markstur --- caikit_nlp/modules/reranker/rerank.py | 2 +- tests/modules/reranker/test_rerank.py | 59 +++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/caikit_nlp/modules/reranker/rerank.py b/caikit_nlp/modules/reranker/rerank.py index a40676e1..e35a12d6 100644 --- a/caikit_nlp/modules/reranker/rerank.py +++ b/caikit_nlp/modules/reranker/rerank.py @@ -89,7 +89,7 @@ def load(cls, model_path: str) -> "Rerank": "", model_name_or_path, ValueError( - f"Model config missing '{cls._ARTIFACTS_CONFIG_KEY}' or '{cls._HF_HUB_KEY}'" + f"Model config missing '{cls._ARTIFACTS_PATH_KEY}' or '{cls._HF_HUB_KEY}'" ), ) diff --git a/tests/modules/reranker/test_rerank.py b/tests/modules/reranker/test_rerank.py index 8a2cb4bf..a02ff8dc 100644 --- a/tests/modules/reranker/test_rerank.py +++ b/tests/modules/reranker/test_rerank.py @@ -8,6 +8,9 @@ # Third Party import pytest +# First Party +from caikit.core import ModuleConfig + # Local from caikit_nlp import RerankQueryResult, RerankScore from caikit_nlp.data_model import RerankPrediction @@ -89,6 +92,22 @@ def test_run_top_n(top_n, expected): assert len(result.scores) == expected +@pytest.mark.parametrize( + "queries, docs", + [ + ([], DOCS), + (QUERIES, []), + ([], []), + ], + ids=["no queries", "no docs", "no queries and no docs"], +) +def test_run_no_queries_or_no_docs(queries, docs): + """No queries and/or no docs therefore result is zero results""" + res = BOOTSTRAPPED_MODEL.run(queries=queries, documents=docs, top_n=9) + assert isinstance(res, RerankPrediction) + assert len(res.results) == 0 + + def test_save_and_load_and_run_model(): """Save and load and run a model""" @@ -138,3 +157,43 @@ def test_save_and_load_and_run_model(): assert type(str_test) == str, "passthru str value type check" assert type(int_test) == int, "passthru int value type check" assert type(float_test) == float, "passthru float value type check" + + +@pytest.mark.parametrize( + "model_path", ["", " ", " " * 100], ids=["empty", "space", "spaces"] +) +def test_save_value_checks(model_path): + with pytest.raises(ValueError): + BOOTSTRAPPED_MODEL.save(model_path) + + +@pytest.mark.parametrize( + "model_path", + ["..", "../" * 100, "/", ".", " / ", " . "], +) +def test_save_exists_checks(model_path): + """Tests for model paths are always existing dirs that should not be clobbered""" + with pytest.raises(FileExistsError): + BOOTSTRAPPED_MODEL.save(model_path) + + +def test_second_save_hits_exists_check(): + """Using a new path the first save should succeed but second fails""" + model_id = "model_id" + with tempfile.TemporaryDirectory(suffix="-2nd") as model_dir: + model_path = os.path.join(model_dir, model_id) + BOOTSTRAPPED_MODEL.save(model_path) + with pytest.raises(FileExistsError): + BOOTSTRAPPED_MODEL.save(model_path) + + +@pytest.mark.parametrize("model_path", [None, {}, object(), 1], ids=type) +def test_save_type_checks(model_path): + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.save(model_path) + + +def test_load_without_artifacts_or_hf_model(): + """Test coverage for the error message when config has no artifacts and no hf_model to load""" + with pytest.raises(ValueError): + Rerank.load(ModuleConfig({})) From d2689b150bce95298dc8b57e7e9bd7d23e2397ea Mon Sep 17 00:00:00 2001 From: markstur Date: Mon, 16 Oct 2023 16:57:23 -0700 Subject: [PATCH 05/17] Make rerank run() only single query * rerank run() will only do one query * adding reranks run_queries() for multiple queries with multi-task (coming soon) Signed-off-by: markstur --- caikit_nlp/modules/reranker/rerank.py | 51 +++++- caikit_nlp/modules/reranker/rerank_task.py | 15 +- tests/modules/reranker/test_rerank.py | 106 +++++------ tests/modules/reranker/test_reranks.py | 199 +++++++++++++++++++++ 4 files changed, 308 insertions(+), 63 deletions(-) create mode 100644 tests/modules/reranker/test_reranks.py diff --git a/caikit_nlp/modules/reranker/rerank.py b/caikit_nlp/modules/reranker/rerank.py index e35a12d6..0f93a7c7 100644 --- a/caikit_nlp/modules/reranker/rerank.py +++ b/caikit_nlp/modules/reranker/rerank.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # Standard from typing import List, Optional import os @@ -96,6 +95,47 @@ def load(cls, model_path: str) -> "Rerank": return cls.bootstrap(model_name_or_path=model_name_or_path) def run( + self, + query: str, + documents: List[JsonDict], + top_n: Optional[int] = None, + ) -> RerankQueryResult: + """Run inference on model. + Args: + query: str + documents: List[JsonDict] + top_n: Optional[int] + Returns: + RerankQueryResult + """ + + error.type_check( + "", + str, + query=query, + ) + + error.type_check( + "", + list, + documents=documents, + ) + + batch_results = self.run_queries( + queries=[query], documents=documents, top_n=top_n + ) + results = batch_results.results + + if len(results) > 1: + error( + "", + ValueError(f"expected single query result, but got {len(results)}"), + ) + + # Return the single result, creating an empty one if results was empty + return results[0] if len(results) == 1 else RerankQueryResult([]) + + def run_queries( self, queries: List[str], documents: List[JsonDict], @@ -117,10 +157,17 @@ def run( documents=documents, ) + error.type_check( + "", + int, + allow_none=True, + top_n=top_n, + ) + if len(queries) < 1 or len(documents) < 1: return RerankPrediction([]) - if top_n is None or top_n < 1: + if top_n is None or int(top_n) < 1: top_n = len(documents) # Using input document dicts so get "text" else "_text" else default to "" diff --git a/caikit_nlp/modules/reranker/rerank_task.py b/caikit_nlp/modules/reranker/rerank_task.py index 9c7379b0..263c13d7 100644 --- a/caikit_nlp/modules/reranker/rerank_task.py +++ b/caikit_nlp/modules/reranker/rerank_task.py @@ -22,12 +22,23 @@ import alog # Local -from caikit_nlp.data_model.reranker import RerankPrediction +from caikit_nlp.data_model.reranker import RerankPrediction, RerankQueryResult logger = alog.use_channel("") error = error_handler.get(logger) +@task( + required_parameters={ + "documents": List[JsonDict], + "query": str, + }, + output_type=RerankQueryResult, +) +class RerankTask(TaskBase): + pass + + @task( required_parameters={ "documents": List[JsonDict], @@ -35,5 +46,5 @@ }, output_type=RerankPrediction, ) -class RerankTask(TaskBase): +class ReranksTask(TaskBase): pass diff --git a/tests/modules/reranker/test_rerank.py b/tests/modules/reranker/test_rerank.py index a02ff8dc..009687e8 100644 --- a/tests/modules/reranker/test_rerank.py +++ b/tests/modules/reranker/test_rerank.py @@ -12,8 +12,7 @@ from caikit.core import ModuleConfig # Local -from caikit_nlp import RerankQueryResult, RerankScore -from caikit_nlp.data_model import RerankPrediction +from caikit_nlp.data_model import RerankQueryResult, RerankScore from caikit_nlp.modules.reranker import Rerank from tests.fixtures import SEQ_CLASS_MODEL @@ -23,10 +22,7 @@ # .bootstrap is tested separately in the first test BOOTSTRAPPED_MODEL = Rerank.bootstrap(SEQ_CLASS_MODEL) -QUERIES: List[str] = [ - "Who is foo?", - "Where is the bar?", -] +QUERY = "What is foo bar?" DOCS = [ { @@ -34,7 +30,7 @@ "title": "title or whatever", "str_test": "test string", "int_test": 1, - "float_test": 1.11, + "float_test": 1.234, "score": 99999, }, { @@ -58,18 +54,25 @@ def test_bootstrap(): @pytest.mark.parametrize( - "queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})] + "query,docs,top_n", + [ + (["test list"], DOCS, None), + (None, DOCS, 1234), + (False, DOCS, 1234), + (QUERY, {"testdict": "not list"}, 1234), + (QUERY, DOCS, "topN string is not an integer or None"), + ], ) -def test_run_type_error(queries, docs): - """type error check ensures params are lists and not just 1 string or just one doc (for example)""" +def test_run_type_error(query, docs, top_n): + """test for type checks matching task/run signature""" with pytest.raises(TypeError): - BOOTSTRAPPED_MODEL.run(queries=queries, documents=docs) + BOOTSTRAPPED_MODEL.run(query=query, documents=docs, top_n=top_n) pytest.fail("Should not reach here.") def test_run_no_type_error(): """no type error with list of string queries and list of dict documents""" - BOOTSTRAPPED_MODEL.run(queries=QUERIES, documents=DOCS) + BOOTSTRAPPED_MODEL.run(query=QUERY, documents=DOCS, top_n=1) @pytest.mark.parametrize( @@ -84,28 +87,20 @@ def test_run_no_type_error(): ], ) def test_run_top_n(top_n, expected): - """no type error with list of string queries and list of dict documents""" - res = BOOTSTRAPPED_MODEL.run(queries=QUERIES, documents=DOCS, top_n=top_n) - assert isinstance(res, RerankPrediction) - assert len(res.results) == len(QUERIES) - for result in res.results: - assert len(result.scores) == expected + res = BOOTSTRAPPED_MODEL.run(query=QUERY, documents=DOCS, top_n=top_n) + assert isinstance(res, RerankQueryResult) + assert len(res.scores) == expected -@pytest.mark.parametrize( - "queries, docs", - [ - ([], DOCS), - (QUERIES, []), - ([], []), - ], - ids=["no queries", "no docs", "no queries and no docs"], -) -def test_run_no_queries_or_no_docs(queries, docs): - """No queries and/or no docs therefore result is zero results""" - res = BOOTSTRAPPED_MODEL.run(queries=queries, documents=docs, top_n=9) - assert isinstance(res, RerankPrediction) - assert len(res.results) == 0 +def test_run_no_query(): + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.run(query=None, documents=DOCS, top_n=99) + + +def test_run_zero_docs(): + """No empty doc list therefore result is zero result scores""" + result = BOOTSTRAPPED_MODEL.run(query=QUERY, documents=[], top_n=99) + assert len(result.scores) == 0 def test_save_and_load_and_run_model(): @@ -120,39 +115,32 @@ def test_save_and_load_and_run_model(): assert isinstance(new_model, Rerank), "save and load error" assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model" - top_n = 2 - rerank_result = new_model.run(queries=QUERIES, documents=DOCS, top_n=top_n) - assert isinstance(rerank_result, RerankPrediction) - - results = rerank_result.results - assert isinstance(results, list) - assert len(results) == 2 == len(QUERIES) # 2 queries yields 2 result(s) + result = new_model.run(query=QUERY, documents=DOCS) + assert isinstance(result, RerankQueryResult) # Collect some of the pass-through extras to verify we can do some types str_test = None int_test = None float_test = None - for result in results: - assert isinstance(result, RerankQueryResult) - scores = result.scores - assert isinstance(scores, list) - assert len(scores) == top_n - for score in scores: - assert isinstance(score, RerankScore) - assert isinstance(score.score, float) - assert isinstance(score.corpus_id, int) - assert score.document == DOCS[score.corpus_id] - - # Test pass-through score (None or 9999) is independent of the result score - assert score.score != score.document.get( - "score" - ), "unexpected passthru score same as result score" - - # Gather various type test values - str_test = score.document.get("str_test", str_test) - int_test = score.document.get("int_test", int_test) - float_test = score.document.get("float_test", float_test) + scores = result.scores + assert isinstance(scores, list) + assert len(scores) == len(DOCS) + for score in scores: + assert isinstance(score, RerankScore) + assert isinstance(score.score, float) + assert isinstance(score.corpus_id, int) + assert score.document == DOCS[score.corpus_id] + + # Test pass-through score (None or 9999) is independent of the result score + assert score.score != score.document.get( + "score" + ), "unexpected passthru score same as result score" + + # Gather various type test values + str_test = score.document.get("str_test", str_test) + int_test = score.document.get("int_test", int_test) + float_test = score.document.get("float_test", float_test) assert type(str_test) == str, "passthru str value type check" assert type(int_test) == int, "passthru int value type check" diff --git a/tests/modules/reranker/test_reranks.py b/tests/modules/reranker/test_reranks.py new file mode 100644 index 00000000..6e945108 --- /dev/null +++ b/tests/modules/reranker/test_reranks.py @@ -0,0 +1,199 @@ +"""Tests for sequence classification module +""" +# Standard +from typing import List +import os +import tempfile + +# Third Party +import pytest + +# First Party +from caikit.core import ModuleConfig + +# Local +from caikit_nlp import RerankQueryResult, RerankScore +from caikit_nlp.data_model import RerankPrediction +from caikit_nlp.modules.reranker import Rerank +from tests.fixtures import SEQ_CLASS_MODEL + +## Setup ######################################################################## + +# Bootstrapped sequence classification model for reusability across tests +# .bootstrap is tested separately in the first test +BOOTSTRAPPED_MODEL = Rerank.bootstrap(SEQ_CLASS_MODEL) + +QUERIES: List[str] = [ + "Who is foo?", + "Where is the bar?", +] + +DOCS = [ + { + "text": "foo", + "title": "title or whatever", + "str_test": "test string", + "int_test": 1, + "float_test": 1.11, + "score": 99999, + }, + { + "_text": "bar", + "title": "title 2", + }, + { + "text": "foo and bar", + }, + { + "_text": "Where is the bar", + "another": "something else", + }, +] + +## Tests ######################################################################## + + +def test_bootstrap(): + assert isinstance(BOOTSTRAPPED_MODEL, Rerank), "bootstrap error" + + +@pytest.mark.parametrize( + "queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})] +) +def test_run_batch_type_error(queries, docs): + """type error check ensures params are lists and not just 1 string or just one doc (for example)""" + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.run_queries(queries=queries, documents=docs) + pytest.fail("Should not reach here.") + + +def test_run_batch_no_type_error(): + """no type error with list of string queries and list of dict documents""" + BOOTSTRAPPED_MODEL.run_queries(queries=QUERIES, documents=DOCS, top_n=99) + + +@pytest.mark.parametrize( + "top_n, expected", + [ + (1, 1), + (2, 2), + (None, len(DOCS)), + (-1, len(DOCS)), + (0, len(DOCS)), + (9999, len(DOCS)), + ], +) +def test_run_batch_top_n(top_n, expected): + """no type error with list of string queries and list of dict documents""" + res = BOOTSTRAPPED_MODEL.run_queries(queries=QUERIES, documents=DOCS, top_n=top_n) + assert isinstance(res, RerankPrediction) + assert len(res.results) == len(QUERIES) + for result in res.results: + assert len(result.scores) == expected + + +@pytest.mark.parametrize( + "queries, docs", + [ + ([], DOCS), + (QUERIES, []), + ([], []), + ], + ids=["no queries", "no docs", "no queries and no docs"], +) +def test_run_batch_no_queries_or_no_docs(queries, docs): + """No queries and/or no docs therefore result is zero results""" + res = BOOTSTRAPPED_MODEL.run_queries(queries=queries, documents=docs, top_n=9) + assert isinstance(res, RerankPrediction) + assert len(res.results) == 0 + + +def test_save_and_load_and_run_batch_model(): + """Save and load and run a model""" + + model_id = "model_id" + with tempfile.TemporaryDirectory(suffix="-1st") as model_dir: + model_path = os.path.join(model_dir, model_id) + BOOTSTRAPPED_MODEL.save(model_path) + new_model = Rerank.load(model_path) + + assert isinstance(new_model, Rerank), "save and load error" + assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model" + + top_n = 2 + rerank_result = new_model.run_queries(queries=QUERIES, documents=DOCS, top_n=top_n) + assert isinstance(rerank_result, RerankPrediction) + + results = rerank_result.results + assert isinstance(results, list) + assert len(results) == 2 == len(QUERIES) # 2 queries yields 2 result(s) + + # Collect some of the pass-through extras to verify we can do some types + str_test = None + int_test = None + float_test = None + + for result in results: + assert isinstance(result, RerankQueryResult) + scores = result.scores + assert isinstance(scores, list) + assert len(scores) == top_n + for score in scores: + assert isinstance(score, RerankScore) + assert isinstance(score.score, float) + assert isinstance(score.corpus_id, int) + assert score.document == DOCS[score.corpus_id] + + # Test pass-through score (None or 9999) is independent of the result score + assert score.score != score.document.get( + "score" + ), "unexpected passthru score same as result score" + + # Gather various type test values + str_test = score.document.get("str_test", str_test) + int_test = score.document.get("int_test", int_test) + float_test = score.document.get("float_test", float_test) + + assert type(str_test) == str, "passthru str value type check" + assert type(int_test) == int, "passthru int value type check" + assert type(float_test) == float, "passthru float value type check" + + +@pytest.mark.parametrize( + "model_path", ["", " ", " " * 100], ids=["empty", "space", "spaces"] +) +def test_save_value_checks(model_path): + with pytest.raises(ValueError): + BOOTSTRAPPED_MODEL.save(model_path) + + +@pytest.mark.parametrize( + "model_path", + ["..", "../" * 100, "/", ".", " / ", " . "], +) +def test_save_exists_checks(model_path): + """Tests for model paths are always existing dirs that should not be clobbered""" + with pytest.raises(FileExistsError): + BOOTSTRAPPED_MODEL.save(model_path) + + +def test_second_save_hits_exists_check(): + """Using a new path the first save should succeed but second fails""" + model_id = "model_id" + with tempfile.TemporaryDirectory(suffix="-2nd") as model_dir: + model_path = os.path.join(model_dir, model_id) + BOOTSTRAPPED_MODEL.save(model_path) + with pytest.raises(FileExistsError): + BOOTSTRAPPED_MODEL.save(model_path) + + +@pytest.mark.parametrize("model_path", [None, {}, object(), 1], ids=type) +def test_save_type_checks(model_path): + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.save(model_path) + + +def test_load_without_artifacts_or_hf_model(): + """Test coverage for the error message when config has no artifacts and no hf_model to load""" + with pytest.raises(ValueError): + Rerank.load(ModuleConfig({})) From ad56fa8e5971105c708effb6936100ef95bea3d0 Mon Sep 17 00:00:00 2001 From: markstur Date: Fri, 20 Oct 2023 00:15:41 -0700 Subject: [PATCH 06/17] Rerank and sentence-similarity added to embedding service using multi-task * The EmbeddingModule now does all 3 tasks (same loaded model) * An additional 3 tasks allow multiple texts, source_sentences, or queries. - the documents or sentences compared to are the same for each * Added more docs Signed-off-by: markstur --- README.md | 19 +- caikit_nlp/data_model/__init__.py | 1 + caikit_nlp/data_model/embedding_vectors.py | 33 ++ caikit_nlp/data_model/reranker.py | 26 +- .../sentence_similarity.py} | 36 +-- caikit_nlp/modules/reranker/__init__.py | 16 - caikit_nlp/modules/reranker/rerank.py | 245 --------------- caikit_nlp/modules/text_embedding/__init__.py | 20 +- .../modules/text_embedding/embedding.py | 66 +++- .../modules/text_embedding/embedding_tasks.py | 14 +- .../modules/text_embedding/rerank_task.py | 73 +++++ .../sentence_similarity_task.py | 40 +++ tests/data_model/test_embedding_vectors.py | 15 + tests/data_model/test_reranker.py | 33 +- tests/modules/reranker/test_rerank.py | 187 ----------- tests/modules/reranker/test_reranks.py | 199 ------------ .../modules/text_embedding/test_embedding.py | 292 +++++++++++++++++- 17 files changed, 589 insertions(+), 726 deletions(-) rename caikit_nlp/{modules/reranker/rerank_task.py => data_model/sentence_similarity.py} (53%) delete mode 100644 caikit_nlp/modules/reranker/__init__.py delete mode 100644 caikit_nlp/modules/reranker/rerank.py create mode 100644 caikit_nlp/modules/text_embedding/rerank_task.py create mode 100644 caikit_nlp/modules/text_embedding/sentence_similarity_task.py delete mode 100644 tests/modules/reranker/test_rerank.py delete mode 100644 tests/modules/reranker/test_reranks.py diff --git a/README.md b/README.md index aab33851..6ed215b2 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,18 @@ Caikit-NLP implements concept of "task" from `caikit` framework to define (and c Capabilities provided by `caikit-nlp`: -| Task | Module(s) | Salient Feature(s) | -|----------------------|-------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Text Generation | 1. `PeftPromptTuning`
2. `TextGeneration` | 1. Prompt Tuning, Multi-task Prompt tuning
2. Fine-tuning Both modules above provide optimized inference capability using Text Generation Inference Server | -| Text Classification | 1. `SequenceClassification` | 1. (Work in progress..) | -| Token Classification | 1. `FilteredSpanClassification` | 1. (Work in progress..) | -| Tokenization | 1. `RegexSentenceSplitter` | 1. Demo purposes only | -| Embedding | [COMING SOON] | [COMING SOON] | +| Task | Module(s) | Salient Feature(s) | +|-------------------------|------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| TextGenerationTask | 1. `PeftPromptTuning`
2. `TextGeneration` | 1. Prompt Tuning, Multi-task Prompt tuning
2. Fine-tuning Both modules above provide optimized inference capability using Text Generation Inference Server | +| TextClassificationTask | 1. `SequenceClassification` | 1. (Work in progress..) | +| TokenClassificationTask | 1. `FilteredSpanClassification` | 1. (Work in progress..) | +| TokenizationTask | 1. `RegexSentenceSplitter` | 1. Demo purposes only | +| EmbeddingTask | 1. `TextEmbedding` | 1. text/embedding from a local sentence-transformers model +| EmbeddingTasks | 1. `TextEmbedding` | 1. Same as EmbeddingTask but multiple sentences (texts) as input and corresponding list of outputs. +| SentenceSimilarityTask | 1. `TextEmbedding` | 1. text/sentence-similarity from a local sentence-transformers model (Hugging Face style API returns scores only in order of input sentences) | +| SentenceSimilarityTasks | 1. `TextEmbedding` | 1. Same as SentenceSimilarityTask but multiple source_sentences (each to be compared to same list of sentences) as input and corresponding lists of outputs. | +| RerankTask | 1. `TextEmbedding` | 1. text/rerank from a local sentence-transformers model (Cohere style API returns top_n scores in order of relevance with index to source and optionally returning inputs) | +| RerankTasks | 1. `TextEmbedding` | 1. Same as RerankTask but multiple queries as input and corresponding lists of outputs. Same list of documents for all queries. | ## Getting Started diff --git a/caikit_nlp/data_model/__init__.py b/caikit_nlp/data_model/__init__.py index 9354d875..98f1d06b 100644 --- a/caikit_nlp/data_model/__init__.py +++ b/caikit_nlp/data_model/__init__.py @@ -19,3 +19,4 @@ from .embedding_vectors import * from .generation import * from .reranker import * +from .sentence_similarity import * diff --git a/caikit_nlp/data_model/embedding_vectors.py b/caikit_nlp/data_model/embedding_vectors.py index 67f3d5b3..22fa8277 100644 --- a/caikit_nlp/data_model/embedding_vectors.py +++ b/caikit_nlp/data_model/embedding_vectors.py @@ -155,6 +155,39 @@ def fill_proto(self, proto): return proto +@dataobject(package="caikit_data_model.caikit_nlp") +class ListOfVector1D(DataObjectBase): + """Data representation for an embedding matrix holding 2D vectors""" + + results: List[Vector1D] + + def __post_init__(self): + error.type_check("", list, results=self.results) + error.type_check_all("", Vector1D, results=self.results) + + @classmethod + def from_json(cls, json_str): + """Fill in the vector data in an appropriate data_""" + + json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str + for v in json_obj["results"]: + data = v.pop("data") + if data is not None: + v["data_pyfloatsequence"] = data + json_str = json.dumps(json_obj) + try: + # Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType + parsed_proto = json_format.Parse( + json_str, cls.get_proto_class()(), ignore_unknown_fields=False + ) + + # Use from_proto to return the DataBase object from the parsed proto + return cls.from_proto(parsed_proto) + + except json_format.ParseError as ex: + error("", ValueError(ex)) + + @dataobject(package="caikit_data_model.caikit_nlp") @dataclass class EmbeddingResult(DataObjectBase): diff --git a/caikit_nlp/data_model/reranker.py b/caikit_nlp/data_model/reranker.py index 48692afd..b0ff7be9 100644 --- a/caikit_nlp/data_model/reranker.py +++ b/caikit_nlp/data_model/reranker.py @@ -13,42 +13,32 @@ # limitations under the License. # Standard -from dataclasses import dataclass -from typing import List +from typing import List, Optional # First Party from caikit.core import DataObjectBase, dataobject from caikit.core.data_model.json_dict import JsonDict -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class RerankDocuments(DataObjectBase): - """An input list of JSON documents""" - - documents: List[JsonDict] - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass +@dataobject() class RerankScore(DataObjectBase): """The score for one document (one query)""" - document: JsonDict - corpus_id: int + document: Optional[JsonDict] + index: int score: float + text: Optional[str] -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass +@dataobject() class RerankQueryResult(DataObjectBase): """Result for one query in a rerank task""" + query: Optional[str] scores: List[RerankScore] -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass +@dataobject() class RerankPrediction(DataObjectBase): """Result for a rerank task""" diff --git a/caikit_nlp/modules/reranker/rerank_task.py b/caikit_nlp/data_model/sentence_similarity.py similarity index 53% rename from caikit_nlp/modules/reranker/rerank_task.py rename to caikit_nlp/data_model/sentence_similarity.py index 263c13d7..c24097f7 100644 --- a/caikit_nlp/modules/reranker/rerank_task.py +++ b/caikit_nlp/data_model/sentence_similarity.py @@ -11,40 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Data structures for embedding vector representations +""" # Standard from typing import List # First Party -from caikit.core import TaskBase, task -from caikit.core.data_model.json_dict import JsonDict +from caikit.core import DataObjectBase, dataobject from caikit.core.exceptions import error_handler import alog -# Local -from caikit_nlp.data_model.reranker import RerankPrediction, RerankQueryResult +log = alog.use_channel("DATAM") +error = error_handler.get(log) -logger = alog.use_channel("") -error = error_handler.get(logger) +@dataobject(package="caikit_data_model.caikit_nlp") +class SentenceScores(DataObjectBase): + scores: List[float] -@task( - required_parameters={ - "documents": List[JsonDict], - "query": str, - }, - output_type=RerankQueryResult, -) -class RerankTask(TaskBase): - pass +@dataobject(package="caikit_data_model.caikit_nlp") +class SentenceListScores(DataObjectBase): -@task( - required_parameters={ - "documents": List[JsonDict], - "queries": List[str], - }, - output_type=RerankPrediction, -) -class ReranksTask(TaskBase): - pass + results: List[SentenceScores] diff --git a/caikit_nlp/modules/reranker/__init__.py b/caikit_nlp/modules/reranker/__init__.py deleted file mode 100644 index e5a6d2bd..00000000 --- a/caikit_nlp/modules/reranker/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Local -from .rerank import Rerank diff --git a/caikit_nlp/modules/reranker/rerank.py b/caikit_nlp/modules/reranker/rerank.py deleted file mode 100644 index 0f93a7c7..00000000 --- a/caikit_nlp/modules/reranker/rerank.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Standard -from typing import List, Optional -import os - -# Third Party -from sentence_transformers import SentenceTransformer -from sentence_transformers.util import dot_score, normalize_embeddings, semantic_search - -# First Party -from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module -from caikit.core.data_model.json_dict import JsonDict -from caikit.core.exceptions import error_handler -import alog - -# Local -from .rerank_task import RerankTask -from caikit_nlp.data_model.reranker import ( - RerankPrediction, - RerankQueryResult, - RerankScore, -) - -logger = alog.use_channel("") -error = error_handler.get(logger) - - -@module( - "00110203-0405-0607-0809-0a0b02dd0e0f", - "RerankerModule", - "0.0.1", - RerankTask, -) -class Rerank(ModuleBase): - - _ARTIFACTS_PATH_KEY = "artifacts_path" - _ARTIFACTS_PATH_DEFAULT = "artifacts" - _HF_HUB_KEY = "hf_model" - - def __init__( - self, - model: SentenceTransformer, - ): - """Initialize - This function gets called by `.load` and `.train` function - which initializes this module. - """ - super().__init__() - self.model = model - - @classmethod - def load(cls, model_path: str) -> "Rerank": - """Load a model - - Args: - model_path: str - Path to the config dir of the model to be loaded. - - Returns: - Rerank - Instance of this class built from the model. - """ - - config = ModuleConfig.load(model_path) - - artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY) - if artifacts_path: - model_name_or_path = os.path.abspath( - os.path.join(model_path, artifacts_path) - ) - error.dir_check("", model_name_or_path) - else: - # If no artifacts_path, look for hf_model Hugging Face model by name (or path) - model_name_or_path = config.get(cls._HF_HUB_KEY) - error.value_check( - "", - model_name_or_path, - ValueError( - f"Model config missing '{cls._ARTIFACTS_PATH_KEY}' or '{cls._HF_HUB_KEY}'" - ), - ) - - return cls.bootstrap(model_name_or_path=model_name_or_path) - - def run( - self, - query: str, - documents: List[JsonDict], - top_n: Optional[int] = None, - ) -> RerankQueryResult: - """Run inference on model. - Args: - query: str - documents: List[JsonDict] - top_n: Optional[int] - Returns: - RerankQueryResult - """ - - error.type_check( - "", - str, - query=query, - ) - - error.type_check( - "", - list, - documents=documents, - ) - - batch_results = self.run_queries( - queries=[query], documents=documents, top_n=top_n - ) - results = batch_results.results - - if len(results) > 1: - error( - "", - ValueError(f"expected single query result, but got {len(results)}"), - ) - - # Return the single result, creating an empty one if results was empty - return results[0] if len(results) == 1 else RerankQueryResult([]) - - def run_queries( - self, - queries: List[str], - documents: List[JsonDict], - top_n: Optional[int] = None, - ) -> RerankPrediction: - """Run inference on model. - Args: - queries: List[str] - documents: List[JsonDict] - top_n: Optional[int] - Returns: - RerankPrediction - """ - - error.type_check( - "", - list, - queries=queries, - documents=documents, - ) - - error.type_check( - "", - int, - allow_none=True, - top_n=top_n, - ) - - if len(queries) < 1 or len(documents) < 1: - return RerankPrediction([]) - - if top_n is None or int(top_n) < 1: - top_n = len(documents) - - # Using input document dicts so get "text" else "_text" else default to "" - doc_texts = [srd.get("text") or srd.get("_text", "") for srd in documents] - - doc_embeddings = self.model.encode(doc_texts, convert_to_tensor=True) - doc_embeddings = doc_embeddings.to(self.model.device) - doc_embeddings = normalize_embeddings(doc_embeddings) - - query_embeddings = self.model.encode(queries, convert_to_tensor=True) - query_embeddings = query_embeddings.to(self.model.device) - query_embeddings = normalize_embeddings(query_embeddings) - - res = semantic_search( - query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score - ) - - for r in res: - for x in r: - x["document"] = documents[x["corpus_id"]] - - results = [RerankQueryResult([RerankScore(**x) for x in r]) for r in res] - - return RerankPrediction(results=results) - - @classmethod - def bootstrap(cls, model_name_or_path: str) -> "Rerank": - """Bootstrap a sentence-transformers model - - Args: - model_name_or_path: str - Model name (Hugging Face hub) or path to model to load. - """ - return cls(model=SentenceTransformer(model_name_or_path=model_name_or_path)) - - def save(self, model_path: str, *args, **kwargs): - """Save model using config in model_path - - Args: - model_path: str - Path to model config - """ - error.type_check("", str, model_path=model_path) - error.value_check( - "", - model_path is not None and model_path.strip(), - f"model_path '{model_path}' is invalid", - ) - - model_path = os.path.abspath( - model_path.strip() - ) # No leading/trailing spaces sneaky weirdness - - if os.path.exists(model_path): - error( - "", - FileExistsError(f"model_path '{model_path}' already exists"), - ) - - saver = ModuleSaver( - module=self, - model_path=model_path, - ) - - # Save update config (artifacts_path) and save artifacts - with saver: - artifacts_path = saver.config.get(self._ARTIFACTS_PATH_KEY) - if not artifacts_path: - artifacts_path = self._ARTIFACTS_PATH_DEFAULT - saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path}) - if self.model: # This condition allows for empty placeholders - artifacts_path = os.path.abspath( - os.path.join(model_path, artifacts_path) - ) - self.model.save(artifacts_path, create_model_card=True) diff --git a/caikit_nlp/modules/text_embedding/__init__.py b/caikit_nlp/modules/text_embedding/__init__.py index f56694f6..d72e55be 100644 --- a/caikit_nlp/modules/text_embedding/__init__.py +++ b/caikit_nlp/modules/text_embedding/__init__.py @@ -12,6 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Text Embedding Module +===================== + +Implements the following tasks: + + 1. EmbeddingTask: Returns an embedding from an input text string + 2. EmbeddingsTasks: EmbeddingTask but with a list of inputs producing a list of outputs + 3. SentenceSimilarityTask: Compare one source sentence to a list of sentences + 4. SentenceSimilarityTasks: SentenceSimilarityTask but with a list of source sentences producing + a list of outputs + 5. RerankTask: Return top_n documents ordered by relevance given a query + 6. RerankTasks: RerankTask but with a list of queries producing a list of outputs + +""" + # Local from .embedding import EmbeddingModule -from .embedding_tasks import EmbeddingTask +from .embedding_tasks import EmbeddingTask, EmbeddingTasks +from .rerank_task import RerankTask, RerankTasks +from .sentence_similarity_task import SentenceSimilarityTask, SentenceSimilarityTasks diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index ea415048..e8c05cca 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -13,19 +13,38 @@ # limitations under the License. # Standard +from typing import List, Optional import os # Third Party from sentence_transformers import SentenceTransformer +from sentence_transformers.util import ( + cos_sim, + dot_score, + normalize_embeddings, + semantic_search, +) # First Party from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module +from caikit.core.data_model.json_dict import JsonDict from caikit.core.exceptions import error_handler import alog # Local -from .embedding_tasks import EmbeddingTask -from caikit_nlp.data_model.embedding_vectors import EmbeddingResult, Vector1D +from .embedding_tasks import EmbeddingTask, EmbeddingTasks +from .rerank_task import RerankTask, RerankTasks +from .sentence_similarity_task import SentenceSimilarityTask, SentenceSimilarityTasks +from caikit_nlp.data_model import ( + EmbeddingResult, + ListOfVector1D, + RerankPrediction, + RerankQueryResult, + RerankScore, + SentenceListScores, + SentenceScores, + Vector1D, +) logger = alog.use_channel("TXT_EMB") error = error_handler.get(logger) @@ -35,7 +54,14 @@ "eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f", "EmbeddingModule", "0.0.1", - EmbeddingTask, + tasks=[ + EmbeddingTask, + EmbeddingTasks, + SentenceSimilarityTask, + SentenceSimilarityTasks, + RerankTask, + RerankTasks, + ], ) class EmbeddingModule(ModuleBase): @@ -76,19 +102,39 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": return cls.bootstrap(model_name_or_path=artifacts_path) - def run( - self, input: str, **kwargs # pylint: disable=redefined-builtin - ) -> EmbeddingResult: - """Run inference on model. + @EmbeddingTask.taskmethod() + def run_embedding(self, text: str) -> EmbeddingResult: # pylint: disable=redefined-builtin + """Get embedding for a string. Args: - input: str + text: str Input text to be processed Returns: EmbeddingResult: the result vector nicely wrapped up """ - error.type_check("", str, input=input) + error.type_check("", str, text=text) + + return EmbeddingResult(Vector1D.from_vector(self.model.encode(text))) - return EmbeddingResult(Vector1D.from_vector(self.model.encode(input))) + @EmbeddingTasks.taskmethod() + def run_embeddings( + self, texts: List[str] # pylint: disable=redefined-builtin + ) -> ListOfVector1D: + """Run inference on model. + Args: + texts: List[str] + List of input texts to be processed + Returns: + List[Vector1D]: List vectors. One for each input text (in order). + Each vector is a list of floats (supports various float types). + """ + if isinstance( + texts, str + ): # encode allows str, but the result would lack a dimension + texts = [texts] + + embeddings = self.model.encode(texts) + results = [Vector1D.from_embeddings(e) for e in embeddings] + return ListOfVector1D(results=results) @classmethod def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule": diff --git a/caikit_nlp/modules/text_embedding/embedding_tasks.py b/caikit_nlp/modules/text_embedding/embedding_tasks.py index 07b31dba..a9a4ebf9 100644 --- a/caikit_nlp/modules/text_embedding/embedding_tasks.py +++ b/caikit_nlp/modules/text_embedding/embedding_tasks.py @@ -13,17 +13,27 @@ # limitations under the License. # Standard +from typing import List # First Party from caikit.core import TaskBase, task # Local from ...data_model import EmbeddingResult +from ...data_model import ListOfVector1D @task( - required_parameters={"input": str}, + required_parameters={"text": str}, output_type=EmbeddingResult, ) class EmbeddingTask(TaskBase): - pass + """Return a text embedding for the input text string""" + + +@task( + required_parameters={"texts": List[str]}, + output_type=ListOfVector1D, +) +class EmbeddingTasks(TaskBase): + """Return a text embedding for each text string in the input list""" diff --git a/caikit_nlp/modules/text_embedding/rerank_task.py b/caikit_nlp/modules/text_embedding/rerank_task.py new file mode 100644 index 00000000..caa9d24f --- /dev/null +++ b/caikit_nlp/modules/text_embedding/rerank_task.py @@ -0,0 +1,73 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import List + +# First Party +from caikit.core import TaskBase, task +from caikit.core.data_model.json_dict import JsonDict +from caikit.core.toolkit.errors import error_handler +import alog + +# Local +from caikit_nlp.data_model.reranker import RerankPrediction, RerankQueryResult + +logger = alog.use_channel("") +error = error_handler.get(logger) + + +@task( + required_parameters={ + "documents": List[JsonDict], + "query": str, + }, + output_type=RerankQueryResult, +) +class RerankTask(TaskBase): + """Returns an ordered list ranking the most relevant documents for the query + + Required parameters: + query: The search query + documents: JSON documents containing "text" or alternative "_text" to search + Returns: + The top_n documents in order of relevance (most relevant first). + For each, a score and document index (position in input) is returned. + The original document JSON is returned depending on optional args. + The top_n optional parameter limits the results when used. + """ + + +@task( + required_parameters={ + "documents": List[JsonDict], + "queries": List[str], + }, + output_type=RerankPrediction, +) +class RerankTasks(TaskBase): + """Returns an ordered list for each query ranking the most relevant documents for the query + + Required parameters: + queries: The search queries + documents: JSON documents containing "text" or alternative "_text" to search + Returns: + Results in order of the queries. + In each query result: + The query text is optionally included for visual convenience. + The top_n documents in order of relevance (most relevant first). + For each, a score and document index (position in input) is returned. + The original document JSON is returned depending on optional args. + The top_n optional parameter limits the results when used. + """ diff --git a/caikit_nlp/modules/text_embedding/sentence_similarity_task.py b/caikit_nlp/modules/text_embedding/sentence_similarity_task.py new file mode 100644 index 00000000..458fa72a --- /dev/null +++ b/caikit_nlp/modules/text_embedding/sentence_similarity_task.py @@ -0,0 +1,40 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# First Party +from caikit.core import TaskBase, task +from typing import List + +from ...data_model import SentenceListScores, SentenceScores + + +@task( + required_parameters={"source_sentence": str, "sentences": List[str]}, + output_type=SentenceScores +) +class SentenceSimilarityTask(TaskBase): + """Compare the source_sentence to each of the sentences. + Result contains a list of scores in the order of the input sentences. + """ + + +@task( + required_parameters={"source_sentences": List[str], "sentences": List[str]}, + output_type=SentenceListScores +) +class SentenceSimilarityTasks(TaskBase): + """Compare each of the source_sentences to each of the sentences. + Returns a list of results in the order of the source_sentences. + Each result contains a list of scores in the order of the input sentences. + """ diff --git a/tests/data_model/test_embedding_vectors.py b/tests/data_model/test_embedding_vectors.py index 010eac10..6d65cb4a 100644 --- a/tests/data_model/test_embedding_vectors.py +++ b/tests/data_model/test_embedding_vectors.py @@ -113,3 +113,18 @@ def test_vector1d_dm(float_seq_class, random_values, float_type): _assert_array_check( dm_from_json, random_values, float ) # NOTE: always float after json + + +@pytest.mark.parametrize( + "float_seq_class, random_values, float_type", + [ + (dm.PyFloatSequence, random_python_vector1d_float, float), + (dm.NpFloat32Sequence, random_numpy_vector1d_float32, np.float32), + (dm.NpFloat64Sequence, random_numpy_vector1d_float64, np.float64), + ], +) +def test_vector1d_dm_from_embeddings(float_seq_class, random_values, float_type): + v = dm.Vector1D.from_embeddings(random_values) + assert isinstance(v.data, float_seq_class) + assert isinstance(v.data.values[0], float_type) + _assert_array_check(v, random_values, float_type) diff --git a/tests/data_model/test_reranker.py b/tests/data_model/test_reranker.py index f20df204..853a0151 100644 --- a/tests/data_model/test_reranker.py +++ b/tests/data_model/test_reranker.py @@ -49,20 +49,23 @@ input_score = { "document": input_document, - "corpus_id": 1234, + "index": 1234, "score": 9876.54321, + "text": "this is the input text", } input_random_score = { "document": input_random_document, - "corpus_id": random.randint(-99999, 99999), + "index": random.randint(-99999, 99999), "score": random.uniform(-99999, 99999), + "text": "".join(random.choices(string.printable, k=100)), } input_random_score_3 = { "document": {"text": "random foo3"}, - "corpus_id": random.randint(-99999, 99999), + "index": random.randint(-99999, 99999), "score": random.uniform(-99999, 99999), + "text": "".join(random.choices(string.printable, k=100)), } input_scores = [dm.RerankScore(**input_score), dm.RerankScore(**input_random_score)] @@ -70,9 +73,24 @@ dm.RerankScore(**input_random_score), dm.RerankScore(**input_random_score_3), ] + +input_result_1 = {"query": "foo", "scores": input_scores} +input_result_2 = {"query": "bar", "scores": input_scores2} input_results = [ - dm.RerankQueryResult(scores=input_scores), - dm.RerankQueryResult(scores=input_scores2), + dm.RerankQueryResult(**input_result_1), + dm.RerankQueryResult(**input_result_2), +] + +input_sentence_similarity_scores_1 = { + "scores": [random.uniform(-99999, 99999) for _ in range(10)] +} +input_sentence_similarity_scores_2 = { + "scores": [random.uniform(-99999, 99999) for _ in range(10)] +} + +input_sentence_similarities_scores = [ + dm.SentenceScores(**input_sentence_similarity_scores_1), + dm.SentenceScores(**input_sentence_similarity_scores_2), ] @@ -82,11 +100,12 @@ @pytest.mark.parametrize( "data_object, inputs", [ - (dm.RerankDocuments, {"documents": input_documents}), (dm.RerankScore, input_score), (dm.RerankScore, input_random_score), - (dm.RerankQueryResult, {"scores": input_scores}), + (dm.RerankQueryResult, input_result_1), (dm.RerankPrediction, {"results": input_results}), + (dm.SentenceScores, input_sentence_similarity_scores_1), + (dm.SentenceListScores, {"results": input_sentence_similarities_scores}), ], ) def test_data_object(data_object, inputs): diff --git a/tests/modules/reranker/test_rerank.py b/tests/modules/reranker/test_rerank.py deleted file mode 100644 index 009687e8..00000000 --- a/tests/modules/reranker/test_rerank.py +++ /dev/null @@ -1,187 +0,0 @@ -"""Tests for sequence classification module -""" -# Standard -from typing import List -import os -import tempfile - -# Third Party -import pytest - -# First Party -from caikit.core import ModuleConfig - -# Local -from caikit_nlp.data_model import RerankQueryResult, RerankScore -from caikit_nlp.modules.reranker import Rerank -from tests.fixtures import SEQ_CLASS_MODEL - -## Setup ######################################################################## - -# Bootstrapped sequence classification model for reusability across tests -# .bootstrap is tested separately in the first test -BOOTSTRAPPED_MODEL = Rerank.bootstrap(SEQ_CLASS_MODEL) - -QUERY = "What is foo bar?" - -DOCS = [ - { - "text": "foo", - "title": "title or whatever", - "str_test": "test string", - "int_test": 1, - "float_test": 1.234, - "score": 99999, - }, - { - "_text": "bar", - "title": "title 2", - }, - { - "text": "foo and bar", - }, - { - "_text": "Where is the bar", - "another": "something else", - }, -] - -## Tests ######################################################################## - - -def test_bootstrap(): - assert isinstance(BOOTSTRAPPED_MODEL, Rerank), "bootstrap error" - - -@pytest.mark.parametrize( - "query,docs,top_n", - [ - (["test list"], DOCS, None), - (None, DOCS, 1234), - (False, DOCS, 1234), - (QUERY, {"testdict": "not list"}, 1234), - (QUERY, DOCS, "topN string is not an integer or None"), - ], -) -def test_run_type_error(query, docs, top_n): - """test for type checks matching task/run signature""" - with pytest.raises(TypeError): - BOOTSTRAPPED_MODEL.run(query=query, documents=docs, top_n=top_n) - pytest.fail("Should not reach here.") - - -def test_run_no_type_error(): - """no type error with list of string queries and list of dict documents""" - BOOTSTRAPPED_MODEL.run(query=QUERY, documents=DOCS, top_n=1) - - -@pytest.mark.parametrize( - "top_n, expected", - [ - (1, 1), - (2, 2), - (None, len(DOCS)), - (-1, len(DOCS)), - (0, len(DOCS)), - (9999, len(DOCS)), - ], -) -def test_run_top_n(top_n, expected): - res = BOOTSTRAPPED_MODEL.run(query=QUERY, documents=DOCS, top_n=top_n) - assert isinstance(res, RerankQueryResult) - assert len(res.scores) == expected - - -def test_run_no_query(): - with pytest.raises(TypeError): - BOOTSTRAPPED_MODEL.run(query=None, documents=DOCS, top_n=99) - - -def test_run_zero_docs(): - """No empty doc list therefore result is zero result scores""" - result = BOOTSTRAPPED_MODEL.run(query=QUERY, documents=[], top_n=99) - assert len(result.scores) == 0 - - -def test_save_and_load_and_run_model(): - """Save and load and run a model""" - - model_id = "model_id" - with tempfile.TemporaryDirectory(suffix="-1st") as model_dir: - model_path = os.path.join(model_dir, model_id) - BOOTSTRAPPED_MODEL.save(model_path) - new_model = Rerank.load(model_path) - - assert isinstance(new_model, Rerank), "save and load error" - assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model" - - result = new_model.run(query=QUERY, documents=DOCS) - assert isinstance(result, RerankQueryResult) - - # Collect some of the pass-through extras to verify we can do some types - str_test = None - int_test = None - float_test = None - - scores = result.scores - assert isinstance(scores, list) - assert len(scores) == len(DOCS) - for score in scores: - assert isinstance(score, RerankScore) - assert isinstance(score.score, float) - assert isinstance(score.corpus_id, int) - assert score.document == DOCS[score.corpus_id] - - # Test pass-through score (None or 9999) is independent of the result score - assert score.score != score.document.get( - "score" - ), "unexpected passthru score same as result score" - - # Gather various type test values - str_test = score.document.get("str_test", str_test) - int_test = score.document.get("int_test", int_test) - float_test = score.document.get("float_test", float_test) - - assert type(str_test) == str, "passthru str value type check" - assert type(int_test) == int, "passthru int value type check" - assert type(float_test) == float, "passthru float value type check" - - -@pytest.mark.parametrize( - "model_path", ["", " ", " " * 100], ids=["empty", "space", "spaces"] -) -def test_save_value_checks(model_path): - with pytest.raises(ValueError): - BOOTSTRAPPED_MODEL.save(model_path) - - -@pytest.mark.parametrize( - "model_path", - ["..", "../" * 100, "/", ".", " / ", " . "], -) -def test_save_exists_checks(model_path): - """Tests for model paths are always existing dirs that should not be clobbered""" - with pytest.raises(FileExistsError): - BOOTSTRAPPED_MODEL.save(model_path) - - -def test_second_save_hits_exists_check(): - """Using a new path the first save should succeed but second fails""" - model_id = "model_id" - with tempfile.TemporaryDirectory(suffix="-2nd") as model_dir: - model_path = os.path.join(model_dir, model_id) - BOOTSTRAPPED_MODEL.save(model_path) - with pytest.raises(FileExistsError): - BOOTSTRAPPED_MODEL.save(model_path) - - -@pytest.mark.parametrize("model_path", [None, {}, object(), 1], ids=type) -def test_save_type_checks(model_path): - with pytest.raises(TypeError): - BOOTSTRAPPED_MODEL.save(model_path) - - -def test_load_without_artifacts_or_hf_model(): - """Test coverage for the error message when config has no artifacts and no hf_model to load""" - with pytest.raises(ValueError): - Rerank.load(ModuleConfig({})) diff --git a/tests/modules/reranker/test_reranks.py b/tests/modules/reranker/test_reranks.py deleted file mode 100644 index 6e945108..00000000 --- a/tests/modules/reranker/test_reranks.py +++ /dev/null @@ -1,199 +0,0 @@ -"""Tests for sequence classification module -""" -# Standard -from typing import List -import os -import tempfile - -# Third Party -import pytest - -# First Party -from caikit.core import ModuleConfig - -# Local -from caikit_nlp import RerankQueryResult, RerankScore -from caikit_nlp.data_model import RerankPrediction -from caikit_nlp.modules.reranker import Rerank -from tests.fixtures import SEQ_CLASS_MODEL - -## Setup ######################################################################## - -# Bootstrapped sequence classification model for reusability across tests -# .bootstrap is tested separately in the first test -BOOTSTRAPPED_MODEL = Rerank.bootstrap(SEQ_CLASS_MODEL) - -QUERIES: List[str] = [ - "Who is foo?", - "Where is the bar?", -] - -DOCS = [ - { - "text": "foo", - "title": "title or whatever", - "str_test": "test string", - "int_test": 1, - "float_test": 1.11, - "score": 99999, - }, - { - "_text": "bar", - "title": "title 2", - }, - { - "text": "foo and bar", - }, - { - "_text": "Where is the bar", - "another": "something else", - }, -] - -## Tests ######################################################################## - - -def test_bootstrap(): - assert isinstance(BOOTSTRAPPED_MODEL, Rerank), "bootstrap error" - - -@pytest.mark.parametrize( - "queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})] -) -def test_run_batch_type_error(queries, docs): - """type error check ensures params are lists and not just 1 string or just one doc (for example)""" - with pytest.raises(TypeError): - BOOTSTRAPPED_MODEL.run_queries(queries=queries, documents=docs) - pytest.fail("Should not reach here.") - - -def test_run_batch_no_type_error(): - """no type error with list of string queries and list of dict documents""" - BOOTSTRAPPED_MODEL.run_queries(queries=QUERIES, documents=DOCS, top_n=99) - - -@pytest.mark.parametrize( - "top_n, expected", - [ - (1, 1), - (2, 2), - (None, len(DOCS)), - (-1, len(DOCS)), - (0, len(DOCS)), - (9999, len(DOCS)), - ], -) -def test_run_batch_top_n(top_n, expected): - """no type error with list of string queries and list of dict documents""" - res = BOOTSTRAPPED_MODEL.run_queries(queries=QUERIES, documents=DOCS, top_n=top_n) - assert isinstance(res, RerankPrediction) - assert len(res.results) == len(QUERIES) - for result in res.results: - assert len(result.scores) == expected - - -@pytest.mark.parametrize( - "queries, docs", - [ - ([], DOCS), - (QUERIES, []), - ([], []), - ], - ids=["no queries", "no docs", "no queries and no docs"], -) -def test_run_batch_no_queries_or_no_docs(queries, docs): - """No queries and/or no docs therefore result is zero results""" - res = BOOTSTRAPPED_MODEL.run_queries(queries=queries, documents=docs, top_n=9) - assert isinstance(res, RerankPrediction) - assert len(res.results) == 0 - - -def test_save_and_load_and_run_batch_model(): - """Save and load and run a model""" - - model_id = "model_id" - with tempfile.TemporaryDirectory(suffix="-1st") as model_dir: - model_path = os.path.join(model_dir, model_id) - BOOTSTRAPPED_MODEL.save(model_path) - new_model = Rerank.load(model_path) - - assert isinstance(new_model, Rerank), "save and load error" - assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model" - - top_n = 2 - rerank_result = new_model.run_queries(queries=QUERIES, documents=DOCS, top_n=top_n) - assert isinstance(rerank_result, RerankPrediction) - - results = rerank_result.results - assert isinstance(results, list) - assert len(results) == 2 == len(QUERIES) # 2 queries yields 2 result(s) - - # Collect some of the pass-through extras to verify we can do some types - str_test = None - int_test = None - float_test = None - - for result in results: - assert isinstance(result, RerankQueryResult) - scores = result.scores - assert isinstance(scores, list) - assert len(scores) == top_n - for score in scores: - assert isinstance(score, RerankScore) - assert isinstance(score.score, float) - assert isinstance(score.corpus_id, int) - assert score.document == DOCS[score.corpus_id] - - # Test pass-through score (None or 9999) is independent of the result score - assert score.score != score.document.get( - "score" - ), "unexpected passthru score same as result score" - - # Gather various type test values - str_test = score.document.get("str_test", str_test) - int_test = score.document.get("int_test", int_test) - float_test = score.document.get("float_test", float_test) - - assert type(str_test) == str, "passthru str value type check" - assert type(int_test) == int, "passthru int value type check" - assert type(float_test) == float, "passthru float value type check" - - -@pytest.mark.parametrize( - "model_path", ["", " ", " " * 100], ids=["empty", "space", "spaces"] -) -def test_save_value_checks(model_path): - with pytest.raises(ValueError): - BOOTSTRAPPED_MODEL.save(model_path) - - -@pytest.mark.parametrize( - "model_path", - ["..", "../" * 100, "/", ".", " / ", " . "], -) -def test_save_exists_checks(model_path): - """Tests for model paths are always existing dirs that should not be clobbered""" - with pytest.raises(FileExistsError): - BOOTSTRAPPED_MODEL.save(model_path) - - -def test_second_save_hits_exists_check(): - """Using a new path the first save should succeed but second fails""" - model_id = "model_id" - with tempfile.TemporaryDirectory(suffix="-2nd") as model_dir: - model_path = os.path.join(model_dir, model_id) - BOOTSTRAPPED_MODEL.save(model_path) - with pytest.raises(FileExistsError): - BOOTSTRAPPED_MODEL.save(model_path) - - -@pytest.mark.parametrize("model_path", [None, {}, object(), 1], ids=type) -def test_save_type_checks(model_path): - with pytest.raises(TypeError): - BOOTSTRAPPED_MODEL.save(model_path) - - -def test_load_without_artifacts_or_hf_model(): - """Test coverage for the error message when config has no artifacts and no hf_model to load""" - with pytest.raises(ValueError): - Rerank.load(ModuleConfig({})) diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index af4d0556..c92f114e 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -1,6 +1,7 @@ """Tests for text embedding module """ # Standard +from typing import List import os import tempfile @@ -9,8 +10,17 @@ import numpy as np import pytest +# First Party +from caikit.core import ModuleConfig + # Local -from caikit_nlp.data_model import EmbeddingResult +from caikit_nlp.data_model import ( + EmbeddingResult, + RerankPrediction, + RerankQueryResult, + RerankScore, + Vector1D, +) from caikit_nlp.modules.text_embedding import EmbeddingModule from tests.fixtures import SEQ_CLASS_MODEL @@ -22,6 +32,42 @@ INPUT = "The quick brown fox jumps over the lazy dog." +QUERY = "What is foo bar?" + +QUERIES: List[str] = [ + "Who is foo?", + "Where is the bar?", +] + +# These are used to test that documents can handle different types in and out +TYPE_KEYS = "str_test", "int_test", "float_test", "nested_dict_test" + +DOCS = [ + { + "text": "foo", + "title": "title or whatever", + "str_test": "test string", + "int_test": 1, + "float_test": 1.234, + "score": 99999, + "nested_dict_test": {"deep1": 1, "deep string": "just testing"}, + }, + { + "_text": "bar", + "title": "title 2", + }, + { + "text": "foo and bar", + }, + { + "_text": "Where is the bar", + "another": "something else", + }, +] + +# Use text or _text from DOCS for our test sentences +SENTENCES = [d.get("text", d.get("_text")) for d in DOCS] + ## Tests ######################################################################## @@ -40,17 +86,44 @@ def test_bootstrap_and_run(): model = EmbeddingModule.bootstrap(SEQ_CLASS_MODEL) result = model.run(INPUT) _assert_is_expected_embedding_result(result) +def _assert_types_found(types_found): + assert type(types_found["str_test"]) == str, "passthru str value type check" + assert type(types_found["int_test"]) == int, "passthru int value type check" + assert type(types_found["float_test"]) == float, "passthru float value type check" + assert ( + type(types_found["nested_dict_test"]) == dict + ), "passthru nested dict value type check" -def test_run_type_check(): - """Input cannot be a list""" - model = BOOTSTRAPPED_MODEL - with pytest.raises(TypeError): - model.run([INPUT]) - pytest.fail("Should not reach here") +def _assert_valid_scores(scores, type_tests={}): + for score in scores: + assert isinstance(score, RerankScore) + assert isinstance(score.score, float) + assert isinstance(score.index, int) + assert isinstance(score.text, str) + document = score.document + assert isinstance(document, dict) + assert document == DOCS[score.index] -def test_save_load_and_run_model(): + # Test document key named score (None or 9999) is independent of the result score + assert score.score != document.get( + "score" + ), "unexpected passthru score same as result score" + + # Gather various type test values when we have them + for k, v in document.items(): + if k in TYPE_KEYS: + type_tests[k] = v + + return type_tests + + +def test_bootstrap(): + assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap error" + + +def test_save_load_and_run(): """Check if we can load and run a saved model successfully""" model_id = "model_id" with tempfile.TemporaryDirectory(suffix="-1st") as model_dir: @@ -58,7 +131,11 @@ def test_save_load_and_run_model(): BOOTSTRAPPED_MODEL.save(model_path) new_model = EmbeddingModule.load(model_path) - result = new_model.run(input=INPUT) + assert isinstance(new_model, EmbeddingModule), "save and load error" + assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model" + + # Use run_embedding just to make sure this new model is usable + result = new_model.run_embedding(text=INPUT) _assert_is_expected_embedding_result(result) @@ -94,3 +171,200 @@ def test_second_save_hits_exists_check(): def test_save_type_checks(model_path): with pytest.raises(TypeError): BOOTSTRAPPED_MODEL.save(model_path) + + +def test_load_without_artifacts(): + """Test coverage for the error message when config has no artifacts to load""" + with pytest.raises(ValueError): + EmbeddingModule.load(ModuleConfig({})) + + +def test_run_embedding_type_check(): + """Input cannot be a list""" + model = BOOTSTRAPPED_MODEL + with pytest.raises(TypeError): + model.run_embedding([INPUT]) + pytest.fail("Should not reach here") + + +def test_run_embedding(): + model = BOOTSTRAPPED_MODEL + res = model.run_embedding(text=INPUT) + _assert_is_expected_embedding_result(res) + + +def test_run_embeddings_str_type(): + """Supposed to be a list, gets fixed automatically.""" + model = BOOTSTRAPPED_MODEL + res = model.run_embeddings(texts=INPUT) + assert isinstance(res.results, list) + assert len(res.results) == 1 + + +def test_run_embeddings(): + model = BOOTSTRAPPED_MODEL + res = model.run_embeddings(texts=[INPUT]) + assert isinstance(res.results, list) + _assert_is_expected_embedding_result(res) + + +@pytest.mark.parametrize( + "query,docs,top_n", + [ + (["test list"], DOCS, None), + (None, DOCS, 1234), + (False, DOCS, 1234), + (QUERY, {"testdict": "not list"}, 1234), + (QUERY, DOCS, "topN string is not an integer or None"), + ], +) +def test_run_rerank_query_type_error(query, docs, top_n): + """test for type checks matching task/run signature""" + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.run_rerank_query(query=query, documents=docs, top_n=top_n) + pytest.fail("Should not reach here.") + + +def test_run_rerank_query_no_type_error(): + """no type error with list of string queries and list of dict documents""" + BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS, top_n=1) + + +@pytest.mark.parametrize( + "top_n, expected", + [ + (1, 1), + (2, 2), + (None, len(DOCS)), + (-1, len(DOCS)), + (0, len(DOCS)), + (9999, len(DOCS)), + ], +) +def test_run_rerank_query_top_n(top_n, expected): + res = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n) + assert isinstance(res, RerankQueryResult) + assert len(res.scores) == expected + + +def test_run_rerank_query_no_query(): + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.run_rerank_query(query=None, documents=DOCS, top_n=99) + + +def test_run_rerank_query_zero_docs(): + """No empty doc list therefore result is zero result scores""" + result = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=[], top_n=99) + assert len(result.scores) == 0 + + +def test_run_rerank_query(): + result = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS) + assert isinstance(result, RerankQueryResult) + + scores = result.scores + assert isinstance(scores, list) + assert len(scores) == len(DOCS) + + types_found = _assert_valid_scores(scores) + _assert_types_found(types_found) + + +@pytest.mark.parametrize( + "queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})] +) +def test_run_rerank_queries_type_error(queries, docs): + """type error check ensures params are lists and not just 1 string or just one doc (for example)""" + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.run_rerank_queries(queries=queries, documents=docs) + pytest.fail("Should not reach here.") + + +def test_run_rerank_queries_no_type_error(): + """no type error with list of string queries and list of dict documents""" + BOOTSTRAPPED_MODEL.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99) + + +@pytest.mark.parametrize( + "top_n, expected", + [ + (1, 1), + (2, 2), + (None, len(DOCS)), + (-1, len(DOCS)), + (0, len(DOCS)), + (9999, len(DOCS)), + ], +) +def test_run_rerank_queries_top_n(top_n, expected): + """no type error with list of string queries and list of dict documents""" + res = BOOTSTRAPPED_MODEL.run_rerank_queries( + queries=QUERIES, documents=DOCS, top_n=top_n + ) + assert isinstance(res, RerankPrediction) + assert len(res.results) == len(QUERIES) + for result in res.results: + assert len(result.scores) == expected + + +@pytest.mark.parametrize( + "queries, docs", + [ + ([], DOCS), + (QUERIES, []), + ([], []), + ], + ids=["no queries", "no docs", "no queries and no docs"], +) +def test_run_rerank_queries_no_queries_or_no_docs(queries, docs): + """No queries and/or no docs therefore result is zero results""" + res = BOOTSTRAPPED_MODEL.run_rerank_queries( + queries=queries, documents=docs, top_n=9 + ) + assert isinstance(res, RerankPrediction) + assert len(res.results) == 0 + + +def test_run_rerank_queries(): + top_n = 2 + rerank_result = BOOTSTRAPPED_MODEL.run_rerank_queries( + queries=QUERIES, documents=DOCS, top_n=top_n + ) + assert isinstance(rerank_result, RerankPrediction) + + results = rerank_result.results + assert isinstance(results, list) + assert len(results) == 2 == len(QUERIES) # 2 queries yields 2 result(s) + + types_found = {} # Gather the type tests from any of the results + + for result in results: + assert isinstance(result, RerankQueryResult) + scores = result.scores + assert isinstance(scores, list) + assert len(scores) == top_n + types_found = _assert_valid_scores(scores, types_found) + + # Make sure our document fields of different types made it in/out ok + _assert_types_found(types_found) + + +def test_run_sentence_similarity(): + model = BOOTSTRAPPED_MODEL + result = model.run_sentence_similarity(source_sentence=QUERY, sentences=SENTENCES) + scores = result.scores + assert len(scores) == len(SENTENCES) + for score in scores: + assert isinstance(score, float) + + +def test_run_sentence_similarities(): + model = BOOTSTRAPPED_MODEL + res = model.run_sentence_similarities(source_sentences=QUERIES, sentences=SENTENCES) + results = res.results + assert len(results) == len(QUERIES) + for result in results: + scores = result.scores + assert len(scores) == len(SENTENCES) + for score in scores: + assert isinstance(score, float) From 967f4d88224da9c6766d0743d5e4f521ede3d9ce Mon Sep 17 00:00:00 2001 From: markstur Date: Fri, 20 Oct 2023 00:35:24 -0700 Subject: [PATCH 07/17] Formatting Signed-off-by: markstur --- .../modules/embedding_retrieval/embedding.py | 345 ++++++++++++++++++ .../sentence_similarity_task.py | 9 +- 2 files changed, 351 insertions(+), 3 deletions(-) create mode 100644 caikit_nlp/modules/embedding_retrieval/embedding.py diff --git a/caikit_nlp/modules/embedding_retrieval/embedding.py b/caikit_nlp/modules/embedding_retrieval/embedding.py new file mode 100644 index 00000000..2f247522 --- /dev/null +++ b/caikit_nlp/modules/embedding_retrieval/embedding.py @@ -0,0 +1,345 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import List, Optional +import os + +# Third Party +from sentence_transformers import SentenceTransformer +from sentence_transformers.util import ( + cos_sim, + dot_score, + normalize_embeddings, + semantic_search, +) + +# First Party +from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module +from caikit.core.data_model.json_dict import JsonDict +from caikit.core.exceptions import error_handler +import alog + +# Local +from .embedding_retrieval_task import EmbeddingTask, EmbeddingTasks +from .rerank_task import RerankTask, RerankTasks +from .sentence_similarity_task import SentenceSimilarityTask, SentenceSimilarityTasks +from caikit_nlp.data_model import ( + ListOfVector1D, + RerankPrediction, + RerankQueryResult, + RerankScore, + SentenceListScores, + SentenceScores, + Vector1D, +) + +logger = alog.use_channel("") +error = error_handler.get(logger) + + +@module( + "EEB12558-B4FA-4F34-A9FD-3F5890E9CD3F", + "Text Embedding", + "0.0.1", + tasks=[ + EmbeddingTask, + EmbeddingTasks, + SentenceSimilarityTask, + SentenceSimilarityTasks, + RerankTask, + RerankTasks, + ], +) +class TextEmbedding(ModuleBase): + + _ARTIFACTS_PATH_KEY = "artifacts_path" + _ARTIFACTS_PATH_DEFAULT = "artifacts" + + def __init__( + self, + model: SentenceTransformer, + ): + super().__init__() + self.model = model + + @classmethod + def load(cls, model_path: str, *args, **kwargs) -> "TextEmbedding": + """Load model + + Args: + model_path: str + Path to the config dir under the model_id (where the config.yml lives) + + Returns: + EmbeddingModule + Instance of this class built from the model. + """ + + config = ModuleConfig.load(model_path) + artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY) + + error.value_check( + "", + artifacts_path, + ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"), + ) + + artifacts_path = os.path.abspath(os.path.join(model_path, artifacts_path)) + error.dir_check("", artifacts_path) + + return cls.bootstrap(model_name_or_path=artifacts_path) + + @EmbeddingTask.taskmethod() + def run_embedding(self, text: str) -> Vector1D: # pylint: disable=redefined-builtin + """Get embedding for a string. + Args: + text: str + Input text to be processed + Returns: + Vector1D: the output + """ + error.type_check("", str, text=text) + + return Vector1D.from_embeddings(self.model.encode(text)) + + @EmbeddingTasks.taskmethod() + def run_embeddings( + self, texts: List[str] # pylint: disable=redefined-builtin + ) -> ListOfVector1D: + """Run inference on model. + Args: + texts: List[str] + List of input texts to be processed + Returns: + List[Vector1D]: List vectors. One for each input text (in order). + Each vector is a list of floats (supports various float types). + """ + if isinstance( + texts, str + ): # encode allows str, but the result would lack a dimension + texts = [texts] + + embeddings = self.model.encode(texts) + results = [Vector1D.from_embeddings(e) for e in embeddings] + return ListOfVector1D(results=results) + + @SentenceSimilarityTask.taskmethod() + def run_sentence_similarity( + self, source_sentence: str, sentences: List[str] + ) -> SentenceScores: # pylint: disable=arguments-differ + """Run inference on model. + Args: + source_sentence: str + sentences: List[str] + Sentences to compare to source_sentence + Returns: + SentenceScores + """ + + source_embedding = self.model.encode(source_sentence) + embeddings = self.model.encode(sentences) + + res = cos_sim(source_embedding, embeddings) + return SentenceScores(res.tolist()[0]) + + @SentenceSimilarityTasks.taskmethod() + def run_sentence_similarities( + self, source_sentences: List[str], sentences: List[str] + ) -> SentenceListScores: # pylint: disable=arguments-differ + """Run inference on model. + Args: + source_sentences: List[str] + sentences: List[str] + Sentences to compare to source_sentences + Returns: + SentenceListScores Similarity scores for each source sentence in order. + each SentenceScores contains the source-sentence's score for each sentence in order. + """ + + source_embedding = self.model.encode(source_sentences) + embeddings = self.model.encode(sentences) + + res = cos_sim(source_embedding, embeddings) + float_list_list = res.tolist() + return SentenceListScores( + results=[SentenceScores(fl) for fl in float_list_list] + ) + + @RerankTask.taskmethod() + def run_rerank_query( + self, + query: str, + documents: List[JsonDict], + top_n: Optional[int] = None, + return_documents: bool = True, + return_query: bool = True, + return_text: bool = True, + ) -> RerankQueryResult: + """Run inference on model. + Args: + query: str + documents: List[JsonDict] + top_n: Optional[int] + Returns: + RerankQueryResult + """ + + error.type_check( + "", + str, + query=query, + ) + + results = self.run_rerank_queries( + queries=[query], + documents=documents, + top_n=top_n, + return_documents=return_documents, + return_queries=return_query, + return_text=return_text, + ).results + + return ( + results[0] + if len(results) > 0 + else RerankQueryResult(scores=[], query=query if return_query else None) + ) + + @RerankTasks.taskmethod() + def run_rerank_queries( + self, + queries: List[str], + documents: List[JsonDict], + top_n: Optional[int] = None, + return_documents: bool = True, + return_queries: bool = True, + return_text: bool = True, + ) -> RerankPrediction: + """Run inference on model. + Args: + queries: List[str] + documents: List[JsonDict] + top_n: Optional[int] + return_documents: bool + return_queries: bool + return_text: bool + Returns: + RerankPrediction + """ + + error.type_check( + "", + list, + queries=queries, + documents=documents, + ) + + if len(queries) < 1 or len(documents) < 1: + return RerankPrediction([]) + + if top_n is None or top_n < 1: + top_n = len(documents) + + # Using input document dicts so get "text" else "_text" else default to "" + def get_text(srd): + return srd.get("text") or srd.get("_text", "") + + doc_texts = [get_text(srd) for srd in documents] + + doc_embeddings = self.model.encode(doc_texts, convert_to_tensor=True) + doc_embeddings = doc_embeddings.to(self.model.device) + doc_embeddings = normalize_embeddings(doc_embeddings) + + query_embeddings = self.model.encode(queries, convert_to_tensor=True) + query_embeddings = query_embeddings.to(self.model.device) + query_embeddings = normalize_embeddings(query_embeddings) + + res = semantic_search( + query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score + ) + + # Fixup result dicts + for r in res: + for x in r: + # Renaming corpus_id to index + corpus_id = x.pop("corpus_id") + x["index"] = corpus_id + # Optionally adding the original document and/or just the text that was used + if return_documents: + x["document"] = documents[corpus_id] + if return_text: + x["text"] = get_text(documents[corpus_id]) + + def add_query(q): + return queries[q] if return_queries else None + + results = [ + RerankQueryResult(query=add_query(q), scores=[RerankScore(**x) for x in r]) + for q, r in enumerate(res) + ] + + return RerankPrediction(results=results) + + @classmethod + def bootstrap(cls, model_name_or_path: str) -> "TextEmbedding": + """Bootstrap a sentence-transformers model + + Args: + model_name_or_path: str + Model name (Hugging Face hub) or path to model to load. + """ + return cls(model=SentenceTransformer(model_name_or_path=model_name_or_path)) + + def save(self, model_path: str, *args, **kwargs): + """Save model using config in model_path + + Args: + model_path: str + Path to model config + """ + + model_config_path = model_path # because the param name is misleading + + error.type_check("", str, model_path=model_config_path) + error.value_check( + "", + model_config_path is not None and model_config_path.strip(), + f"model_path '{model_config_path}' is invalid", + ) + + model_config_path = os.path.abspath( + model_config_path.strip() + ) # No leading/trailing spaces sneaky weirdness + + os.makedirs(model_config_path, exist_ok=False) + saver = ModuleSaver( + module=self, + model_path=model_config_path, + ) + + # Get and update config (artifacts_path) + artifacts_path = saver.config.get(self._ARTIFACTS_PATH_KEY) + if not artifacts_path: + artifacts_path = self._ARTIFACTS_PATH_DEFAULT + saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path}) + + # Save the model + artifacts_path = os.path.abspath( + os.path.join(model_config_path, artifacts_path) + ) + self.model.save(artifacts_path, create_model_card=True) + + # Save the config + ModuleConfig(saver.config).save(model_config_path) diff --git a/caikit_nlp/modules/text_embedding/sentence_similarity_task.py b/caikit_nlp/modules/text_embedding/sentence_similarity_task.py index 458fa72a..021aec23 100644 --- a/caikit_nlp/modules/text_embedding/sentence_similarity_task.py +++ b/caikit_nlp/modules/text_embedding/sentence_similarity_task.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Standard +from typing import List + # First Party from caikit.core import TaskBase, task -from typing import List +# Local from ...data_model import SentenceListScores, SentenceScores @task( required_parameters={"source_sentence": str, "sentences": List[str]}, - output_type=SentenceScores + output_type=SentenceScores, ) class SentenceSimilarityTask(TaskBase): """Compare the source_sentence to each of the sentences. @@ -31,7 +34,7 @@ class SentenceSimilarityTask(TaskBase): @task( required_parameters={"source_sentences": List[str], "sentences": List[str]}, - output_type=SentenceListScores + output_type=SentenceListScores, ) class SentenceSimilarityTasks(TaskBase): """Compare each of the source_sentences to each of the sentences. From 0ce4b38b21c6f88f56dd82dc62c0c318a440b6aa Mon Sep 17 00:00:00 2001 From: markstur Date: Tue, 24 Oct 2023 14:47:27 -0700 Subject: [PATCH 08/17] Updates after rebase Signed-off-by: markstur --- caikit_nlp/data_model/embedding_vectors.py | 7 +- .../modules/embedding_retrieval/embedding.py | 345 ------------------ .../modules/text_embedding/embedding.py | 163 ++++++++- .../modules/text_embedding/embedding_tasks.py | 3 +- tests/data_model/test_embedding_vectors.py | 4 +- .../modules/text_embedding/test_embedding.py | 41 ++- 6 files changed, 196 insertions(+), 367 deletions(-) delete mode 100644 caikit_nlp/modules/embedding_retrieval/embedding.py diff --git a/caikit_nlp/data_model/embedding_vectors.py b/caikit_nlp/data_model/embedding_vectors.py index 22fa8277..4379372b 100644 --- a/caikit_nlp/data_model/embedding_vectors.py +++ b/caikit_nlp/data_model/embedding_vectors.py @@ -79,9 +79,12 @@ def __post_init__(self): @classmethod def from_vector(cls, vector): - if vector.dtype == np.float32: + dtype = getattr(vector, "dtype", False) + if dtype is None: + data = PyFloatSequence(vector) + elif dtype == np.float32: data = NpFloat32Sequence(vector) - elif vector.dtype == np.float64: + elif dtype == np.float64: data = NpFloat64Sequence(vector) else: data = PyFloatSequence(vector) diff --git a/caikit_nlp/modules/embedding_retrieval/embedding.py b/caikit_nlp/modules/embedding_retrieval/embedding.py deleted file mode 100644 index 2f247522..00000000 --- a/caikit_nlp/modules/embedding_retrieval/embedding.py +++ /dev/null @@ -1,345 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard -from typing import List, Optional -import os - -# Third Party -from sentence_transformers import SentenceTransformer -from sentence_transformers.util import ( - cos_sim, - dot_score, - normalize_embeddings, - semantic_search, -) - -# First Party -from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module -from caikit.core.data_model.json_dict import JsonDict -from caikit.core.exceptions import error_handler -import alog - -# Local -from .embedding_retrieval_task import EmbeddingTask, EmbeddingTasks -from .rerank_task import RerankTask, RerankTasks -from .sentence_similarity_task import SentenceSimilarityTask, SentenceSimilarityTasks -from caikit_nlp.data_model import ( - ListOfVector1D, - RerankPrediction, - RerankQueryResult, - RerankScore, - SentenceListScores, - SentenceScores, - Vector1D, -) - -logger = alog.use_channel("") -error = error_handler.get(logger) - - -@module( - "EEB12558-B4FA-4F34-A9FD-3F5890E9CD3F", - "Text Embedding", - "0.0.1", - tasks=[ - EmbeddingTask, - EmbeddingTasks, - SentenceSimilarityTask, - SentenceSimilarityTasks, - RerankTask, - RerankTasks, - ], -) -class TextEmbedding(ModuleBase): - - _ARTIFACTS_PATH_KEY = "artifacts_path" - _ARTIFACTS_PATH_DEFAULT = "artifacts" - - def __init__( - self, - model: SentenceTransformer, - ): - super().__init__() - self.model = model - - @classmethod - def load(cls, model_path: str, *args, **kwargs) -> "TextEmbedding": - """Load model - - Args: - model_path: str - Path to the config dir under the model_id (where the config.yml lives) - - Returns: - EmbeddingModule - Instance of this class built from the model. - """ - - config = ModuleConfig.load(model_path) - artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY) - - error.value_check( - "", - artifacts_path, - ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"), - ) - - artifacts_path = os.path.abspath(os.path.join(model_path, artifacts_path)) - error.dir_check("", artifacts_path) - - return cls.bootstrap(model_name_or_path=artifacts_path) - - @EmbeddingTask.taskmethod() - def run_embedding(self, text: str) -> Vector1D: # pylint: disable=redefined-builtin - """Get embedding for a string. - Args: - text: str - Input text to be processed - Returns: - Vector1D: the output - """ - error.type_check("", str, text=text) - - return Vector1D.from_embeddings(self.model.encode(text)) - - @EmbeddingTasks.taskmethod() - def run_embeddings( - self, texts: List[str] # pylint: disable=redefined-builtin - ) -> ListOfVector1D: - """Run inference on model. - Args: - texts: List[str] - List of input texts to be processed - Returns: - List[Vector1D]: List vectors. One for each input text (in order). - Each vector is a list of floats (supports various float types). - """ - if isinstance( - texts, str - ): # encode allows str, but the result would lack a dimension - texts = [texts] - - embeddings = self.model.encode(texts) - results = [Vector1D.from_embeddings(e) for e in embeddings] - return ListOfVector1D(results=results) - - @SentenceSimilarityTask.taskmethod() - def run_sentence_similarity( - self, source_sentence: str, sentences: List[str] - ) -> SentenceScores: # pylint: disable=arguments-differ - """Run inference on model. - Args: - source_sentence: str - sentences: List[str] - Sentences to compare to source_sentence - Returns: - SentenceScores - """ - - source_embedding = self.model.encode(source_sentence) - embeddings = self.model.encode(sentences) - - res = cos_sim(source_embedding, embeddings) - return SentenceScores(res.tolist()[0]) - - @SentenceSimilarityTasks.taskmethod() - def run_sentence_similarities( - self, source_sentences: List[str], sentences: List[str] - ) -> SentenceListScores: # pylint: disable=arguments-differ - """Run inference on model. - Args: - source_sentences: List[str] - sentences: List[str] - Sentences to compare to source_sentences - Returns: - SentenceListScores Similarity scores for each source sentence in order. - each SentenceScores contains the source-sentence's score for each sentence in order. - """ - - source_embedding = self.model.encode(source_sentences) - embeddings = self.model.encode(sentences) - - res = cos_sim(source_embedding, embeddings) - float_list_list = res.tolist() - return SentenceListScores( - results=[SentenceScores(fl) for fl in float_list_list] - ) - - @RerankTask.taskmethod() - def run_rerank_query( - self, - query: str, - documents: List[JsonDict], - top_n: Optional[int] = None, - return_documents: bool = True, - return_query: bool = True, - return_text: bool = True, - ) -> RerankQueryResult: - """Run inference on model. - Args: - query: str - documents: List[JsonDict] - top_n: Optional[int] - Returns: - RerankQueryResult - """ - - error.type_check( - "", - str, - query=query, - ) - - results = self.run_rerank_queries( - queries=[query], - documents=documents, - top_n=top_n, - return_documents=return_documents, - return_queries=return_query, - return_text=return_text, - ).results - - return ( - results[0] - if len(results) > 0 - else RerankQueryResult(scores=[], query=query if return_query else None) - ) - - @RerankTasks.taskmethod() - def run_rerank_queries( - self, - queries: List[str], - documents: List[JsonDict], - top_n: Optional[int] = None, - return_documents: bool = True, - return_queries: bool = True, - return_text: bool = True, - ) -> RerankPrediction: - """Run inference on model. - Args: - queries: List[str] - documents: List[JsonDict] - top_n: Optional[int] - return_documents: bool - return_queries: bool - return_text: bool - Returns: - RerankPrediction - """ - - error.type_check( - "", - list, - queries=queries, - documents=documents, - ) - - if len(queries) < 1 or len(documents) < 1: - return RerankPrediction([]) - - if top_n is None or top_n < 1: - top_n = len(documents) - - # Using input document dicts so get "text" else "_text" else default to "" - def get_text(srd): - return srd.get("text") or srd.get("_text", "") - - doc_texts = [get_text(srd) for srd in documents] - - doc_embeddings = self.model.encode(doc_texts, convert_to_tensor=True) - doc_embeddings = doc_embeddings.to(self.model.device) - doc_embeddings = normalize_embeddings(doc_embeddings) - - query_embeddings = self.model.encode(queries, convert_to_tensor=True) - query_embeddings = query_embeddings.to(self.model.device) - query_embeddings = normalize_embeddings(query_embeddings) - - res = semantic_search( - query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score - ) - - # Fixup result dicts - for r in res: - for x in r: - # Renaming corpus_id to index - corpus_id = x.pop("corpus_id") - x["index"] = corpus_id - # Optionally adding the original document and/or just the text that was used - if return_documents: - x["document"] = documents[corpus_id] - if return_text: - x["text"] = get_text(documents[corpus_id]) - - def add_query(q): - return queries[q] if return_queries else None - - results = [ - RerankQueryResult(query=add_query(q), scores=[RerankScore(**x) for x in r]) - for q, r in enumerate(res) - ] - - return RerankPrediction(results=results) - - @classmethod - def bootstrap(cls, model_name_or_path: str) -> "TextEmbedding": - """Bootstrap a sentence-transformers model - - Args: - model_name_or_path: str - Model name (Hugging Face hub) or path to model to load. - """ - return cls(model=SentenceTransformer(model_name_or_path=model_name_or_path)) - - def save(self, model_path: str, *args, **kwargs): - """Save model using config in model_path - - Args: - model_path: str - Path to model config - """ - - model_config_path = model_path # because the param name is misleading - - error.type_check("", str, model_path=model_config_path) - error.value_check( - "", - model_config_path is not None and model_config_path.strip(), - f"model_path '{model_config_path}' is invalid", - ) - - model_config_path = os.path.abspath( - model_config_path.strip() - ) # No leading/trailing spaces sneaky weirdness - - os.makedirs(model_config_path, exist_ok=False) - saver = ModuleSaver( - module=self, - model_path=model_config_path, - ) - - # Get and update config (artifacts_path) - artifacts_path = saver.config.get(self._ARTIFACTS_PATH_KEY) - if not artifacts_path: - artifacts_path = self._ARTIFACTS_PATH_DEFAULT - saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path}) - - # Save the model - artifacts_path = os.path.abspath( - os.path.join(model_config_path, artifacts_path) - ) - self.model.save(artifacts_path, create_model_card=True) - - # Save the config - ModuleConfig(saver.config).save(model_config_path) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index e8c05cca..e8f7526d 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -103,7 +103,9 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": return cls.bootstrap(model_name_or_path=artifacts_path) @EmbeddingTask.taskmethod() - def run_embedding(self, text: str) -> EmbeddingResult: # pylint: disable=redefined-builtin + def run_embedding( + self, text: str + ) -> EmbeddingResult: # pylint: disable=redefined-builtin """Get embedding for a string. Args: text: str @@ -133,9 +135,166 @@ def run_embeddings( texts = [texts] embeddings = self.model.encode(texts) - results = [Vector1D.from_embeddings(e) for e in embeddings] + results = [Vector1D.from_vector(e) for e in embeddings] return ListOfVector1D(results=results) + @SentenceSimilarityTask.taskmethod() + def run_sentence_similarity( + self, source_sentence: str, sentences: List[str] + ) -> SentenceScores: # pylint: disable=arguments-differ + """Run inference on model. + Args: + source_sentence: str + sentences: List[str] + Sentences to compare to source_sentence + Returns: + SentenceScores + """ + + source_embedding = self.model.encode(source_sentence) + embeddings = self.model.encode(sentences) + + res = cos_sim(source_embedding, embeddings) + return SentenceScores(res.tolist()[0]) + + @SentenceSimilarityTasks.taskmethod() + def run_sentence_similarities( + self, source_sentences: List[str], sentences: List[str] + ) -> SentenceListScores: # pylint: disable=arguments-differ + """Run inference on model. + Args: + source_sentences: List[str] + sentences: List[str] + Sentences to compare to source_sentences + Returns: + SentenceListScores Similarity scores for each source sentence in order. + each SentenceScores contains the source-sentence's score for each sentence in order. + """ + + source_embedding = self.model.encode(source_sentences) + embeddings = self.model.encode(sentences) + + res = cos_sim(source_embedding, embeddings) + float_list_list = res.tolist() + return SentenceListScores( + results=[SentenceScores(fl) for fl in float_list_list] + ) + + @RerankTask.taskmethod() + def run_rerank_query( + self, + query: str, + documents: List[JsonDict], + top_n: Optional[int] = None, + return_documents: bool = True, + return_query: bool = True, + return_text: bool = True, + ) -> RerankQueryResult: + """Run inference on model. + Args: + query: str + documents: List[JsonDict] + top_n: Optional[int] + Returns: + RerankQueryResult + """ + + error.type_check( + "", + str, + query=query, + ) + + results = self.run_rerank_queries( + queries=[query], + documents=documents, + top_n=top_n, + return_documents=return_documents, + return_queries=return_query, + return_text=return_text, + ).results + + return ( + results[0] + if len(results) > 0 + else RerankQueryResult(scores=[], query=query if return_query else None) + ) + + @RerankTasks.taskmethod() + def run_rerank_queries( + self, + queries: List[str], + documents: List[JsonDict], + top_n: Optional[int] = None, + return_documents: bool = True, + return_queries: bool = True, + return_text: bool = True, + ) -> RerankPrediction: + """Run inference on model. + Args: + queries: List[str] + documents: List[JsonDict] + top_n: Optional[int] + return_documents: bool + return_queries: bool + return_text: bool + Returns: + RerankPrediction + """ + + error.type_check( + "", + list, + queries=queries, + documents=documents, + ) + + if len(queries) < 1 or len(documents) < 1: + return RerankPrediction([]) + + if top_n is None or top_n < 1: + top_n = len(documents) + + # Using input document dicts so get "text" else "_text" else default to "" + def get_text(srd): + return srd.get("text") or srd.get("_text", "") + + doc_texts = [get_text(srd) for srd in documents] + + doc_embeddings = self.model.encode(doc_texts, convert_to_tensor=True) + doc_embeddings = doc_embeddings.to(self.model.device) + doc_embeddings = normalize_embeddings(doc_embeddings) + + query_embeddings = self.model.encode(queries, convert_to_tensor=True) + query_embeddings = query_embeddings.to(self.model.device) + query_embeddings = normalize_embeddings(query_embeddings) + + res = semantic_search( + query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score + ) + + # Fixup result dicts + for r in res: + for x in r: + # Renaming corpus_id to index + corpus_id = x.pop("corpus_id") + x["index"] = corpus_id + # Optionally adding the original document and/or just the text that was used + if return_documents: + x["document"] = documents[corpus_id] + if return_text: + x["text"] = get_text(documents[corpus_id]) + + def add_query(q): + return queries[q] if return_queries else None + + results = [ + RerankQueryResult(query=add_query(q), scores=[RerankScore(**x) for x in r]) + for q, r in enumerate(res) + ] + + return RerankPrediction(results=results) + @classmethod def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule": """Bootstrap a sentence-transformers model diff --git a/caikit_nlp/modules/text_embedding/embedding_tasks.py b/caikit_nlp/modules/text_embedding/embedding_tasks.py index a9a4ebf9..e084b920 100644 --- a/caikit_nlp/modules/text_embedding/embedding_tasks.py +++ b/caikit_nlp/modules/text_embedding/embedding_tasks.py @@ -19,8 +19,7 @@ from caikit.core import TaskBase, task # Local -from ...data_model import EmbeddingResult -from ...data_model import ListOfVector1D +from ...data_model import EmbeddingResult, ListOfVector1D @task( diff --git a/tests/data_model/test_embedding_vectors.py b/tests/data_model/test_embedding_vectors.py index 6d65cb4a..54930283 100644 --- a/tests/data_model/test_embedding_vectors.py +++ b/tests/data_model/test_embedding_vectors.py @@ -123,8 +123,8 @@ def test_vector1d_dm(float_seq_class, random_values, float_type): (dm.NpFloat64Sequence, random_numpy_vector1d_float64, np.float64), ], ) -def test_vector1d_dm_from_embeddings(float_seq_class, random_values, float_type): - v = dm.Vector1D.from_embeddings(random_values) +def test_vector1d_dm_from_vector(float_seq_class, random_values, float_type): + v = dm.Vector1D.from_vector(random_values) assert isinstance(v.data, float_seq_class) assert isinstance(v.data.values[0], float_type) _assert_array_check(v, random_values, float_type) diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index c92f114e..af565f25 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -16,6 +16,7 @@ # Local from caikit_nlp.data_model import ( EmbeddingResult, + ListOfVector1D, RerankPrediction, RerankQueryResult, RerankScore, @@ -71,21 +72,33 @@ ## Tests ######################################################################## +def _assert_is_expected_vector(vector): + assert isinstance(vector.data.values[0], np.float32) + assert len(vector.data.values) == 32 + # Just testing a few values for readability + assert approx(vector.data.values[0]) == 0.3244932293891907 + assert approx(vector.data.values[1]) == -0.4934631288051605 + assert approx(vector.data.values[2]) == 0.5721234083175659 + + def _assert_is_expected_embedding_result(actual): assert isinstance(actual, EmbeddingResult) - assert isinstance(actual.result.data.values[0], np.float32) - assert len(actual.result.data.values) == 32 - # Just testing a few values for readability - assert approx(actual.result.data.values[0]) == 0.3244932293891907 - assert approx(actual.result.data.values[1]) == -0.4934631288051605 - assert approx(actual.result.data.values[2]) == 0.5721234083175659 + vector = actual.result + _assert_is_expected_vector(vector) + + +def _assert_is_expected_embeddings_results(actual): + assert isinstance(actual, ListOfVector1D) + vectors = actual.results + _assert_is_expected_vector(vectors[0]) + + +def test_bootstrap(): + assert isinstance( + EmbeddingModule.bootstrap(SEQ_CLASS_MODEL), EmbeddingModule + ), "bootstrap error" -def test_bootstrap_and_run(): - """Check if we can bootstrap and run embedding""" - model = EmbeddingModule.bootstrap(SEQ_CLASS_MODEL) - result = model.run(INPUT) - _assert_is_expected_embedding_result(result) def _assert_types_found(types_found): assert type(types_found["str_test"]) == str, "passthru str value type check" assert type(types_found["int_test"]) == int, "passthru int value type check" @@ -203,9 +216,9 @@ def test_run_embeddings_str_type(): def test_run_embeddings(): model = BOOTSTRAPPED_MODEL - res = model.run_embeddings(texts=[INPUT]) - assert isinstance(res.results, list) - _assert_is_expected_embedding_result(res) + results = model.run_embeddings(texts=[INPUT]) + assert isinstance(results.results, list) + _assert_is_expected_embeddings_results(results) @pytest.mark.parametrize( From 0c97da73be4ecbb983e2c5a80104d0d35ce1fbde Mon Sep 17 00:00:00 2001 From: markstur Date: Thu, 26 Oct 2023 13:21:04 -0700 Subject: [PATCH 09/17] More docstring and rename for RerankPredictions * More docstrings to help code readers (doc viewers?) * Renamed RerankPrediction -> RerankPredictions since plural is better as it is being used for multiple queries each with a RerankQueryResult with scores. Signed-off-by: markstur --- caikit_nlp/data_model/reranker.py | 11 +++- .../modules/text_embedding/embedding.py | 57 +++++++++++++++---- .../modules/text_embedding/rerank_task.py | 4 +- tests/data_model/test_reranker.py | 2 +- .../modules/text_embedding/test_embedding.py | 8 +-- 5 files changed, 62 insertions(+), 20 deletions(-) diff --git a/caikit_nlp/data_model/reranker.py b/caikit_nlp/data_model/reranker.py index b0ff7be9..995038c9 100644 --- a/caikit_nlp/data_model/reranker.py +++ b/caikit_nlp/data_model/reranker.py @@ -32,14 +32,19 @@ class RerankScore(DataObjectBase): @dataobject() class RerankQueryResult(DataObjectBase): - """Result for one query in a rerank task""" + """Result for one query in a rerank task. + This is a list of n ReRankScore where n is based on top_n documents and each score indicates + the relevance of that document for this query. Results are ordered most-relevant first. + """ query: Optional[str] scores: List[RerankScore] @dataobject() -class RerankPrediction(DataObjectBase): - """Result for a rerank task""" +class RerankPredictions(DataObjectBase): + """Result for a rerank tasks (supporting multiple queries). + For multiple queries, each one has a RerankQueryResult (ranking the documents for that query). + """ results: List[RerankQueryResult] diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index e8f7526d..e224d1b7 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -38,7 +38,7 @@ from caikit_nlp.data_model import ( EmbeddingResult, ListOfVector1D, - RerankPrediction, + RerankPredictions, RerankQueryResult, RerankScore, SentenceListScores, @@ -121,7 +121,7 @@ def run_embedding( def run_embeddings( self, texts: List[str] # pylint: disable=redefined-builtin ) -> ListOfVector1D: - """Run inference on model. + """Get embedding vectors for texts. Args: texts: List[str] List of input texts to be processed @@ -142,7 +142,7 @@ def run_embeddings( def run_sentence_similarity( self, source_sentence: str, sentences: List[str] ) -> SentenceScores: # pylint: disable=arguments-differ - """Run inference on model. + """Get similarity scores for each of sentences compared to the source_sentence. Args: source_sentence: str sentences: List[str] @@ -161,7 +161,7 @@ def run_sentence_similarity( def run_sentence_similarities( self, source_sentences: List[str], sentences: List[str] ) -> SentenceListScores: # pylint: disable=arguments-differ - """Run inference on model. + """Run sentence-similarities on model. Args: source_sentences: List[str] sentences: List[str] @@ -190,13 +190,33 @@ def run_rerank_query( return_query: bool = True, return_text: bool = True, ) -> RerankQueryResult: - """Run inference on model. + """Rerank the documents returning the most relevant top_n in order for this query. Args: query: str + Query is the source string to be compared to the text of the documents. documents: List[JsonDict] + Each document is a dict. The text value is used for comparison to the query. + If there is no text key, then _text is used and finally default is "". top_n: Optional[int] + Results for the top n most relevant documents will be returned. + If top_n is not provided or (not > 0), then all are returned. + return_documents: bool + Default True + Setting to False will disable returning of the input document (index is returned). + return_query: bool + Default True + Setting to False will disable returning of the query (results are in query order) + return_text: bool + Default True + Setting to False will disable returning of document text string that was used. Returns: RerankQueryResult + Returns the (top_n) scores in relevance order (most relevant first). + The results always include a score and index which may be used to find the document + in the original documents list. Optionally, the results also contain the entire + document with its score (for use in chaining) and for convenience the query and + text used for comparison may be returned. + """ error.type_check( @@ -229,17 +249,34 @@ def run_rerank_queries( return_documents: bool = True, return_queries: bool = True, return_text: bool = True, - ) -> RerankPrediction: - """Run inference on model. + ) -> RerankPredictions: + """Rerank the documents returning the most relevant top_n in order for each of the queries. Args: queries: List[str] + Each of the queries will be compared to the text of each of the documents. documents: List[JsonDict] + Each document is a dict. The text value is used for comparison to the query. + If there is no text key, then _text is used and finally default is "". top_n: Optional[int] + Results for the top n most relevant documents will be returned. + If top_n is not provided or (not > 0), then all are returned. return_documents: bool + Default True + Setting to False will disable returning of the input document (index is returned). return_queries: bool + Default True + Setting to False will disable returning of the query (results are in query order) return_text: bool + Default True + Setting to False will disable returning of document text string that was used. Returns: - RerankPrediction + RerankPredictions + For each query in queries (in the original order)... + Returns the (top_n) scores in relevance order (most relevant first). + The results always include a score and index which may be used to find the document + in the original documents list. Optionally, the results also contain the entire + document with its score (for use in chaining) and for convenience the query and + text used for comparison may be returned. """ error.type_check( @@ -250,7 +287,7 @@ def run_rerank_queries( ) if len(queries) < 1 or len(documents) < 1: - return RerankPrediction([]) + return RerankPredictions([]) if top_n is None or top_n < 1: top_n = len(documents) @@ -293,7 +330,7 @@ def add_query(q): for q, r in enumerate(res) ] - return RerankPrediction(results=results) + return RerankPredictions(results=results) @classmethod def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule": diff --git a/caikit_nlp/modules/text_embedding/rerank_task.py b/caikit_nlp/modules/text_embedding/rerank_task.py index caa9d24f..f92b34d1 100644 --- a/caikit_nlp/modules/text_embedding/rerank_task.py +++ b/caikit_nlp/modules/text_embedding/rerank_task.py @@ -22,7 +22,7 @@ import alog # Local -from caikit_nlp.data_model.reranker import RerankPrediction, RerankQueryResult +from caikit_nlp.data_model.reranker import RerankPredictions, RerankQueryResult logger = alog.use_channel("") error = error_handler.get(logger) @@ -54,7 +54,7 @@ class RerankTask(TaskBase): "documents": List[JsonDict], "queries": List[str], }, - output_type=RerankPrediction, + output_type=RerankPredictions, ) class RerankTasks(TaskBase): """Returns an ordered list for each query ranking the most relevant documents for the query diff --git a/tests/data_model/test_reranker.py b/tests/data_model/test_reranker.py index 853a0151..bf947afc 100644 --- a/tests/data_model/test_reranker.py +++ b/tests/data_model/test_reranker.py @@ -103,7 +103,7 @@ (dm.RerankScore, input_score), (dm.RerankScore, input_random_score), (dm.RerankQueryResult, input_result_1), - (dm.RerankPrediction, {"results": input_results}), + (dm.RerankPredictions, {"results": input_results}), (dm.SentenceScores, input_sentence_similarity_scores_1), (dm.SentenceListScores, {"results": input_sentence_similarities_scores}), ], diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index af565f25..a8282898 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -17,7 +17,7 @@ from caikit_nlp.data_model import ( EmbeddingResult, ListOfVector1D, - RerankPrediction, + RerankPredictions, RerankQueryResult, RerankScore, Vector1D, @@ -314,7 +314,7 @@ def test_run_rerank_queries_top_n(top_n, expected): res = BOOTSTRAPPED_MODEL.run_rerank_queries( queries=QUERIES, documents=DOCS, top_n=top_n ) - assert isinstance(res, RerankPrediction) + assert isinstance(res, RerankPredictions) assert len(res.results) == len(QUERIES) for result in res.results: assert len(result.scores) == expected @@ -334,7 +334,7 @@ def test_run_rerank_queries_no_queries_or_no_docs(queries, docs): res = BOOTSTRAPPED_MODEL.run_rerank_queries( queries=queries, documents=docs, top_n=9 ) - assert isinstance(res, RerankPrediction) + assert isinstance(res, RerankPredictions) assert len(res.results) == 0 @@ -343,7 +343,7 @@ def test_run_rerank_queries(): rerank_result = BOOTSTRAPPED_MODEL.run_rerank_queries( queries=QUERIES, documents=DOCS, top_n=top_n ) - assert isinstance(rerank_result, RerankPrediction) + assert isinstance(rerank_result, RerankPredictions) results = rerank_result.results assert isinstance(results, list) From d8795b3f128137ed79e0463d5f3599b6c01532b1 Mon Sep 17 00:00:00 2001 From: markstur Date: Tue, 31 Oct 2023 23:24:20 -0700 Subject: [PATCH 10/17] Use pytest fixtures in embedding/rerank tests * Some misc clean-up based on review feedback * Use pytest fixtures in the tests Signed-off-by: markstur --- caikit_nlp/data_model/reranker.py | 6 +- .../modules/text_embedding/embedding.py | 12 +- .../modules/text_embedding/rerank_task.py | 5 - tests/data_model/test_embedding_vectors.py | 67 +++--- tests/data_model/test_reranker.py | 210 +++++++++++------- 5 files changed, 180 insertions(+), 120 deletions(-) diff --git a/caikit_nlp/data_model/reranker.py b/caikit_nlp/data_model/reranker.py index 995038c9..924a7909 100644 --- a/caikit_nlp/data_model/reranker.py +++ b/caikit_nlp/data_model/reranker.py @@ -20,7 +20,7 @@ from caikit.core.data_model.json_dict import JsonDict -@dataobject() +@dataobject(package="caikit_data_model.caikit_nlp") class RerankScore(DataObjectBase): """The score for one document (one query)""" @@ -30,7 +30,7 @@ class RerankScore(DataObjectBase): text: Optional[str] -@dataobject() +@dataobject(package="caikit_data_model.caikit_nlp") class RerankQueryResult(DataObjectBase): """Result for one query in a rerank task. This is a list of n ReRankScore where n is based on top_n documents and each score indicates @@ -41,7 +41,7 @@ class RerankQueryResult(DataObjectBase): scores: List[RerankScore] -@dataobject() +@dataobject(package="caikit_data_model.caikit_nlp") class RerankPredictions(DataObjectBase): """Result for a rerank tasks (supporting multiple queries). For multiple queries, each one has a RerankQueryResult (ranking the documents for that query). diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index e224d1b7..19d016d1 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -103,9 +103,7 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": return cls.bootstrap(model_name_or_path=artifacts_path) @EmbeddingTask.taskmethod() - def run_embedding( - self, text: str - ) -> EmbeddingResult: # pylint: disable=redefined-builtin + def run_embedding(self, text: str) -> EmbeddingResult: """Get embedding for a string. Args: text: str @@ -118,9 +116,7 @@ def run_embedding( return EmbeddingResult(Vector1D.from_vector(self.model.encode(text))) @EmbeddingTasks.taskmethod() - def run_embeddings( - self, texts: List[str] # pylint: disable=redefined-builtin - ) -> ListOfVector1D: + def run_embeddings(self, texts: List[str]) -> ListOfVector1D: """Get embedding vectors for texts. Args: texts: List[str] @@ -141,7 +137,7 @@ def run_embeddings( @SentenceSimilarityTask.taskmethod() def run_sentence_similarity( self, source_sentence: str, sentences: List[str] - ) -> SentenceScores: # pylint: disable=arguments-differ + ) -> SentenceScores: """Get similarity scores for each of sentences compared to the source_sentence. Args: source_sentence: str @@ -160,7 +156,7 @@ def run_sentence_similarity( @SentenceSimilarityTasks.taskmethod() def run_sentence_similarities( self, source_sentences: List[str], sentences: List[str] - ) -> SentenceListScores: # pylint: disable=arguments-differ + ) -> SentenceListScores: """Run sentence-similarities on model. Args: source_sentences: List[str] diff --git a/caikit_nlp/modules/text_embedding/rerank_task.py b/caikit_nlp/modules/text_embedding/rerank_task.py index f92b34d1..4b7d67e5 100644 --- a/caikit_nlp/modules/text_embedding/rerank_task.py +++ b/caikit_nlp/modules/text_embedding/rerank_task.py @@ -18,15 +18,10 @@ # First Party from caikit.core import TaskBase, task from caikit.core.data_model.json_dict import JsonDict -from caikit.core.toolkit.errors import error_handler -import alog # Local from caikit_nlp.data_model.reranker import RerankPredictions, RerankQueryResult -logger = alog.use_channel("") -error = error_handler.get(logger) - @task( required_parameters={ diff --git a/tests/data_model/test_embedding_vectors.py b/tests/data_model/test_embedding_vectors.py index 54930283..0894041c 100644 --- a/tests/data_model/test_embedding_vectors.py +++ b/tests/data_model/test_embedding_vectors.py @@ -25,23 +25,34 @@ ## Setup ######################################################################### -RANDOM_SEED = 77 DUMMY_VECTOR_SHAPE = (5,) +RANDOM_SEED = 77 +np.random.seed(RANDOM_SEED) +random_number_generator = np.random.default_rng() # To tests the limits of our type-checking, this can replace our legit data objects TRICK_SEQUENCE = namedtuple("Trick", "values") -np.random.seed(RANDOM_SEED) -random_number_generator = np.random.default_rng() +@pytest.fixture +def simple_array_of_floats(): + return [1.1, 2.2] + + +@pytest.fixture +def random_numpy_vector1d_float32(): + return random_number_generator.random(DUMMY_VECTOR_SHAPE, dtype=np.float32) + + +@pytest.fixture +def random_numpy_vector1d_float64(): + return random_number_generator.random(DUMMY_VECTOR_SHAPE, dtype=np.float64) + + +@pytest.fixture +def random_python_vector1d_float(random_numpy_vector1d_float32): + return random_numpy_vector1d_float32.tolist() -random_numpy_vector1d_float32 = random_number_generator.random( - DUMMY_VECTOR_SHAPE, dtype=np.float32 -) -random_numpy_vector1d_float64 = random_number_generator.random( - DUMMY_VECTOR_SHAPE, dtype=np.float64 -) -random_python_vector1d_float = random_numpy_vector1d_float32.tolist() ## Tests ######################################################################## @@ -90,41 +101,47 @@ def _assert_array_check(new_array, data_values, float_type): @pytest.mark.parametrize( "float_seq_class, random_values, float_type", [ - (dm.PyFloatSequence, random_python_vector1d_float, float), - (dm.NpFloat32Sequence, random_numpy_vector1d_float32, np.float32), - (dm.NpFloat64Sequence, random_numpy_vector1d_float64, np.float64), - (TRICK_SEQUENCE, [1.1, 2.2], float), # Sneaky but tests corner cases for now + (dm.PyFloatSequence, "random_python_vector1d_float", float), + (dm.NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), + (dm.NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), + ( + TRICK_SEQUENCE, + "simple_array_of_floats", + float, + ), # Sneaky but tests corner cases for now ], ) -def test_vector1d_dm(float_seq_class, random_values, float_type): +def test_vector1d_dm(float_seq_class, random_values, float_type, request): # Test init - dm_init = dm.Vector1D(data=float_seq_class(random_values)) - _assert_array_check(dm_init, random_values, float_type) + fixture_values = request.getfixturevalue(random_values) + dm_init = dm.Vector1D(data=float_seq_class(fixture_values)) + _assert_array_check(dm_init, fixture_values, float_type) # Test proto dm_to_proto = dm_init.to_proto() dm_from_proto = dm.Vector1D.from_proto(dm_to_proto) - _assert_array_check(dm_from_proto, random_values, float_type) + _assert_array_check(dm_from_proto, fixture_values, float_type) # Test json dm_to_json = dm_init.to_json() dm_from_json = dm.Vector1D.from_json(dm_to_json) _assert_array_check( - dm_from_json, random_values, float + dm_from_json, fixture_values, float ) # NOTE: always float after json @pytest.mark.parametrize( "float_seq_class, random_values, float_type", [ - (dm.PyFloatSequence, random_python_vector1d_float, float), - (dm.NpFloat32Sequence, random_numpy_vector1d_float32, np.float32), - (dm.NpFloat64Sequence, random_numpy_vector1d_float64, np.float64), + (dm.PyFloatSequence, "random_python_vector1d_float", float), + (dm.NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), + (dm.NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), ], ) -def test_vector1d_dm_from_vector(float_seq_class, random_values, float_type): - v = dm.Vector1D.from_vector(random_values) +def test_vector1d_dm_from_vector(float_seq_class, random_values, float_type, request): + fixture_values = request.getfixturevalue(random_values) + v = dm.Vector1D.from_vector(fixture_values) assert isinstance(v.data, float_seq_class) assert isinstance(v.data.values[0], float_type) - _assert_array_check(v, random_values, float_type) + _assert_array_check(v, fixture_values, float_type) diff --git a/tests/data_model/test_reranker.py b/tests/data_model/test_reranker.py index bf947afc..754272f3 100644 --- a/tests/data_model/test_reranker.py +++ b/tests/data_model/test_reranker.py @@ -26,72 +26,124 @@ ## Setup ######################################################################### -input_document = { - "text": "this is the input text", - "_text": "alternate _text here", - "title": "some title attribute here", - "anything": "another string attribute", - "str_test": "test string", - "int_test": 1234, - "float_test": 9876.4321, -} - -key = "".join(random.choices(string.ascii_letters, k=20)) -value = "".join(random.choices(string.printable, k=100)) -input_random_document = { - "text": "".join(random.choices(string.printable, k=100)), - "random_str": "".join(random.choices(string.printable, k=100)), - "random_int": random.randint(-99999, 99999), - "random_float": random.uniform(-99999, 99999), -} - -input_documents = [input_document, input_random_document] - -input_score = { - "document": input_document, - "index": 1234, - "score": 9876.54321, - "text": "this is the input text", -} - -input_random_score = { - "document": input_random_document, - "index": random.randint(-99999, 99999), - "score": random.uniform(-99999, 99999), - "text": "".join(random.choices(string.printable, k=100)), -} - -input_random_score_3 = { - "document": {"text": "random foo3"}, - "index": random.randint(-99999, 99999), - "score": random.uniform(-99999, 99999), - "text": "".join(random.choices(string.printable, k=100)), -} - -input_scores = [dm.RerankScore(**input_score), dm.RerankScore(**input_random_score)] -input_scores2 = [ - dm.RerankScore(**input_random_score), - dm.RerankScore(**input_random_score_3), -] - -input_result_1 = {"query": "foo", "scores": input_scores} -input_result_2 = {"query": "bar", "scores": input_scores2} -input_results = [ - dm.RerankQueryResult(**input_result_1), - dm.RerankQueryResult(**input_result_2), -] - -input_sentence_similarity_scores_1 = { - "scores": [random.uniform(-99999, 99999) for _ in range(10)] -} -input_sentence_similarity_scores_2 = { - "scores": [random.uniform(-99999, 99999) for _ in range(10)] -} - -input_sentence_similarities_scores = [ - dm.SentenceScores(**input_sentence_similarity_scores_1), - dm.SentenceScores(**input_sentence_similarity_scores_2), -] + +@pytest.fixture +def input_document(): + return { + "text": "this is the input text", + "_text": "alternate _text here", + "title": "some title attribute here", + "anything": "another string attribute", + "str_test": "test string", + "int_test": 1234, + "float_test": 9876.4321, + } + + +@pytest.fixture +def input_random_document(): + return { + "text": "".join(random.choices(string.printable, k=100)), + "random_str": "".join(random.choices(string.printable, k=100)), + "random_int": random.randint(-99999, 99999), + "random_float": random.uniform(-99999, 99999), + } + + +@pytest.fixture +def input_documents(input_document, input_random_document): + return [input_document, input_random_document] + + +@pytest.fixture +def input_score(input_document): + return { + "document": input_document, + "index": 1234, + "score": 9876.54321, + "text": "this is the input text", + } + + +@pytest.fixture +def input_random_score(input_random_document): + return { + "document": input_random_document, + "index": random.randint(-99999, 99999), + "score": random.uniform(-99999, 99999), + "text": "".join(random.choices(string.printable, k=100)), + } + + +@pytest.fixture +def input_random_score_3(): + return { + "document": {"text": "random foo3"}, + "index": random.randint(-99999, 99999), + "score": random.uniform(-99999, 99999), + "text": "".join(random.choices(string.printable, k=100)), + } + + +@pytest.fixture +def input_scores(input_score, input_random_score): + return [dm.RerankScore(**input_score), dm.RerankScore(**input_random_score)] + + +@pytest.fixture +def input_scores2(input_random_score, input_random_score_3): + return [ + dm.RerankScore(**input_random_score), + dm.RerankScore(**input_random_score_3), + ] + + +@pytest.fixture +def input_result_1(input_scores): + return {"query": "foo", "scores": input_scores} + + +@pytest.fixture +def input_result_2(input_scores2): + return {"query": "bar", "scores": input_scores2} + + +@pytest.fixture +def input_results(input_result_1, input_result_2): + return [ + dm.RerankQueryResult(**input_result_1), + dm.RerankQueryResult(**input_result_2), + ] + + +@pytest.fixture +def input_sentence_similarity_scores_1(): + return {"scores": [random.uniform(-99999, 99999) for _ in range(10)]} + + +@pytest.fixture +def input_rerank_predictions(input_results): + return {"results": input_results} + + +@pytest.fixture +def input_sentence_list_scores(input_sentence_similarities_scores): + return {"results": input_sentence_similarities_scores} + + +@pytest.fixture +def input_sentence_similarity_scores_2(): + return {"scores": [random.uniform(-99999, 99999) for _ in range(10)]} + + +@pytest.fixture +def input_sentence_similarities_scores( + input_sentence_similarity_scores_1, input_sentence_similarity_scores_2 +): + return [ + dm.SentenceScores(**input_sentence_similarity_scores_1), + dm.SentenceScores(**input_sentence_similarity_scores_2), + ] ## Tests ######################################################################## @@ -100,35 +152,35 @@ @pytest.mark.parametrize( "data_object, inputs", [ - (dm.RerankScore, input_score), - (dm.RerankScore, input_random_score), - (dm.RerankQueryResult, input_result_1), - (dm.RerankPredictions, {"results": input_results}), - (dm.SentenceScores, input_sentence_similarity_scores_1), - (dm.SentenceListScores, {"results": input_sentence_similarities_scores}), + (dm.RerankScore, "input_score"), + (dm.RerankScore, "input_random_score"), + (dm.RerankQueryResult, "input_result_1"), + (dm.RerankPredictions, "input_rerank_predictions"), + (dm.SentenceScores, "input_sentence_similarity_scores_1"), + (dm.SentenceListScores, "input_sentence_list_scores"), ], ) -def test_data_object(data_object, inputs): +def test_data_object(data_object, inputs, request): # Init data object - new_do_from_init = data_object(**inputs) + fixture_values = request.getfixturevalue(inputs) + new_do_from_init = data_object(**fixture_values) assert isinstance(new_do_from_init, data_object) - assert_fields_match(new_do_from_init, inputs) + assert_fields_match(new_do_from_init, fixture_values) # Test to/from proto proto_from_dm = new_do_from_init.to_proto() new_do_from_proto = data_object.from_proto(proto_from_dm) assert isinstance(new_do_from_proto, data_object) - assert_fields_match(new_do_from_proto, inputs) + assert_fields_match(new_do_from_proto, fixture_values) assert new_do_from_init == new_do_from_proto # Test to/from json json_from_dm = new_do_from_init.to_json() new_do_from_json = data_object.from_json(json_from_dm) assert isinstance(new_do_from_json, data_object) - assert_fields_match(new_do_from_json, inputs) + assert_fields_match(new_do_from_json, fixture_values) assert new_do_from_init == new_do_from_json def assert_fields_match(data_object, inputs): - for k, v in inputs.items(): - assert getattr(data_object, k) == inputs[k] + assert all(getattr(data_object, key) == value for key, value in inputs.items()) From 25dce3e2d2f5909e2fff77cb57dfcb8a63274e21 Mon Sep 17 00:00:00 2001 From: markstur Date: Wed, 1 Nov 2023 14:57:05 -0700 Subject: [PATCH 11/17] Updated with suggested improvements from code review Signed-off-by: markstur --- caikit_nlp/data_model/embedding_vectors.py | 6 +-- .../modules/text_embedding/embedding.py | 41 ++++++++++--------- .../modules/text_embedding/test_embedding.py | 12 +++--- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/caikit_nlp/data_model/embedding_vectors.py b/caikit_nlp/data_model/embedding_vectors.py index 4379372b..bb982b72 100644 --- a/caikit_nlp/data_model/embedding_vectors.py +++ b/caikit_nlp/data_model/embedding_vectors.py @@ -15,7 +15,7 @@ """ # Standard from dataclasses import dataclass, field -from typing import List, Union +from typing import Any, List, Union import json # Third Party @@ -91,7 +91,7 @@ def from_vector(cls, vector): return cls(data=data) @classmethod - def from_json(cls, json_str): + def from_json(cls, json_str: Union[dict[str, Any], str]) -> "Vector1D": """JSON does not have different float types. Move data into data_pyfloatsequence""" json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str @@ -169,7 +169,7 @@ def __post_init__(self): error.type_check_all("", Vector1D, results=self.results) @classmethod - def from_json(cls, json_str): + def from_json(cls, json_str: Union[dict[str, Any], str]) -> "ListOfVector1D": """Fill in the vector data in an appropriate data_""" json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 19d016d1..3c062175 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -230,11 +230,10 @@ def run_rerank_query( return_text=return_text, ).results - return ( - results[0] - if len(results) > 0 - else RerankQueryResult(scores=[], query=query if return_query else None) - ) + if results: + return results[0] + + RerankQueryResult(scores=[], query=query if return_query else None) @RerankTasks.taskmethod() def run_rerank_queries( @@ -282,25 +281,28 @@ def run_rerank_queries( documents=documents, ) - if len(queries) < 1 or len(documents) < 1: - return RerankPredictions([]) + error.value_check( + "", + queries and documents, + "Cannot rerank without a query and at least one document", + ) if top_n is None or top_n < 1: top_n = len(documents) # Using input document dicts so get "text" else "_text" else default to "" - def get_text(srd): - return srd.get("text") or srd.get("_text", "") + def get_text(doc): + return doc.get("text") or doc.get("_text", "") - doc_texts = [get_text(srd) for srd in documents] + doc_texts = [get_text(doc) for doc in documents] - doc_embeddings = self.model.encode(doc_texts, convert_to_tensor=True) - doc_embeddings = doc_embeddings.to(self.model.device) - doc_embeddings = normalize_embeddings(doc_embeddings) + doc_embeddings = normalize_embeddings( + self.model.encode(doc_texts, convert_to_tensor=True).to(self.model.device) + ) - query_embeddings = self.model.encode(queries, convert_to_tensor=True) - query_embeddings = query_embeddings.to(self.model.device) - query_embeddings = normalize_embeddings(query_embeddings) + query_embeddings = normalize_embeddings( + self.model.encode(queries, convert_to_tensor=True).to(self.model.device) + ) res = semantic_search( query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score @@ -346,12 +348,11 @@ def save(self, model_path: str, *args, **kwargs): Path to model config """ - model_config_path = model_path # because the param name is misleading - - error.type_check("", str, model_path=model_config_path) + error.type_check("", str, model_path=model_path) + model_config_path = model_path.strip() error.value_check( "", - model_config_path is not None and model_config_path.strip(), + model_config_path, f"model_path '{model_config_path}' is invalid", ) diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index a8282898..dcfd007c 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -267,8 +267,8 @@ def test_run_rerank_query_no_query(): def test_run_rerank_query_zero_docs(): """No empty doc list therefore result is zero result scores""" - result = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=[], top_n=99) - assert len(result.scores) == 0 + with pytest.raises(ValueError): + BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=[], top_n=99) def test_run_rerank_query(): @@ -331,11 +331,9 @@ def test_run_rerank_queries_top_n(top_n, expected): ) def test_run_rerank_queries_no_queries_or_no_docs(queries, docs): """No queries and/or no docs therefore result is zero results""" - res = BOOTSTRAPPED_MODEL.run_rerank_queries( - queries=queries, documents=docs, top_n=9 - ) - assert isinstance(res, RerankPredictions) - assert len(res.results) == 0 + + with pytest.raises(ValueError): + BOOTSTRAPPED_MODEL.run_rerank_queries(queries=queries, documents=docs, top_n=9) def test_run_rerank_queries(): From a031ba52572252c559301f646852bd8b262ff431 Mon Sep 17 00:00:00 2001 From: markstur Date: Tue, 14 Nov 2023 11:27:40 -0800 Subject: [PATCH 12/17] README.md edits Signed-off-by: markstur --- README.md | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 6ed215b2..509485ae 100644 --- a/README.md +++ b/README.md @@ -8,18 +8,15 @@ Caikit-NLP implements concept of "task" from `caikit` framework to define (and c Capabilities provided by `caikit-nlp`: -| Task | Module(s) | Salient Feature(s) | -|-------------------------|------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| TextGenerationTask | 1. `PeftPromptTuning`
2. `TextGeneration` | 1. Prompt Tuning, Multi-task Prompt tuning
2. Fine-tuning Both modules above provide optimized inference capability using Text Generation Inference Server | -| TextClassificationTask | 1. `SequenceClassification` | 1. (Work in progress..) | -| TokenClassificationTask | 1. `FilteredSpanClassification` | 1. (Work in progress..) | -| TokenizationTask | 1. `RegexSentenceSplitter` | 1. Demo purposes only | -| EmbeddingTask | 1. `TextEmbedding` | 1. text/embedding from a local sentence-transformers model -| EmbeddingTasks | 1. `TextEmbedding` | 1. Same as EmbeddingTask but multiple sentences (texts) as input and corresponding list of outputs. -| SentenceSimilarityTask | 1. `TextEmbedding` | 1. text/sentence-similarity from a local sentence-transformers model (Hugging Face style API returns scores only in order of input sentences) | -| SentenceSimilarityTasks | 1. `TextEmbedding` | 1. Same as SentenceSimilarityTask but multiple source_sentences (each to be compared to same list of sentences) as input and corresponding lists of outputs. | -| RerankTask | 1. `TextEmbedding` | 1. text/rerank from a local sentence-transformers model (Cohere style API returns top_n scores in order of relevance with index to source and optionally returning inputs) | -| RerankTasks | 1. `TextEmbedding` | 1. Same as RerankTask but multiple queries as input and corresponding lists of outputs. Same list of documents for all queries. | +| Task | Module(s) | Salient Feature(s) | +|-----------------------------------------------------|------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| TextGenerationTask | 1. `PeftPromptTuning`
2. `TextGeneration` | 1. Prompt Tuning, Multi-task Prompt tuning
2. Fine-tuning Both modules above provide optimized inference capability using Text Generation Inference Server | +| TextClassificationTask | 1. `SequenceClassification` | 1. (Work in progress..) | +| TokenClassificationTask | 1. `FilteredSpanClassification` | 1. (Work in progress..) | +| TokenizationTask | 1. `RegexSentenceSplitter` | 1. Demo purposes only | +| EmbeddingTask
EmbeddingTasks | 1. `TextEmbedding` | 1. TextEmbedding returns a text embedding vector from a local sentence-transformers model
2. EmbeddingTasks takes multiple input texts and returns a corresponding list of vectors. +| SentenceSimilarityTask
SentenceSimilarityTasks | 1. `TextEmbedding` | 1. SentenceSimilarityTask compares one source_sentence to a list of sentences and returns similarity scores in order of the sentences.
2. SentenceSimilarityTasks uses a list of source_sentences (each to be compared to same list of sentences) and returns corresponding lists of outputs. | +| RerankTask
RerankTasks | 1. `TextEmbedding` | 1. RerankTask compares a query to a list of documents and returns top_n scores in order of relevance with indexes to the source documents and optionally returning the documents.
2. RerankTasks takes multiple queries as input and returns a corresponding list of outputs. The same list of documents is used for all queries. | ## Getting Started From d1853ed5f6f3df79e098bb321b98921990cc744b Mon Sep 17 00:00:00 2001 From: markstur Date: Wed, 15 Nov 2023 09:06:23 -0800 Subject: [PATCH 13/17] Make sentence-transformers import optional * Handling ModuleNotFound so that we can move extras to extras in the future * Testing with pip install --nodeps of only the minimum (probably to be replaced with full import of sentence-transformers in extras in the future) Signed-off-by: markstur --- .../modules/text_embedding/embedding.py | 31 +++++++++++++------ pyproject.toml | 1 - sentence-transformers.nodeps.txt | 7 +++++ tox.ini | 3 ++ 4 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 sentence-transformers.nodeps.txt diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 3c062175..f281665f 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -14,17 +14,9 @@ # Standard from typing import List, Optional +import importlib import os -# Third Party -from sentence_transformers import SentenceTransformer -from sentence_transformers.util import ( - cos_sim, - dot_score, - normalize_embeddings, - semantic_search, -) - # First Party from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module from caikit.core.data_model.json_dict import JsonDict @@ -49,6 +41,27 @@ logger = alog.use_channel("TXT_EMB") error = error_handler.get(logger) +# To avoid dependency problems, make sentence-transformers an optional import and +# defer any ModuleNotFoundError until someone actually tries to init a model with this module. +try: + sentence_transformers = importlib.import_module("sentence_transformers") + # Third Party + from sentence_transformers import SentenceTransformer + from sentence_transformers.util import ( + cos_sim, + dot_score, + normalize_embeddings, + semantic_search, + ) +except ModuleNotFoundError: + # When it is not available, create a dummy that raises an error on attempted init() + class SentenceTransformerNotAvailable: + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + # Will reproduce the ModuleNotFoundError if/when anyone actually tries this module/model + importlib.import_module("sentence_transformers") + + SentenceTransformer = SentenceTransformerNotAvailable + @module( "eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f", diff --git a/pyproject.toml b/pyproject.toml index ce5200dc..e99ce0b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "pandas>=1.5.0", "scikit-learn>=1.1", "scipy>=1.8.1", - "sentence-transformers~=2.2.2", "tokenizers>=0.13.3", "torch>=2.0.1", "tqdm>=4.65.0", diff --git a/sentence-transformers.nodeps.txt b/sentence-transformers.nodeps.txt new file mode 100644 index 00000000..dad81257 --- /dev/null +++ b/sentence-transformers.nodeps.txt @@ -0,0 +1,7 @@ +# These can be installed with --no-deps. + +# Minimum needed to use sentence-transformers: +sentence-transformers>=2.2.2,<2.3.0 +nltk>=3.8.1,<3.9.0 +Pillow>=10.0.0,<10.1.0 + diff --git a/tox.ini b/tox.ini index ed361a28..586bcd54 100644 --- a/tox.ini +++ b/tox.ini @@ -16,6 +16,9 @@ passenv = LOG_THREAD_ID LOG_CHANNEL_WIDTH commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} +commands = + python -I -m pip install --force-reinstall --no-deps -r sentence-transformers.nodeps.txt + pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} ; Unclear: We probably want to test wheel packaging ; But! tox will fail when this is set and _any_ interpreter is missing From 81928c5b56f47f8933fc3add96f700420a8376a5 Mon Sep 17 00:00:00 2001 From: markstur Date: Wed, 15 Nov 2023 18:19:59 -0800 Subject: [PATCH 14/17] Embedding service using interfaces (tasks and data models) from caikit * Moved interfaces (tasks and datamodels) to caikit * Updated code here to the new interfaces with added producer_id and related changes to the data models Signed-off-by: markstur --- caikit_nlp/data_model/__init__.py | 5 +- caikit_nlp/data_model/embedding_vectors.py | 199 ------------------ caikit_nlp/data_model/reranker.py | 50 ----- caikit_nlp/data_model/sentence_similarity.py | 36 ---- caikit_nlp/modules/text_embedding/__init__.py | 3 - .../modules/text_embedding/embedding.py | 93 +++++--- .../modules/text_embedding/embedding_tasks.py | 38 ---- .../modules/text_embedding/rerank_task.py | 68 ------ .../sentence_similarity_task.py | 43 ---- tests/data_model/test_embedding_vectors.py | 49 +++-- tests/data_model/test_reranker.py | 72 ++++--- .../modules/text_embedding/test_embedding.py | 47 ++--- tox.ini | 1 - 13 files changed, 157 insertions(+), 547 deletions(-) delete mode 100644 caikit_nlp/data_model/embedding_vectors.py delete mode 100644 caikit_nlp/data_model/reranker.py delete mode 100644 caikit_nlp/data_model/sentence_similarity.py delete mode 100644 caikit_nlp/modules/text_embedding/embedding_tasks.py delete mode 100644 caikit_nlp/modules/text_embedding/rerank_task.py delete mode 100644 caikit_nlp/modules/text_embedding/sentence_similarity_task.py diff --git a/caikit_nlp/data_model/__init__.py b/caikit_nlp/data_model/__init__.py index 98f1d06b..6826b631 100644 --- a/caikit_nlp/data_model/__init__.py +++ b/caikit_nlp/data_model/__init__.py @@ -15,8 +15,5 @@ """ # Local -from . import embedding_vectors, generation -from .embedding_vectors import * +from . import generation from .generation import * -from .reranker import * -from .sentence_similarity import * diff --git a/caikit_nlp/data_model/embedding_vectors.py b/caikit_nlp/data_model/embedding_vectors.py deleted file mode 100644 index bb982b72..00000000 --- a/caikit_nlp/data_model/embedding_vectors.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Data structures for embedding vector representations -""" -# Standard -from dataclasses import dataclass, field -from typing import Any, List, Union -import json - -# Third Party -from google.protobuf import json_format -import numpy as np - -# First Party -from caikit.core import DataObjectBase, dataobject -from caikit.core.exceptions import error_handler -import alog - -log = alog.use_channel("DATAM") -error = error_handler.get(log) - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class PyFloatSequence(DataObjectBase): - values: List[float] = field(default_factory=list) - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class NpFloat32Sequence(DataObjectBase): - values: List[np.float32] - - @classmethod - def from_proto(cls, proto): - values = np.asarray(proto.values, dtype=np.float32) - return cls(values) - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class NpFloat64Sequence(DataObjectBase): - values: List[np.float64] - - @classmethod - def from_proto(cls, proto): - values = np.asarray(proto.values, dtype=np.float64) - return cls(values) - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class Vector1D(DataObjectBase): - """Data representation for a 1 dimension vector of float-type data.""" - - data: Union[ - PyFloatSequence, - NpFloat32Sequence, - NpFloat64Sequence, - ] - - def __post_init__(self): - error.value_check( - "", - hasattr(self.data, "values"), - ValueError("Vector1D requires a float sequence data object with values."), - ) - - @classmethod - def from_vector(cls, vector): - dtype = getattr(vector, "dtype", False) - if dtype is None: - data = PyFloatSequence(vector) - elif dtype == np.float32: - data = NpFloat32Sequence(vector) - elif dtype == np.float64: - data = NpFloat64Sequence(vector) - else: - data = PyFloatSequence(vector) - return cls(data=data) - - @classmethod - def from_json(cls, json_str: Union[dict[str, Any], str]) -> "Vector1D": - """JSON does not have different float types. Move data into data_pyfloatsequence""" - - json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str - data = json_obj.pop("data") - if data is not None: - json_obj["data_pyfloatsequence"] = data - - json_str = json.dumps(json_obj) - try: - # Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType - parsed_proto = json_format.Parse( - json_str, cls.get_proto_class()(), ignore_unknown_fields=False - ) - - # Use from_proto to return the DataBase object from the parsed proto - return cls.from_proto(parsed_proto) - - except json_format.ParseError as ex: - error("", ValueError(ex)) - - def to_dict(self) -> dict: - """to_dict is needed to make things serializable""" - values = self.data.values if self.data.values is not None else [] - return { - "data": { - # coerce numpy.ndarray and numpy.float32 into JSON serializable list of floats - "values": values.tolist() - if isinstance(values, np.ndarray) - else values - } - } - - @classmethod - def from_proto(cls, proto): - """Wrap the data in an appropriate float sequence, wrapped by this class""" - woo = proto.WhichOneof("data") - if woo is None: - return cls(PyFloatSequence()) - - woo_data = getattr(proto, woo) - if woo == "data_npfloat64sequence": - ret = cls(NpFloat64Sequence.from_proto(woo_data)) - elif woo == "data_npfloat32sequence": - ret = cls(NpFloat32Sequence.from_proto(woo_data)) - else: - ret = cls(PyFloatSequence.from_proto(woo_data)) - return ret - - def fill_proto(self, proto): - """Fill in the data in an appropriate data_""" - values = self.data.values - if values is not None and len(values) > 0: - sample = values[0] - error.type_check( - "", float, np.float32, np.float64, sample=sample - ) - if isinstance(sample, np.float64): - proto.data_npfloat64sequence.values.extend(values) - elif isinstance(sample, np.float32): - proto.data_npfloat32sequence.values.extend(values) - else: - proto.data_pyfloatsequence.values.extend(values) - - return proto - - -@dataobject(package="caikit_data_model.caikit_nlp") -class ListOfVector1D(DataObjectBase): - """Data representation for an embedding matrix holding 2D vectors""" - - results: List[Vector1D] - - def __post_init__(self): - error.type_check("", list, results=self.results) - error.type_check_all("", Vector1D, results=self.results) - - @classmethod - def from_json(cls, json_str: Union[dict[str, Any], str]) -> "ListOfVector1D": - """Fill in the vector data in an appropriate data_""" - - json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str - for v in json_obj["results"]: - data = v.pop("data") - if data is not None: - v["data_pyfloatsequence"] = data - json_str = json.dumps(json_obj) - try: - # Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType - parsed_proto = json_format.Parse( - json_str, cls.get_proto_class()(), ignore_unknown_fields=False - ) - - # Use from_proto to return the DataBase object from the parsed proto - return cls.from_proto(parsed_proto) - - except json_format.ParseError as ex: - error("", ValueError(ex)) - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class EmbeddingResult(DataObjectBase): - """Result from text embedding task""" - - result: Vector1D diff --git a/caikit_nlp/data_model/reranker.py b/caikit_nlp/data_model/reranker.py deleted file mode 100644 index 924a7909..00000000 --- a/caikit_nlp/data_model/reranker.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard -from typing import List, Optional - -# First Party -from caikit.core import DataObjectBase, dataobject -from caikit.core.data_model.json_dict import JsonDict - - -@dataobject(package="caikit_data_model.caikit_nlp") -class RerankScore(DataObjectBase): - """The score for one document (one query)""" - - document: Optional[JsonDict] - index: int - score: float - text: Optional[str] - - -@dataobject(package="caikit_data_model.caikit_nlp") -class RerankQueryResult(DataObjectBase): - """Result for one query in a rerank task. - This is a list of n ReRankScore where n is based on top_n documents and each score indicates - the relevance of that document for this query. Results are ordered most-relevant first. - """ - - query: Optional[str] - scores: List[RerankScore] - - -@dataobject(package="caikit_data_model.caikit_nlp") -class RerankPredictions(DataObjectBase): - """Result for a rerank tasks (supporting multiple queries). - For multiple queries, each one has a RerankQueryResult (ranking the documents for that query). - """ - - results: List[RerankQueryResult] diff --git a/caikit_nlp/data_model/sentence_similarity.py b/caikit_nlp/data_model/sentence_similarity.py deleted file mode 100644 index c24097f7..00000000 --- a/caikit_nlp/data_model/sentence_similarity.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Data structures for embedding vector representations -""" -# Standard -from typing import List - -# First Party -from caikit.core import DataObjectBase, dataobject -from caikit.core.exceptions import error_handler -import alog - -log = alog.use_channel("DATAM") -error = error_handler.get(log) - - -@dataobject(package="caikit_data_model.caikit_nlp") -class SentenceScores(DataObjectBase): - scores: List[float] - - -@dataobject(package="caikit_data_model.caikit_nlp") -class SentenceListScores(DataObjectBase): - - results: List[SentenceScores] diff --git a/caikit_nlp/modules/text_embedding/__init__.py b/caikit_nlp/modules/text_embedding/__init__.py index d72e55be..2451f4a2 100644 --- a/caikit_nlp/modules/text_embedding/__init__.py +++ b/caikit_nlp/modules/text_embedding/__init__.py @@ -30,6 +30,3 @@ # Local from .embedding import EmbeddingModule -from .embedding_tasks import EmbeddingTask, EmbeddingTasks -from .rerank_task import RerankTask, RerankTasks -from .sentence_similarity_task import SentenceSimilarityTask, SentenceSimilarityTasks diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index f281665f..f64e185e 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -21,22 +21,27 @@ from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module from caikit.core.data_model.json_dict import JsonDict from caikit.core.exceptions import error_handler -import alog - -# Local -from .embedding_tasks import EmbeddingTask, EmbeddingTasks -from .rerank_task import RerankTask, RerankTasks -from .sentence_similarity_task import SentenceSimilarityTask, SentenceSimilarityTasks -from caikit_nlp.data_model import ( +from caikit.interfaces.common.data_model.vectors import ListOfVector1D, Vector1D +from caikit.interfaces.nlp.data_model import ( EmbeddingResult, - ListOfVector1D, - RerankPredictions, - RerankQueryResult, + EmbeddingResults, + RerankResult, + RerankResults, RerankScore, - SentenceListScores, - SentenceScores, - Vector1D, + RerankScores, + SentenceSimilarityResult, + SentenceSimilarityResults, + SentenceSimilarityScores, +) +from caikit.interfaces.nlp.tasks import ( + EmbeddingTask, + EmbeddingTasks, + RerankTask, + RerankTasks, + SentenceSimilarityTask, + SentenceSimilarityTasks, ) +import alog logger = alog.use_channel("TXT_EMB") error = error_handler.get(logger) @@ -126,16 +131,19 @@ def run_embedding(self, text: str) -> EmbeddingResult: """ error.type_check("", str, text=text) - return EmbeddingResult(Vector1D.from_vector(self.model.encode(text))) + return EmbeddingResult( + result=Vector1D.from_vector(self.model.encode(text)), + producer_id=self.PRODUCER_ID, + ) @EmbeddingTasks.taskmethod() - def run_embeddings(self, texts: List[str]) -> ListOfVector1D: + def run_embeddings(self, texts: List[str]) -> EmbeddingResults: """Get embedding vectors for texts. Args: texts: List[str] List of input texts to be processed Returns: - List[Vector1D]: List vectors. One for each input text (in order). + EmbeddingResults: List of vectors. One for each input text (in order). Each vector is a list of floats (supports various float types). """ if isinstance( @@ -144,40 +152,45 @@ def run_embeddings(self, texts: List[str]) -> ListOfVector1D: texts = [texts] embeddings = self.model.encode(texts) - results = [Vector1D.from_vector(e) for e in embeddings] - return ListOfVector1D(results=results) + vectors = [Vector1D.from_vector(e) for e in embeddings] + return EmbeddingResults( + results=ListOfVector1D(vectors=vectors), producer_id=self.PRODUCER_ID + ) @SentenceSimilarityTask.taskmethod() def run_sentence_similarity( self, source_sentence: str, sentences: List[str] - ) -> SentenceScores: + ) -> SentenceSimilarityResult: """Get similarity scores for each of sentences compared to the source_sentence. Args: source_sentence: str sentences: List[str] Sentences to compare to source_sentence Returns: - SentenceScores + SentenceSimilarityResult: Similarity scores for each sentence. """ source_embedding = self.model.encode(source_sentence) embeddings = self.model.encode(sentences) res = cos_sim(source_embedding, embeddings) - return SentenceScores(res.tolist()[0]) + return SentenceSimilarityResult( + result=SentenceSimilarityScores(scores=res.tolist()[0]), + producer_id=self.PRODUCER_ID, + ) @SentenceSimilarityTasks.taskmethod() def run_sentence_similarities( self, source_sentences: List[str], sentences: List[str] - ) -> SentenceListScores: + ) -> SentenceSimilarityResults: """Run sentence-similarities on model. Args: source_sentences: List[str] sentences: List[str] Sentences to compare to source_sentences Returns: - SentenceListScores Similarity scores for each source sentence in order. - each SentenceScores contains the source-sentence's score for each sentence in order. + SentenceSimilarityResults: Similarity scores for each source sentence in order. + Each one contains the source-sentence's score for each sentence in order. """ source_embedding = self.model.encode(source_sentences) @@ -185,8 +198,9 @@ def run_sentence_similarities( res = cos_sim(source_embedding, embeddings) float_list_list = res.tolist() - return SentenceListScores( - results=[SentenceScores(fl) for fl in float_list_list] + return SentenceSimilarityResults( + results=[SentenceSimilarityScores(fl) for fl in float_list_list], + producer_id=self.PRODUCER_ID, ) @RerankTask.taskmethod() @@ -198,7 +212,7 @@ def run_rerank_query( return_documents: bool = True, return_query: bool = True, return_text: bool = True, - ) -> RerankQueryResult: + ) -> RerankResult: """Rerank the documents returning the most relevant top_n in order for this query. Args: query: str @@ -219,7 +233,7 @@ def run_rerank_query( Default True Setting to False will disable returning of document text string that was used. Returns: - RerankQueryResult + RerankResult Returns the (top_n) scores in relevance order (most relevant first). The results always include a score and index which may be used to find the document in the original documents list. Optionally, the results also contain the entire @@ -244,9 +258,15 @@ def run_rerank_query( ).results if results: - return results[0] - - RerankQueryResult(scores=[], query=query if return_query else None) + return RerankResult(result=results[0], producer_id=self.PRODUCER_ID) + + RerankResult( + producer_id=self.PRODUCER_ID, + result=RerankScore( + scores=[], + query=query if return_query else None, + ), + ) @RerankTasks.taskmethod() def run_rerank_queries( @@ -257,7 +277,7 @@ def run_rerank_queries( return_documents: bool = True, return_queries: bool = True, return_text: bool = True, - ) -> RerankPredictions: + ) -> RerankResults: """Rerank the documents returning the most relevant top_n in order for each of the queries. Args: queries: List[str] @@ -278,7 +298,7 @@ def run_rerank_queries( Default True Setting to False will disable returning of document text string that was used. Returns: - RerankPredictions + RerankResults For each query in queries (in the original order)... Returns the (top_n) scores in relevance order (most relevant first). The results always include a score and index which may be used to find the document @@ -337,11 +357,14 @@ def add_query(q): return queries[q] if return_queries else None results = [ - RerankQueryResult(query=add_query(q), scores=[RerankScore(**x) for x in r]) + RerankScores( + query=add_query(q), + scores=[RerankScore(**x) for x in r], + ) for q, r in enumerate(res) ] - return RerankPredictions(results=results) + return RerankResults(results=results, producer_id=self.PRODUCER_ID) @classmethod def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule": diff --git a/caikit_nlp/modules/text_embedding/embedding_tasks.py b/caikit_nlp/modules/text_embedding/embedding_tasks.py deleted file mode 100644 index e084b920..00000000 --- a/caikit_nlp/modules/text_embedding/embedding_tasks.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard -from typing import List - -# First Party -from caikit.core import TaskBase, task - -# Local -from ...data_model import EmbeddingResult, ListOfVector1D - - -@task( - required_parameters={"text": str}, - output_type=EmbeddingResult, -) -class EmbeddingTask(TaskBase): - """Return a text embedding for the input text string""" - - -@task( - required_parameters={"texts": List[str]}, - output_type=ListOfVector1D, -) -class EmbeddingTasks(TaskBase): - """Return a text embedding for each text string in the input list""" diff --git a/caikit_nlp/modules/text_embedding/rerank_task.py b/caikit_nlp/modules/text_embedding/rerank_task.py deleted file mode 100644 index 4b7d67e5..00000000 --- a/caikit_nlp/modules/text_embedding/rerank_task.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard -from typing import List - -# First Party -from caikit.core import TaskBase, task -from caikit.core.data_model.json_dict import JsonDict - -# Local -from caikit_nlp.data_model.reranker import RerankPredictions, RerankQueryResult - - -@task( - required_parameters={ - "documents": List[JsonDict], - "query": str, - }, - output_type=RerankQueryResult, -) -class RerankTask(TaskBase): - """Returns an ordered list ranking the most relevant documents for the query - - Required parameters: - query: The search query - documents: JSON documents containing "text" or alternative "_text" to search - Returns: - The top_n documents in order of relevance (most relevant first). - For each, a score and document index (position in input) is returned. - The original document JSON is returned depending on optional args. - The top_n optional parameter limits the results when used. - """ - - -@task( - required_parameters={ - "documents": List[JsonDict], - "queries": List[str], - }, - output_type=RerankPredictions, -) -class RerankTasks(TaskBase): - """Returns an ordered list for each query ranking the most relevant documents for the query - - Required parameters: - queries: The search queries - documents: JSON documents containing "text" or alternative "_text" to search - Returns: - Results in order of the queries. - In each query result: - The query text is optionally included for visual convenience. - The top_n documents in order of relevance (most relevant first). - For each, a score and document index (position in input) is returned. - The original document JSON is returned depending on optional args. - The top_n optional parameter limits the results when used. - """ diff --git a/caikit_nlp/modules/text_embedding/sentence_similarity_task.py b/caikit_nlp/modules/text_embedding/sentence_similarity_task.py deleted file mode 100644 index 021aec23..00000000 --- a/caikit_nlp/modules/text_embedding/sentence_similarity_task.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard -from typing import List - -# First Party -from caikit.core import TaskBase, task - -# Local -from ...data_model import SentenceListScores, SentenceScores - - -@task( - required_parameters={"source_sentence": str, "sentences": List[str]}, - output_type=SentenceScores, -) -class SentenceSimilarityTask(TaskBase): - """Compare the source_sentence to each of the sentences. - Result contains a list of scores in the order of the input sentences. - """ - - -@task( - required_parameters={"source_sentences": List[str], "sentences": List[str]}, - output_type=SentenceListScores, -) -class SentenceSimilarityTasks(TaskBase): - """Compare each of the source_sentences to each of the sentences. - Returns a list of results in the order of the source_sentences. - Each result contains a list of scores in the order of the input sentences. - """ diff --git a/tests/data_model/test_embedding_vectors.py b/tests/data_model/test_embedding_vectors.py index 0894041c..7ffc5ad9 100644 --- a/tests/data_model/test_embedding_vectors.py +++ b/tests/data_model/test_embedding_vectors.py @@ -20,8 +20,15 @@ import numpy as np import pytest -# Local -from caikit_nlp import data_model as dm +# First Party +from caikit.interfaces.common.data_model.vectors import ( + ListOfVector1D, + NpFloat32Sequence, + NpFloat64Sequence, + PyFloatSequence, + Vector1D, +) +from caikit.interfaces.nlp import data_model as dm ## Setup ######################################################################### @@ -60,36 +67,36 @@ def random_python_vector1d_float(random_numpy_vector1d_float32): @pytest.mark.parametrize( "sequence", [ - dm.PyFloatSequence(), - dm.NpFloat32Sequence(), - dm.NpFloat64Sequence(), + PyFloatSequence(), + NpFloat32Sequence(), + NpFloat64Sequence(), TRICK_SEQUENCE(values=None), ], ids=type, ) def test_empty_sequences(sequence): """No type check error with empty sequences""" - new_dm_from_init = dm.Vector1D(sequence) + new_dm_from_init = Vector1D(sequence) assert isinstance(new_dm_from_init.data, type(sequence)) assert new_dm_from_init.data.values is None # Test proto proto_from_dm = new_dm_from_init.to_proto() - new_dm_from_proto = dm.Vector1D.from_proto(proto_from_dm) - assert isinstance(new_dm_from_proto, dm.Vector1D) + new_dm_from_proto = Vector1D.from_proto(proto_from_dm) + assert isinstance(new_dm_from_proto, Vector1D) assert new_dm_from_proto.data.values is None # Test json json_from_dm = new_dm_from_init.to_json() - new_dm_from_json = dm.Vector1D.from_json(json_from_dm) - assert isinstance(new_dm_from_json, dm.Vector1D) + new_dm_from_json = Vector1D.from_json(json_from_dm) + assert isinstance(new_dm_from_json, Vector1D) assert new_dm_from_json.data.values == [] def test_vector1d_iterator_error(): """Cannot just shove in an iterator and expect it to work""" with pytest.raises(ValueError): - dm.Vector1D(data=[1.1, 2.2, 3.3]) + Vector1D(data=[1.1, 2.2, 3.3]) def _assert_array_check(new_array, data_values, float_type): @@ -101,9 +108,9 @@ def _assert_array_check(new_array, data_values, float_type): @pytest.mark.parametrize( "float_seq_class, random_values, float_type", [ - (dm.PyFloatSequence, "random_python_vector1d_float", float), - (dm.NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), - (dm.NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), + (PyFloatSequence, "random_python_vector1d_float", float), + (NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), + (NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), ( TRICK_SEQUENCE, "simple_array_of_floats", @@ -115,17 +122,17 @@ def test_vector1d_dm(float_seq_class, random_values, float_type, request): # Test init fixture_values = request.getfixturevalue(random_values) - dm_init = dm.Vector1D(data=float_seq_class(fixture_values)) + dm_init = Vector1D(data=float_seq_class(fixture_values)) _assert_array_check(dm_init, fixture_values, float_type) # Test proto dm_to_proto = dm_init.to_proto() - dm_from_proto = dm.Vector1D.from_proto(dm_to_proto) + dm_from_proto = Vector1D.from_proto(dm_to_proto) _assert_array_check(dm_from_proto, fixture_values, float_type) # Test json dm_to_json = dm_init.to_json() - dm_from_json = dm.Vector1D.from_json(dm_to_json) + dm_from_json = Vector1D.from_json(dm_to_json) _assert_array_check( dm_from_json, fixture_values, float ) # NOTE: always float after json @@ -134,14 +141,14 @@ def test_vector1d_dm(float_seq_class, random_values, float_type, request): @pytest.mark.parametrize( "float_seq_class, random_values, float_type", [ - (dm.PyFloatSequence, "random_python_vector1d_float", float), - (dm.NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), - (dm.NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), + (PyFloatSequence, "random_python_vector1d_float", float), + (NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), + (NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), ], ) def test_vector1d_dm_from_vector(float_seq_class, random_values, float_type, request): fixture_values = request.getfixturevalue(random_values) - v = dm.Vector1D.from_vector(fixture_values) + v = Vector1D.from_vector(fixture_values) assert isinstance(v.data, float_seq_class) assert isinstance(v.data.values[0], float_type) _assert_array_check(v, fixture_values, float_type) diff --git a/tests/data_model/test_reranker.py b/tests/data_model/test_reranker.py index 754272f3..169ce737 100644 --- a/tests/data_model/test_reranker.py +++ b/tests/data_model/test_reranker.py @@ -21,8 +21,16 @@ # Third Party import pytest -# Local -from caikit_nlp import data_model as dm +# First Party +from caikit.interfaces.nlp.data_model import ( + RerankResult, + RerankResults, + RerankScore, + RerankScores, + SentenceSimilarityResult, + SentenceSimilarityResults, + SentenceSimilarityScores, +) ## Setup ######################################################################### @@ -87,43 +95,57 @@ def input_random_score_3(): @pytest.fixture def input_scores(input_score, input_random_score): - return [dm.RerankScore(**input_score), dm.RerankScore(**input_random_score)] + return [RerankScore(**input_score), RerankScore(**input_random_score)] @pytest.fixture def input_scores2(input_random_score, input_random_score_3): return [ - dm.RerankScore(**input_random_score), - dm.RerankScore(**input_random_score_3), + RerankScore(**input_random_score), + RerankScore(**input_random_score_3), ] @pytest.fixture -def input_result_1(input_scores): - return {"query": "foo", "scores": input_scores} +def rerank_result_1(input_scores): + return {"result": RerankScores(query="foo", scores=input_scores)} @pytest.fixture -def input_result_2(input_scores2): - return {"query": "bar", "scores": input_scores2} +def rerank_result_2(input_scores2): + return {"result": RerankScores(query="bar", scores=input_scores2)} @pytest.fixture -def input_results(input_result_1, input_result_2): - return [ - dm.RerankQueryResult(**input_result_1), - dm.RerankQueryResult(**input_result_2), - ] +def input_sentence_similarity_scores_1(): + return {"scores": [random.uniform(-99999, 99999) for _ in range(10)]} @pytest.fixture -def input_sentence_similarity_scores_1(): +def input_sentence_similarity_scores_2(): return {"scores": [random.uniform(-99999, 99999) for _ in range(10)]} @pytest.fixture -def input_rerank_predictions(input_results): - return {"results": input_results} +def sentence_similarity_result(input_sentence_similarity_scores_1): + return {"result": SentenceSimilarityScores(**input_sentence_similarity_scores_1)} + + +@pytest.fixture +def sentence_similarity_results( + input_sentence_similarity_scores_1, input_sentence_similarity_scores_2 +): + return { + "results": [ + SentenceSimilarityScores(**input_sentence_similarity_scores_1), + SentenceSimilarityScores(**input_sentence_similarity_scores_2), + ] + } + + +@pytest.fixture +def rerank_results(rerank_result_1, rerank_result_2): + return {"results": [rerank_result_1["result"], rerank_result_2["result"]]} @pytest.fixture @@ -141,8 +163,8 @@ def input_sentence_similarities_scores( input_sentence_similarity_scores_1, input_sentence_similarity_scores_2 ): return [ - dm.SentenceScores(**input_sentence_similarity_scores_1), - dm.SentenceScores(**input_sentence_similarity_scores_2), + SentenceSimilarityResult(**input_sentence_similarity_scores_1), + SentenceSimilarityResult(**input_sentence_similarity_scores_2), ] @@ -152,12 +174,12 @@ def input_sentence_similarities_scores( @pytest.mark.parametrize( "data_object, inputs", [ - (dm.RerankScore, "input_score"), - (dm.RerankScore, "input_random_score"), - (dm.RerankQueryResult, "input_result_1"), - (dm.RerankPredictions, "input_rerank_predictions"), - (dm.SentenceScores, "input_sentence_similarity_scores_1"), - (dm.SentenceListScores, "input_sentence_list_scores"), + (RerankScore, "input_score"), + (RerankScore, "input_random_score"), + (RerankResult, "rerank_result_1"), + (RerankResults, "rerank_results"), + (SentenceSimilarityResult, "sentence_similarity_result"), + (SentenceSimilarityResults, "sentence_similarity_results"), ], ) def test_data_object(data_object, inputs, request): diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index dcfd007c..6dd25ff8 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -12,16 +12,16 @@ # First Party from caikit.core import ModuleConfig - -# Local -from caikit_nlp.data_model import ( +from caikit.interfaces.common.data_model.vectors import ListOfVector1D +from caikit.interfaces.nlp.data_model import ( EmbeddingResult, - ListOfVector1D, - RerankPredictions, - RerankQueryResult, + RerankResult, + RerankResults, RerankScore, - Vector1D, + RerankScores, ) + +# Local from caikit_nlp.modules.text_embedding import EmbeddingModule from tests.fixtures import SEQ_CLASS_MODEL @@ -89,8 +89,7 @@ def _assert_is_expected_embedding_result(actual): def _assert_is_expected_embeddings_results(actual): assert isinstance(actual, ListOfVector1D) - vectors = actual.results - _assert_is_expected_vector(vectors[0]) + _assert_is_expected_vector(actual.vectors[0]) def test_bootstrap(): @@ -210,15 +209,15 @@ def test_run_embeddings_str_type(): """Supposed to be a list, gets fixed automatically.""" model = BOOTSTRAPPED_MODEL res = model.run_embeddings(texts=INPUT) - assert isinstance(res.results, list) - assert len(res.results) == 1 + assert isinstance(res.results.vectors, list) + assert len(res.results.vectors) == 1 def test_run_embeddings(): model = BOOTSTRAPPED_MODEL - results = model.run_embeddings(texts=[INPUT]) - assert isinstance(results.results, list) - _assert_is_expected_embeddings_results(results) + res = model.run_embeddings(texts=[INPUT]) + assert isinstance(res.results.vectors, list) + _assert_is_expected_embeddings_results(res.results) @pytest.mark.parametrize( @@ -256,8 +255,8 @@ def test_run_rerank_query_no_type_error(): ) def test_run_rerank_query_top_n(top_n, expected): res = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n) - assert isinstance(res, RerankQueryResult) - assert len(res.scores) == expected + assert isinstance(res, RerankResult) + assert len(res.result.scores) == expected def test_run_rerank_query_no_query(): @@ -272,10 +271,10 @@ def test_run_rerank_query_zero_docs(): def test_run_rerank_query(): - result = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS) - assert isinstance(result, RerankQueryResult) + res = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS) + assert isinstance(res, RerankResult) - scores = result.scores + scores = res.result.scores assert isinstance(scores, list) assert len(scores) == len(DOCS) @@ -314,7 +313,7 @@ def test_run_rerank_queries_top_n(top_n, expected): res = BOOTSTRAPPED_MODEL.run_rerank_queries( queries=QUERIES, documents=DOCS, top_n=top_n ) - assert isinstance(res, RerankPredictions) + assert isinstance(res, RerankResults) assert len(res.results) == len(QUERIES) for result in res.results: assert len(result.scores) == expected @@ -341,7 +340,7 @@ def test_run_rerank_queries(): rerank_result = BOOTSTRAPPED_MODEL.run_rerank_queries( queries=QUERIES, documents=DOCS, top_n=top_n ) - assert isinstance(rerank_result, RerankPredictions) + assert isinstance(rerank_result, RerankResults) results = rerank_result.results assert isinstance(results, list) @@ -350,7 +349,7 @@ def test_run_rerank_queries(): types_found = {} # Gather the type tests from any of the results for result in results: - assert isinstance(result, RerankQueryResult) + assert isinstance(result, RerankScores) scores = result.scores assert isinstance(scores, list) assert len(scores) == top_n @@ -362,8 +361,8 @@ def test_run_rerank_queries(): def test_run_sentence_similarity(): model = BOOTSTRAPPED_MODEL - result = model.run_sentence_similarity(source_sentence=QUERY, sentences=SENTENCES) - scores = result.scores + res = model.run_sentence_similarity(source_sentence=QUERY, sentences=SENTENCES) + scores = res.result.scores assert len(scores) == len(SENTENCES) for score in scores: assert isinstance(score, float) diff --git a/tox.ini b/tox.ini index 586bcd54..d336b2a5 100644 --- a/tox.ini +++ b/tox.ini @@ -15,7 +15,6 @@ passenv = LOG_FORMATTER LOG_THREAD_ID LOG_CHANNEL_WIDTH -commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} commands = python -I -m pip install --force-reinstall --no-deps -r sentence-transformers.nodeps.txt pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} From 4cab37b866cf1d02003771490839725e849d1f94 Mon Sep 17 00:00:00 2001 From: markstur Date: Mon, 20 Nov 2023 10:28:26 -0800 Subject: [PATCH 15/17] TextEmbedding: Remove data model tests that were moved to caikit interfaces Signed-off-by: markstur --- tests/data_model/test_embedding_vectors.py | 154 --------------- tests/data_model/test_reranker.py | 208 --------------------- 2 files changed, 362 deletions(-) delete mode 100644 tests/data_model/test_embedding_vectors.py delete mode 100644 tests/data_model/test_reranker.py diff --git a/tests/data_model/test_embedding_vectors.py b/tests/data_model/test_embedding_vectors.py deleted file mode 100644 index 7ffc5ad9..00000000 --- a/tests/data_model/test_embedding_vectors.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Test for embedding vectors -""" -# Standard -from collections import namedtuple - -# Third Party -import numpy as np -import pytest - -# First Party -from caikit.interfaces.common.data_model.vectors import ( - ListOfVector1D, - NpFloat32Sequence, - NpFloat64Sequence, - PyFloatSequence, - Vector1D, -) -from caikit.interfaces.nlp import data_model as dm - -## Setup ######################################################################### - -DUMMY_VECTOR_SHAPE = (5,) -RANDOM_SEED = 77 -np.random.seed(RANDOM_SEED) -random_number_generator = np.random.default_rng() - -# To tests the limits of our type-checking, this can replace our legit data objects -TRICK_SEQUENCE = namedtuple("Trick", "values") - - -@pytest.fixture -def simple_array_of_floats(): - return [1.1, 2.2] - - -@pytest.fixture -def random_numpy_vector1d_float32(): - return random_number_generator.random(DUMMY_VECTOR_SHAPE, dtype=np.float32) - - -@pytest.fixture -def random_numpy_vector1d_float64(): - return random_number_generator.random(DUMMY_VECTOR_SHAPE, dtype=np.float64) - - -@pytest.fixture -def random_python_vector1d_float(random_numpy_vector1d_float32): - return random_numpy_vector1d_float32.tolist() - - -## Tests ######################################################################## - - -@pytest.mark.parametrize( - "sequence", - [ - PyFloatSequence(), - NpFloat32Sequence(), - NpFloat64Sequence(), - TRICK_SEQUENCE(values=None), - ], - ids=type, -) -def test_empty_sequences(sequence): - """No type check error with empty sequences""" - new_dm_from_init = Vector1D(sequence) - assert isinstance(new_dm_from_init.data, type(sequence)) - assert new_dm_from_init.data.values is None - - # Test proto - proto_from_dm = new_dm_from_init.to_proto() - new_dm_from_proto = Vector1D.from_proto(proto_from_dm) - assert isinstance(new_dm_from_proto, Vector1D) - assert new_dm_from_proto.data.values is None - - # Test json - json_from_dm = new_dm_from_init.to_json() - new_dm_from_json = Vector1D.from_json(json_from_dm) - assert isinstance(new_dm_from_json, Vector1D) - assert new_dm_from_json.data.values == [] - - -def test_vector1d_iterator_error(): - """Cannot just shove in an iterator and expect it to work""" - with pytest.raises(ValueError): - Vector1D(data=[1.1, 2.2, 3.3]) - - -def _assert_array_check(new_array, data_values, float_type): - for value in new_array.data.values: - assert isinstance(value, float_type) - np.testing.assert_array_equal(new_array.data.values, data_values) - - -@pytest.mark.parametrize( - "float_seq_class, random_values, float_type", - [ - (PyFloatSequence, "random_python_vector1d_float", float), - (NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), - (NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), - ( - TRICK_SEQUENCE, - "simple_array_of_floats", - float, - ), # Sneaky but tests corner cases for now - ], -) -def test_vector1d_dm(float_seq_class, random_values, float_type, request): - - # Test init - fixture_values = request.getfixturevalue(random_values) - dm_init = Vector1D(data=float_seq_class(fixture_values)) - _assert_array_check(dm_init, fixture_values, float_type) - - # Test proto - dm_to_proto = dm_init.to_proto() - dm_from_proto = Vector1D.from_proto(dm_to_proto) - _assert_array_check(dm_from_proto, fixture_values, float_type) - - # Test json - dm_to_json = dm_init.to_json() - dm_from_json = Vector1D.from_json(dm_to_json) - _assert_array_check( - dm_from_json, fixture_values, float - ) # NOTE: always float after json - - -@pytest.mark.parametrize( - "float_seq_class, random_values, float_type", - [ - (PyFloatSequence, "random_python_vector1d_float", float), - (NpFloat32Sequence, "random_numpy_vector1d_float32", np.float32), - (NpFloat64Sequence, "random_numpy_vector1d_float64", np.float64), - ], -) -def test_vector1d_dm_from_vector(float_seq_class, random_values, float_type, request): - fixture_values = request.getfixturevalue(random_values) - v = Vector1D.from_vector(fixture_values) - assert isinstance(v.data, float_seq_class) - assert isinstance(v.data.values[0], float_type) - _assert_array_check(v, fixture_values, float_type) diff --git a/tests/data_model/test_reranker.py b/tests/data_model/test_reranker.py deleted file mode 100644 index 169ce737..00000000 --- a/tests/data_model/test_reranker.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Test for reranker -""" - -# Standard -import random -import string - -# Third Party -import pytest - -# First Party -from caikit.interfaces.nlp.data_model import ( - RerankResult, - RerankResults, - RerankScore, - RerankScores, - SentenceSimilarityResult, - SentenceSimilarityResults, - SentenceSimilarityScores, -) - -## Setup ######################################################################### - - -@pytest.fixture -def input_document(): - return { - "text": "this is the input text", - "_text": "alternate _text here", - "title": "some title attribute here", - "anything": "another string attribute", - "str_test": "test string", - "int_test": 1234, - "float_test": 9876.4321, - } - - -@pytest.fixture -def input_random_document(): - return { - "text": "".join(random.choices(string.printable, k=100)), - "random_str": "".join(random.choices(string.printable, k=100)), - "random_int": random.randint(-99999, 99999), - "random_float": random.uniform(-99999, 99999), - } - - -@pytest.fixture -def input_documents(input_document, input_random_document): - return [input_document, input_random_document] - - -@pytest.fixture -def input_score(input_document): - return { - "document": input_document, - "index": 1234, - "score": 9876.54321, - "text": "this is the input text", - } - - -@pytest.fixture -def input_random_score(input_random_document): - return { - "document": input_random_document, - "index": random.randint(-99999, 99999), - "score": random.uniform(-99999, 99999), - "text": "".join(random.choices(string.printable, k=100)), - } - - -@pytest.fixture -def input_random_score_3(): - return { - "document": {"text": "random foo3"}, - "index": random.randint(-99999, 99999), - "score": random.uniform(-99999, 99999), - "text": "".join(random.choices(string.printable, k=100)), - } - - -@pytest.fixture -def input_scores(input_score, input_random_score): - return [RerankScore(**input_score), RerankScore(**input_random_score)] - - -@pytest.fixture -def input_scores2(input_random_score, input_random_score_3): - return [ - RerankScore(**input_random_score), - RerankScore(**input_random_score_3), - ] - - -@pytest.fixture -def rerank_result_1(input_scores): - return {"result": RerankScores(query="foo", scores=input_scores)} - - -@pytest.fixture -def rerank_result_2(input_scores2): - return {"result": RerankScores(query="bar", scores=input_scores2)} - - -@pytest.fixture -def input_sentence_similarity_scores_1(): - return {"scores": [random.uniform(-99999, 99999) for _ in range(10)]} - - -@pytest.fixture -def input_sentence_similarity_scores_2(): - return {"scores": [random.uniform(-99999, 99999) for _ in range(10)]} - - -@pytest.fixture -def sentence_similarity_result(input_sentence_similarity_scores_1): - return {"result": SentenceSimilarityScores(**input_sentence_similarity_scores_1)} - - -@pytest.fixture -def sentence_similarity_results( - input_sentence_similarity_scores_1, input_sentence_similarity_scores_2 -): - return { - "results": [ - SentenceSimilarityScores(**input_sentence_similarity_scores_1), - SentenceSimilarityScores(**input_sentence_similarity_scores_2), - ] - } - - -@pytest.fixture -def rerank_results(rerank_result_1, rerank_result_2): - return {"results": [rerank_result_1["result"], rerank_result_2["result"]]} - - -@pytest.fixture -def input_sentence_list_scores(input_sentence_similarities_scores): - return {"results": input_sentence_similarities_scores} - - -@pytest.fixture -def input_sentence_similarity_scores_2(): - return {"scores": [random.uniform(-99999, 99999) for _ in range(10)]} - - -@pytest.fixture -def input_sentence_similarities_scores( - input_sentence_similarity_scores_1, input_sentence_similarity_scores_2 -): - return [ - SentenceSimilarityResult(**input_sentence_similarity_scores_1), - SentenceSimilarityResult(**input_sentence_similarity_scores_2), - ] - - -## Tests ######################################################################## - - -@pytest.mark.parametrize( - "data_object, inputs", - [ - (RerankScore, "input_score"), - (RerankScore, "input_random_score"), - (RerankResult, "rerank_result_1"), - (RerankResults, "rerank_results"), - (SentenceSimilarityResult, "sentence_similarity_result"), - (SentenceSimilarityResults, "sentence_similarity_results"), - ], -) -def test_data_object(data_object, inputs, request): - # Init data object - fixture_values = request.getfixturevalue(inputs) - new_do_from_init = data_object(**fixture_values) - assert isinstance(new_do_from_init, data_object) - assert_fields_match(new_do_from_init, fixture_values) - - # Test to/from proto - proto_from_dm = new_do_from_init.to_proto() - new_do_from_proto = data_object.from_proto(proto_from_dm) - assert isinstance(new_do_from_proto, data_object) - assert_fields_match(new_do_from_proto, fixture_values) - assert new_do_from_init == new_do_from_proto - - # Test to/from json - json_from_dm = new_do_from_init.to_json() - new_do_from_json = data_object.from_json(json_from_dm) - assert isinstance(new_do_from_json, data_object) - assert_fields_match(new_do_from_json, fixture_values) - assert new_do_from_init == new_do_from_json - - -def assert_fields_match(data_object, inputs): - assert all(getattr(data_object, key) == value for key, value in inputs.items()) From 30f2faa228312a21476b29f01e2d00729601a584 Mon Sep 17 00:00:00 2001 From: markstur Date: Mon, 20 Nov 2023 11:34:41 -0800 Subject: [PATCH 16/17] Text Embedding: Require caikit 0.25 or later Signed-off-by: markstur --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e99ce0b0..616923bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.24.0,<0.25.0", + "caikit[runtime-grpc,runtime-http]>=0.25.0,<0.26.0", "caikit-tgis-backend>=0.1.17,<0.2.0", # TODO: loosen dependencies "accelerate>=0.22.0", From 6faabe28eb30fa2f4c81c4b8da8bd7ba0de99321 Mon Sep 17 00:00:00 2001 From: markstur Date: Mon, 20 Nov 2023 15:01:54 -0800 Subject: [PATCH 17/17] Bump caikit-tgis-backend to new minimum avoiding conflicts Signed-off-by: markstur --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 616923bd..7df014ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers=[ ] dependencies = [ "caikit[runtime-grpc,runtime-http]>=0.25.0,<0.26.0", - "caikit-tgis-backend>=0.1.17,<0.2.0", + "caikit-tgis-backend>=0.1.25,<0.2.0", # TODO: loosen dependencies "accelerate>=0.22.0", "datasets>=2.4.0",