diff --git a/README.md b/README.md index aab33851..509485ae 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,15 @@ Caikit-NLP implements concept of "task" from `caikit` framework to define (and c Capabilities provided by `caikit-nlp`: -| Task | Module(s) | Salient Feature(s) | -|----------------------|-------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Text Generation | 1. `PeftPromptTuning`
2. `TextGeneration` | 1. Prompt Tuning, Multi-task Prompt tuning
2. Fine-tuning Both modules above provide optimized inference capability using Text Generation Inference Server | -| Text Classification | 1. `SequenceClassification` | 1. (Work in progress..) | -| Token Classification | 1. `FilteredSpanClassification` | 1. (Work in progress..) | -| Tokenization | 1. `RegexSentenceSplitter` | 1. Demo purposes only | -| Embedding | [COMING SOON] | [COMING SOON] | +| Task | Module(s) | Salient Feature(s) | +|-----------------------------------------------------|------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| TextGenerationTask | 1. `PeftPromptTuning`
2. `TextGeneration` | 1. Prompt Tuning, Multi-task Prompt tuning
2. Fine-tuning Both modules above provide optimized inference capability using Text Generation Inference Server | +| TextClassificationTask | 1. `SequenceClassification` | 1. (Work in progress..) | +| TokenClassificationTask | 1. `FilteredSpanClassification` | 1. (Work in progress..) | +| TokenizationTask | 1. `RegexSentenceSplitter` | 1. Demo purposes only | +| EmbeddingTask
EmbeddingTasks | 1. `TextEmbedding` | 1. TextEmbedding returns a text embedding vector from a local sentence-transformers model
2. EmbeddingTasks takes multiple input texts and returns a corresponding list of vectors. +| SentenceSimilarityTask
SentenceSimilarityTasks | 1. `TextEmbedding` | 1. SentenceSimilarityTask compares one source_sentence to a list of sentences and returns similarity scores in order of the sentences.
2. SentenceSimilarityTasks uses a list of source_sentences (each to be compared to same list of sentences) and returns corresponding lists of outputs. | +| RerankTask
RerankTasks | 1. `TextEmbedding` | 1. RerankTask compares a query to a list of documents and returns top_n scores in order of relevance with indexes to the source documents and optionally returning the documents.
2. RerankTasks takes multiple queries as input and returns a corresponding list of outputs. The same list of documents is used for all queries. | ## Getting Started diff --git a/caikit_nlp/data_model/__init__.py b/caikit_nlp/data_model/__init__.py index dcbf0585..6826b631 100644 --- a/caikit_nlp/data_model/__init__.py +++ b/caikit_nlp/data_model/__init__.py @@ -15,6 +15,5 @@ """ # Local -from . import embedding_vectors, generation -from .embedding_vectors import * +from . import generation from .generation import * diff --git a/caikit_nlp/data_model/embedding_vectors.py b/caikit_nlp/data_model/embedding_vectors.py deleted file mode 100644 index 67f3d5b3..00000000 --- a/caikit_nlp/data_model/embedding_vectors.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Data structures for embedding vector representations -""" -# Standard -from dataclasses import dataclass, field -from typing import List, Union -import json - -# Third Party -from google.protobuf import json_format -import numpy as np - -# First Party -from caikit.core import DataObjectBase, dataobject -from caikit.core.exceptions import error_handler -import alog - -log = alog.use_channel("DATAM") -error = error_handler.get(log) - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class PyFloatSequence(DataObjectBase): - values: List[float] = field(default_factory=list) - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class NpFloat32Sequence(DataObjectBase): - values: List[np.float32] - - @classmethod - def from_proto(cls, proto): - values = np.asarray(proto.values, dtype=np.float32) - return cls(values) - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class NpFloat64Sequence(DataObjectBase): - values: List[np.float64] - - @classmethod - def from_proto(cls, proto): - values = np.asarray(proto.values, dtype=np.float64) - return cls(values) - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class Vector1D(DataObjectBase): - """Data representation for a 1 dimension vector of float-type data.""" - - data: Union[ - PyFloatSequence, - NpFloat32Sequence, - NpFloat64Sequence, - ] - - def __post_init__(self): - error.value_check( - "", - hasattr(self.data, "values"), - ValueError("Vector1D requires a float sequence data object with values."), - ) - - @classmethod - def from_vector(cls, vector): - if vector.dtype == np.float32: - data = NpFloat32Sequence(vector) - elif vector.dtype == np.float64: - data = NpFloat64Sequence(vector) - else: - data = PyFloatSequence(vector) - return cls(data=data) - - @classmethod - def from_json(cls, json_str): - """JSON does not have different float types. Move data into data_pyfloatsequence""" - - json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str - data = json_obj.pop("data") - if data is not None: - json_obj["data_pyfloatsequence"] = data - - json_str = json.dumps(json_obj) - try: - # Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType - parsed_proto = json_format.Parse( - json_str, cls.get_proto_class()(), ignore_unknown_fields=False - ) - - # Use from_proto to return the DataBase object from the parsed proto - return cls.from_proto(parsed_proto) - - except json_format.ParseError as ex: - error("", ValueError(ex)) - - def to_dict(self) -> dict: - """to_dict is needed to make things serializable""" - values = self.data.values if self.data.values is not None else [] - return { - "data": { - # coerce numpy.ndarray and numpy.float32 into JSON serializable list of floats - "values": values.tolist() - if isinstance(values, np.ndarray) - else values - } - } - - @classmethod - def from_proto(cls, proto): - """Wrap the data in an appropriate float sequence, wrapped by this class""" - woo = proto.WhichOneof("data") - if woo is None: - return cls(PyFloatSequence()) - - woo_data = getattr(proto, woo) - if woo == "data_npfloat64sequence": - ret = cls(NpFloat64Sequence.from_proto(woo_data)) - elif woo == "data_npfloat32sequence": - ret = cls(NpFloat32Sequence.from_proto(woo_data)) - else: - ret = cls(PyFloatSequence.from_proto(woo_data)) - return ret - - def fill_proto(self, proto): - """Fill in the data in an appropriate data_""" - values = self.data.values - if values is not None and len(values) > 0: - sample = values[0] - error.type_check( - "", float, np.float32, np.float64, sample=sample - ) - if isinstance(sample, np.float64): - proto.data_npfloat64sequence.values.extend(values) - elif isinstance(sample, np.float32): - proto.data_npfloat32sequence.values.extend(values) - else: - proto.data_pyfloatsequence.values.extend(values) - - return proto - - -@dataobject(package="caikit_data_model.caikit_nlp") -@dataclass -class EmbeddingResult(DataObjectBase): - """Result from text embedding task""" - - result: Vector1D diff --git a/caikit_nlp/modules/text_embedding/__init__.py b/caikit_nlp/modules/text_embedding/__init__.py index f56694f6..2451f4a2 100644 --- a/caikit_nlp/modules/text_embedding/__init__.py +++ b/caikit_nlp/modules/text_embedding/__init__.py @@ -12,6 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Text Embedding Module +===================== + +Implements the following tasks: + + 1. EmbeddingTask: Returns an embedding from an input text string + 2. EmbeddingsTasks: EmbeddingTask but with a list of inputs producing a list of outputs + 3. SentenceSimilarityTask: Compare one source sentence to a list of sentences + 4. SentenceSimilarityTasks: SentenceSimilarityTask but with a list of source sentences producing + a list of outputs + 5. RerankTask: Return top_n documents ordered by relevance given a query + 6. RerankTasks: RerankTask but with a list of queries producing a list of outputs + +""" + # Local from .embedding import EmbeddingModule -from .embedding_tasks import EmbeddingTask diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index ea415048..f64e185e 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -13,29 +13,73 @@ # limitations under the License. # Standard +from typing import List, Optional +import importlib import os -# Third Party -from sentence_transformers import SentenceTransformer - # First Party from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module +from caikit.core.data_model.json_dict import JsonDict from caikit.core.exceptions import error_handler +from caikit.interfaces.common.data_model.vectors import ListOfVector1D, Vector1D +from caikit.interfaces.nlp.data_model import ( + EmbeddingResult, + EmbeddingResults, + RerankResult, + RerankResults, + RerankScore, + RerankScores, + SentenceSimilarityResult, + SentenceSimilarityResults, + SentenceSimilarityScores, +) +from caikit.interfaces.nlp.tasks import ( + EmbeddingTask, + EmbeddingTasks, + RerankTask, + RerankTasks, + SentenceSimilarityTask, + SentenceSimilarityTasks, +) import alog -# Local -from .embedding_tasks import EmbeddingTask -from caikit_nlp.data_model.embedding_vectors import EmbeddingResult, Vector1D - logger = alog.use_channel("TXT_EMB") error = error_handler.get(logger) +# To avoid dependency problems, make sentence-transformers an optional import and +# defer any ModuleNotFoundError until someone actually tries to init a model with this module. +try: + sentence_transformers = importlib.import_module("sentence_transformers") + # Third Party + from sentence_transformers import SentenceTransformer + from sentence_transformers.util import ( + cos_sim, + dot_score, + normalize_embeddings, + semantic_search, + ) +except ModuleNotFoundError: + # When it is not available, create a dummy that raises an error on attempted init() + class SentenceTransformerNotAvailable: + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + # Will reproduce the ModuleNotFoundError if/when anyone actually tries this module/model + importlib.import_module("sentence_transformers") + + SentenceTransformer = SentenceTransformerNotAvailable + @module( "eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f", "EmbeddingModule", "0.0.1", - EmbeddingTask, + tasks=[ + EmbeddingTask, + EmbeddingTasks, + SentenceSimilarityTask, + SentenceSimilarityTasks, + RerankTask, + RerankTasks, + ], ) class EmbeddingModule(ModuleBase): @@ -76,19 +120,251 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": return cls.bootstrap(model_name_or_path=artifacts_path) - def run( - self, input: str, **kwargs # pylint: disable=redefined-builtin - ) -> EmbeddingResult: - """Run inference on model. + @EmbeddingTask.taskmethod() + def run_embedding(self, text: str) -> EmbeddingResult: + """Get embedding for a string. Args: - input: str + text: str Input text to be processed Returns: EmbeddingResult: the result vector nicely wrapped up """ - error.type_check("", str, input=input) + error.type_check("", str, text=text) - return EmbeddingResult(Vector1D.from_vector(self.model.encode(input))) + return EmbeddingResult( + result=Vector1D.from_vector(self.model.encode(text)), + producer_id=self.PRODUCER_ID, + ) + + @EmbeddingTasks.taskmethod() + def run_embeddings(self, texts: List[str]) -> EmbeddingResults: + """Get embedding vectors for texts. + Args: + texts: List[str] + List of input texts to be processed + Returns: + EmbeddingResults: List of vectors. One for each input text (in order). + Each vector is a list of floats (supports various float types). + """ + if isinstance( + texts, str + ): # encode allows str, but the result would lack a dimension + texts = [texts] + + embeddings = self.model.encode(texts) + vectors = [Vector1D.from_vector(e) for e in embeddings] + return EmbeddingResults( + results=ListOfVector1D(vectors=vectors), producer_id=self.PRODUCER_ID + ) + + @SentenceSimilarityTask.taskmethod() + def run_sentence_similarity( + self, source_sentence: str, sentences: List[str] + ) -> SentenceSimilarityResult: + """Get similarity scores for each of sentences compared to the source_sentence. + Args: + source_sentence: str + sentences: List[str] + Sentences to compare to source_sentence + Returns: + SentenceSimilarityResult: Similarity scores for each sentence. + """ + + source_embedding = self.model.encode(source_sentence) + embeddings = self.model.encode(sentences) + + res = cos_sim(source_embedding, embeddings) + return SentenceSimilarityResult( + result=SentenceSimilarityScores(scores=res.tolist()[0]), + producer_id=self.PRODUCER_ID, + ) + + @SentenceSimilarityTasks.taskmethod() + def run_sentence_similarities( + self, source_sentences: List[str], sentences: List[str] + ) -> SentenceSimilarityResults: + """Run sentence-similarities on model. + Args: + source_sentences: List[str] + sentences: List[str] + Sentences to compare to source_sentences + Returns: + SentenceSimilarityResults: Similarity scores for each source sentence in order. + Each one contains the source-sentence's score for each sentence in order. + """ + + source_embedding = self.model.encode(source_sentences) + embeddings = self.model.encode(sentences) + + res = cos_sim(source_embedding, embeddings) + float_list_list = res.tolist() + return SentenceSimilarityResults( + results=[SentenceSimilarityScores(fl) for fl in float_list_list], + producer_id=self.PRODUCER_ID, + ) + + @RerankTask.taskmethod() + def run_rerank_query( + self, + query: str, + documents: List[JsonDict], + top_n: Optional[int] = None, + return_documents: bool = True, + return_query: bool = True, + return_text: bool = True, + ) -> RerankResult: + """Rerank the documents returning the most relevant top_n in order for this query. + Args: + query: str + Query is the source string to be compared to the text of the documents. + documents: List[JsonDict] + Each document is a dict. The text value is used for comparison to the query. + If there is no text key, then _text is used and finally default is "". + top_n: Optional[int] + Results for the top n most relevant documents will be returned. + If top_n is not provided or (not > 0), then all are returned. + return_documents: bool + Default True + Setting to False will disable returning of the input document (index is returned). + return_query: bool + Default True + Setting to False will disable returning of the query (results are in query order) + return_text: bool + Default True + Setting to False will disable returning of document text string that was used. + Returns: + RerankResult + Returns the (top_n) scores in relevance order (most relevant first). + The results always include a score and index which may be used to find the document + in the original documents list. Optionally, the results also contain the entire + document with its score (for use in chaining) and for convenience the query and + text used for comparison may be returned. + + """ + + error.type_check( + "", + str, + query=query, + ) + + results = self.run_rerank_queries( + queries=[query], + documents=documents, + top_n=top_n, + return_documents=return_documents, + return_queries=return_query, + return_text=return_text, + ).results + + if results: + return RerankResult(result=results[0], producer_id=self.PRODUCER_ID) + + RerankResult( + producer_id=self.PRODUCER_ID, + result=RerankScore( + scores=[], + query=query if return_query else None, + ), + ) + + @RerankTasks.taskmethod() + def run_rerank_queries( + self, + queries: List[str], + documents: List[JsonDict], + top_n: Optional[int] = None, + return_documents: bool = True, + return_queries: bool = True, + return_text: bool = True, + ) -> RerankResults: + """Rerank the documents returning the most relevant top_n in order for each of the queries. + Args: + queries: List[str] + Each of the queries will be compared to the text of each of the documents. + documents: List[JsonDict] + Each document is a dict. The text value is used for comparison to the query. + If there is no text key, then _text is used and finally default is "". + top_n: Optional[int] + Results for the top n most relevant documents will be returned. + If top_n is not provided or (not > 0), then all are returned. + return_documents: bool + Default True + Setting to False will disable returning of the input document (index is returned). + return_queries: bool + Default True + Setting to False will disable returning of the query (results are in query order) + return_text: bool + Default True + Setting to False will disable returning of document text string that was used. + Returns: + RerankResults + For each query in queries (in the original order)... + Returns the (top_n) scores in relevance order (most relevant first). + The results always include a score and index which may be used to find the document + in the original documents list. Optionally, the results also contain the entire + document with its score (for use in chaining) and for convenience the query and + text used for comparison may be returned. + """ + + error.type_check( + "", + list, + queries=queries, + documents=documents, + ) + + error.value_check( + "", + queries and documents, + "Cannot rerank without a query and at least one document", + ) + + if top_n is None or top_n < 1: + top_n = len(documents) + + # Using input document dicts so get "text" else "_text" else default to "" + def get_text(doc): + return doc.get("text") or doc.get("_text", "") + + doc_texts = [get_text(doc) for doc in documents] + + doc_embeddings = normalize_embeddings( + self.model.encode(doc_texts, convert_to_tensor=True).to(self.model.device) + ) + + query_embeddings = normalize_embeddings( + self.model.encode(queries, convert_to_tensor=True).to(self.model.device) + ) + + res = semantic_search( + query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score + ) + + # Fixup result dicts + for r in res: + for x in r: + # Renaming corpus_id to index + corpus_id = x.pop("corpus_id") + x["index"] = corpus_id + # Optionally adding the original document and/or just the text that was used + if return_documents: + x["document"] = documents[corpus_id] + if return_text: + x["text"] = get_text(documents[corpus_id]) + + def add_query(q): + return queries[q] if return_queries else None + + results = [ + RerankScores( + query=add_query(q), + scores=[RerankScore(**x) for x in r], + ) + for q, r in enumerate(res) + ] + + return RerankResults(results=results, producer_id=self.PRODUCER_ID) @classmethod def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule": @@ -108,12 +384,11 @@ def save(self, model_path: str, *args, **kwargs): Path to model config """ - model_config_path = model_path # because the param name is misleading - - error.type_check("", str, model_path=model_config_path) + error.type_check("", str, model_path=model_path) + model_config_path = model_path.strip() error.value_check( "", - model_config_path is not None and model_config_path.strip(), + model_config_path, f"model_path '{model_config_path}' is invalid", ) diff --git a/caikit_nlp/modules/text_embedding/embedding_tasks.py b/caikit_nlp/modules/text_embedding/embedding_tasks.py deleted file mode 100644 index 07b31dba..00000000 --- a/caikit_nlp/modules/text_embedding/embedding_tasks.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard - -# First Party -from caikit.core import TaskBase, task - -# Local -from ...data_model import EmbeddingResult - - -@task( - required_parameters={"input": str}, - output_type=EmbeddingResult, -) -class EmbeddingTask(TaskBase): - pass diff --git a/pyproject.toml b/pyproject.toml index ce5200dc..7df014ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,8 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.24.0,<0.25.0", - "caikit-tgis-backend>=0.1.17,<0.2.0", + "caikit[runtime-grpc,runtime-http]>=0.25.0,<0.26.0", + "caikit-tgis-backend>=0.1.25,<0.2.0", # TODO: loosen dependencies "accelerate>=0.22.0", "datasets>=2.4.0", @@ -24,7 +24,6 @@ dependencies = [ "pandas>=1.5.0", "scikit-learn>=1.1", "scipy>=1.8.1", - "sentence-transformers~=2.2.2", "tokenizers>=0.13.3", "torch>=2.0.1", "tqdm>=4.65.0", diff --git a/sentence-transformers.nodeps.txt b/sentence-transformers.nodeps.txt new file mode 100644 index 00000000..dad81257 --- /dev/null +++ b/sentence-transformers.nodeps.txt @@ -0,0 +1,7 @@ +# These can be installed with --no-deps. + +# Minimum needed to use sentence-transformers: +sentence-transformers>=2.2.2,<2.3.0 +nltk>=3.8.1,<3.9.0 +Pillow>=10.0.0,<10.1.0 + diff --git a/tests/data_model/test_embedding_vectors.py b/tests/data_model/test_embedding_vectors.py deleted file mode 100644 index 010eac10..00000000 --- a/tests/data_model/test_embedding_vectors.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Test for embedding vectors -""" -# Standard -from collections import namedtuple - -# Third Party -import numpy as np -import pytest - -# Local -from caikit_nlp import data_model as dm - -## Setup ######################################################################### - -RANDOM_SEED = 77 -DUMMY_VECTOR_SHAPE = (5,) - -# To tests the limits of our type-checking, this can replace our legit data objects -TRICK_SEQUENCE = namedtuple("Trick", "values") - -np.random.seed(RANDOM_SEED) - -random_number_generator = np.random.default_rng() - -random_numpy_vector1d_float32 = random_number_generator.random( - DUMMY_VECTOR_SHAPE, dtype=np.float32 -) -random_numpy_vector1d_float64 = random_number_generator.random( - DUMMY_VECTOR_SHAPE, dtype=np.float64 -) -random_python_vector1d_float = random_numpy_vector1d_float32.tolist() - -## Tests ######################################################################## - - -@pytest.mark.parametrize( - "sequence", - [ - dm.PyFloatSequence(), - dm.NpFloat32Sequence(), - dm.NpFloat64Sequence(), - TRICK_SEQUENCE(values=None), - ], - ids=type, -) -def test_empty_sequences(sequence): - """No type check error with empty sequences""" - new_dm_from_init = dm.Vector1D(sequence) - assert isinstance(new_dm_from_init.data, type(sequence)) - assert new_dm_from_init.data.values is None - - # Test proto - proto_from_dm = new_dm_from_init.to_proto() - new_dm_from_proto = dm.Vector1D.from_proto(proto_from_dm) - assert isinstance(new_dm_from_proto, dm.Vector1D) - assert new_dm_from_proto.data.values is None - - # Test json - json_from_dm = new_dm_from_init.to_json() - new_dm_from_json = dm.Vector1D.from_json(json_from_dm) - assert isinstance(new_dm_from_json, dm.Vector1D) - assert new_dm_from_json.data.values == [] - - -def test_vector1d_iterator_error(): - """Cannot just shove in an iterator and expect it to work""" - with pytest.raises(ValueError): - dm.Vector1D(data=[1.1, 2.2, 3.3]) - - -def _assert_array_check(new_array, data_values, float_type): - for value in new_array.data.values: - assert isinstance(value, float_type) - np.testing.assert_array_equal(new_array.data.values, data_values) - - -@pytest.mark.parametrize( - "float_seq_class, random_values, float_type", - [ - (dm.PyFloatSequence, random_python_vector1d_float, float), - (dm.NpFloat32Sequence, random_numpy_vector1d_float32, np.float32), - (dm.NpFloat64Sequence, random_numpy_vector1d_float64, np.float64), - (TRICK_SEQUENCE, [1.1, 2.2], float), # Sneaky but tests corner cases for now - ], -) -def test_vector1d_dm(float_seq_class, random_values, float_type): - - # Test init - dm_init = dm.Vector1D(data=float_seq_class(random_values)) - _assert_array_check(dm_init, random_values, float_type) - - # Test proto - dm_to_proto = dm_init.to_proto() - dm_from_proto = dm.Vector1D.from_proto(dm_to_proto) - _assert_array_check(dm_from_proto, random_values, float_type) - - # Test json - dm_to_json = dm_init.to_json() - dm_from_json = dm.Vector1D.from_json(dm_to_json) - _assert_array_check( - dm_from_json, random_values, float - ) # NOTE: always float after json diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index af4d0556..6dd25ff8 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -1,6 +1,7 @@ """Tests for text embedding module """ # Standard +from typing import List import os import tempfile @@ -9,8 +10,18 @@ import numpy as np import pytest +# First Party +from caikit.core import ModuleConfig +from caikit.interfaces.common.data_model.vectors import ListOfVector1D +from caikit.interfaces.nlp.data_model import ( + EmbeddingResult, + RerankResult, + RerankResults, + RerankScore, + RerankScores, +) + # Local -from caikit_nlp.data_model import EmbeddingResult from caikit_nlp.modules.text_embedding import EmbeddingModule from tests.fixtures import SEQ_CLASS_MODEL @@ -22,35 +33,109 @@ INPUT = "The quick brown fox jumps over the lazy dog." +QUERY = "What is foo bar?" + +QUERIES: List[str] = [ + "Who is foo?", + "Where is the bar?", +] + +# These are used to test that documents can handle different types in and out +TYPE_KEYS = "str_test", "int_test", "float_test", "nested_dict_test" + +DOCS = [ + { + "text": "foo", + "title": "title or whatever", + "str_test": "test string", + "int_test": 1, + "float_test": 1.234, + "score": 99999, + "nested_dict_test": {"deep1": 1, "deep string": "just testing"}, + }, + { + "_text": "bar", + "title": "title 2", + }, + { + "text": "foo and bar", + }, + { + "_text": "Where is the bar", + "another": "something else", + }, +] + +# Use text or _text from DOCS for our test sentences +SENTENCES = [d.get("text", d.get("_text")) for d in DOCS] + ## Tests ######################################################################## +def _assert_is_expected_vector(vector): + assert isinstance(vector.data.values[0], np.float32) + assert len(vector.data.values) == 32 + # Just testing a few values for readability + assert approx(vector.data.values[0]) == 0.3244932293891907 + assert approx(vector.data.values[1]) == -0.4934631288051605 + assert approx(vector.data.values[2]) == 0.5721234083175659 + + def _assert_is_expected_embedding_result(actual): assert isinstance(actual, EmbeddingResult) - assert isinstance(actual.result.data.values[0], np.float32) - assert len(actual.result.data.values) == 32 - # Just testing a few values for readability - assert approx(actual.result.data.values[0]) == 0.3244932293891907 - assert approx(actual.result.data.values[1]) == -0.4934631288051605 - assert approx(actual.result.data.values[2]) == 0.5721234083175659 + vector = actual.result + _assert_is_expected_vector(vector) -def test_bootstrap_and_run(): - """Check if we can bootstrap and run embedding""" - model = EmbeddingModule.bootstrap(SEQ_CLASS_MODEL) - result = model.run(INPUT) - _assert_is_expected_embedding_result(result) +def _assert_is_expected_embeddings_results(actual): + assert isinstance(actual, ListOfVector1D) + _assert_is_expected_vector(actual.vectors[0]) -def test_run_type_check(): - """Input cannot be a list""" - model = BOOTSTRAPPED_MODEL - with pytest.raises(TypeError): - model.run([INPUT]) - pytest.fail("Should not reach here") +def test_bootstrap(): + assert isinstance( + EmbeddingModule.bootstrap(SEQ_CLASS_MODEL), EmbeddingModule + ), "bootstrap error" + + +def _assert_types_found(types_found): + assert type(types_found["str_test"]) == str, "passthru str value type check" + assert type(types_found["int_test"]) == int, "passthru int value type check" + assert type(types_found["float_test"]) == float, "passthru float value type check" + assert ( + type(types_found["nested_dict_test"]) == dict + ), "passthru nested dict value type check" + + +def _assert_valid_scores(scores, type_tests={}): + for score in scores: + assert isinstance(score, RerankScore) + assert isinstance(score.score, float) + assert isinstance(score.index, int) + assert isinstance(score.text, str) + + document = score.document + assert isinstance(document, dict) + assert document == DOCS[score.index] + + # Test document key named score (None or 9999) is independent of the result score + assert score.score != document.get( + "score" + ), "unexpected passthru score same as result score" + + # Gather various type test values when we have them + for k, v in document.items(): + if k in TYPE_KEYS: + type_tests[k] = v + + return type_tests -def test_save_load_and_run_model(): +def test_bootstrap(): + assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap error" + + +def test_save_load_and_run(): """Check if we can load and run a saved model successfully""" model_id = "model_id" with tempfile.TemporaryDirectory(suffix="-1st") as model_dir: @@ -58,7 +143,11 @@ def test_save_load_and_run_model(): BOOTSTRAPPED_MODEL.save(model_path) new_model = EmbeddingModule.load(model_path) - result = new_model.run(input=INPUT) + assert isinstance(new_model, EmbeddingModule), "save and load error" + assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model" + + # Use run_embedding just to make sure this new model is usable + result = new_model.run_embedding(text=INPUT) _assert_is_expected_embedding_result(result) @@ -94,3 +183,198 @@ def test_second_save_hits_exists_check(): def test_save_type_checks(model_path): with pytest.raises(TypeError): BOOTSTRAPPED_MODEL.save(model_path) + + +def test_load_without_artifacts(): + """Test coverage for the error message when config has no artifacts to load""" + with pytest.raises(ValueError): + EmbeddingModule.load(ModuleConfig({})) + + +def test_run_embedding_type_check(): + """Input cannot be a list""" + model = BOOTSTRAPPED_MODEL + with pytest.raises(TypeError): + model.run_embedding([INPUT]) + pytest.fail("Should not reach here") + + +def test_run_embedding(): + model = BOOTSTRAPPED_MODEL + res = model.run_embedding(text=INPUT) + _assert_is_expected_embedding_result(res) + + +def test_run_embeddings_str_type(): + """Supposed to be a list, gets fixed automatically.""" + model = BOOTSTRAPPED_MODEL + res = model.run_embeddings(texts=INPUT) + assert isinstance(res.results.vectors, list) + assert len(res.results.vectors) == 1 + + +def test_run_embeddings(): + model = BOOTSTRAPPED_MODEL + res = model.run_embeddings(texts=[INPUT]) + assert isinstance(res.results.vectors, list) + _assert_is_expected_embeddings_results(res.results) + + +@pytest.mark.parametrize( + "query,docs,top_n", + [ + (["test list"], DOCS, None), + (None, DOCS, 1234), + (False, DOCS, 1234), + (QUERY, {"testdict": "not list"}, 1234), + (QUERY, DOCS, "topN string is not an integer or None"), + ], +) +def test_run_rerank_query_type_error(query, docs, top_n): + """test for type checks matching task/run signature""" + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.run_rerank_query(query=query, documents=docs, top_n=top_n) + pytest.fail("Should not reach here.") + + +def test_run_rerank_query_no_type_error(): + """no type error with list of string queries and list of dict documents""" + BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS, top_n=1) + + +@pytest.mark.parametrize( + "top_n, expected", + [ + (1, 1), + (2, 2), + (None, len(DOCS)), + (-1, len(DOCS)), + (0, len(DOCS)), + (9999, len(DOCS)), + ], +) +def test_run_rerank_query_top_n(top_n, expected): + res = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n) + assert isinstance(res, RerankResult) + assert len(res.result.scores) == expected + + +def test_run_rerank_query_no_query(): + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.run_rerank_query(query=None, documents=DOCS, top_n=99) + + +def test_run_rerank_query_zero_docs(): + """No empty doc list therefore result is zero result scores""" + with pytest.raises(ValueError): + BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=[], top_n=99) + + +def test_run_rerank_query(): + res = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS) + assert isinstance(res, RerankResult) + + scores = res.result.scores + assert isinstance(scores, list) + assert len(scores) == len(DOCS) + + types_found = _assert_valid_scores(scores) + _assert_types_found(types_found) + + +@pytest.mark.parametrize( + "queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})] +) +def test_run_rerank_queries_type_error(queries, docs): + """type error check ensures params are lists and not just 1 string or just one doc (for example)""" + with pytest.raises(TypeError): + BOOTSTRAPPED_MODEL.run_rerank_queries(queries=queries, documents=docs) + pytest.fail("Should not reach here.") + + +def test_run_rerank_queries_no_type_error(): + """no type error with list of string queries and list of dict documents""" + BOOTSTRAPPED_MODEL.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99) + + +@pytest.mark.parametrize( + "top_n, expected", + [ + (1, 1), + (2, 2), + (None, len(DOCS)), + (-1, len(DOCS)), + (0, len(DOCS)), + (9999, len(DOCS)), + ], +) +def test_run_rerank_queries_top_n(top_n, expected): + """no type error with list of string queries and list of dict documents""" + res = BOOTSTRAPPED_MODEL.run_rerank_queries( + queries=QUERIES, documents=DOCS, top_n=top_n + ) + assert isinstance(res, RerankResults) + assert len(res.results) == len(QUERIES) + for result in res.results: + assert len(result.scores) == expected + + +@pytest.mark.parametrize( + "queries, docs", + [ + ([], DOCS), + (QUERIES, []), + ([], []), + ], + ids=["no queries", "no docs", "no queries and no docs"], +) +def test_run_rerank_queries_no_queries_or_no_docs(queries, docs): + """No queries and/or no docs therefore result is zero results""" + + with pytest.raises(ValueError): + BOOTSTRAPPED_MODEL.run_rerank_queries(queries=queries, documents=docs, top_n=9) + + +def test_run_rerank_queries(): + top_n = 2 + rerank_result = BOOTSTRAPPED_MODEL.run_rerank_queries( + queries=QUERIES, documents=DOCS, top_n=top_n + ) + assert isinstance(rerank_result, RerankResults) + + results = rerank_result.results + assert isinstance(results, list) + assert len(results) == 2 == len(QUERIES) # 2 queries yields 2 result(s) + + types_found = {} # Gather the type tests from any of the results + + for result in results: + assert isinstance(result, RerankScores) + scores = result.scores + assert isinstance(scores, list) + assert len(scores) == top_n + types_found = _assert_valid_scores(scores, types_found) + + # Make sure our document fields of different types made it in/out ok + _assert_types_found(types_found) + + +def test_run_sentence_similarity(): + model = BOOTSTRAPPED_MODEL + res = model.run_sentence_similarity(source_sentence=QUERY, sentences=SENTENCES) + scores = res.result.scores + assert len(scores) == len(SENTENCES) + for score in scores: + assert isinstance(score, float) + + +def test_run_sentence_similarities(): + model = BOOTSTRAPPED_MODEL + res = model.run_sentence_similarities(source_sentences=QUERIES, sentences=SENTENCES) + results = res.results + assert len(results) == len(QUERIES) + for result in results: + scores = result.scores + assert len(scores) == len(SENTENCES) + for score in scores: + assert isinstance(score, float) diff --git a/tox.ini b/tox.ini index ed361a28..d336b2a5 100644 --- a/tox.ini +++ b/tox.ini @@ -15,7 +15,9 @@ passenv = LOG_FORMATTER LOG_THREAD_ID LOG_CHANNEL_WIDTH -commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} +commands = + python -I -m pip install --force-reinstall --no-deps -r sentence-transformers.nodeps.txt + pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} ; Unclear: We probably want to test wheel packaging ; But! tox will fail when this is set and _any_ interpreter is missing