diff --git a/README.md b/README.md
index aab33851..509485ae 100644
--- a/README.md
+++ b/README.md
@@ -8,13 +8,15 @@ Caikit-NLP implements concept of "task" from `caikit` framework to define (and c
Capabilities provided by `caikit-nlp`:
-| Task | Module(s) | Salient Feature(s) |
-|----------------------|-------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| Text Generation | 1. `PeftPromptTuning`
2. `TextGeneration` | 1. Prompt Tuning, Multi-task Prompt tuning
2. Fine-tuning Both modules above provide optimized inference capability using Text Generation Inference Server |
-| Text Classification | 1. `SequenceClassification` | 1. (Work in progress..) |
-| Token Classification | 1. `FilteredSpanClassification` | 1. (Work in progress..) |
-| Tokenization | 1. `RegexSentenceSplitter` | 1. Demo purposes only |
-| Embedding | [COMING SOON] | [COMING SOON] |
+| Task | Module(s) | Salient Feature(s) |
+|-----------------------------------------------------|------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| TextGenerationTask | 1. `PeftPromptTuning`
2. `TextGeneration` | 1. Prompt Tuning, Multi-task Prompt tuning
2. Fine-tuning Both modules above provide optimized inference capability using Text Generation Inference Server |
+| TextClassificationTask | 1. `SequenceClassification` | 1. (Work in progress..) |
+| TokenClassificationTask | 1. `FilteredSpanClassification` | 1. (Work in progress..) |
+| TokenizationTask | 1. `RegexSentenceSplitter` | 1. Demo purposes only |
+| EmbeddingTask
EmbeddingTasks | 1. `TextEmbedding` | 1. TextEmbedding returns a text embedding vector from a local sentence-transformers model
2. EmbeddingTasks takes multiple input texts and returns a corresponding list of vectors.
+| SentenceSimilarityTask
SentenceSimilarityTasks | 1. `TextEmbedding` | 1. SentenceSimilarityTask compares one source_sentence to a list of sentences and returns similarity scores in order of the sentences.
2. SentenceSimilarityTasks uses a list of source_sentences (each to be compared to same list of sentences) and returns corresponding lists of outputs. |
+| RerankTask
RerankTasks | 1. `TextEmbedding` | 1. RerankTask compares a query to a list of documents and returns top_n scores in order of relevance with indexes to the source documents and optionally returning the documents.
2. RerankTasks takes multiple queries as input and returns a corresponding list of outputs. The same list of documents is used for all queries. |
## Getting Started
diff --git a/caikit_nlp/data_model/__init__.py b/caikit_nlp/data_model/__init__.py
index dcbf0585..6826b631 100644
--- a/caikit_nlp/data_model/__init__.py
+++ b/caikit_nlp/data_model/__init__.py
@@ -15,6 +15,5 @@
"""
# Local
-from . import embedding_vectors, generation
-from .embedding_vectors import *
+from . import generation
from .generation import *
diff --git a/caikit_nlp/data_model/embedding_vectors.py b/caikit_nlp/data_model/embedding_vectors.py
deleted file mode 100644
index 67f3d5b3..00000000
--- a/caikit_nlp/data_model/embedding_vectors.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright The Caikit Authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Data structures for embedding vector representations
-"""
-# Standard
-from dataclasses import dataclass, field
-from typing import List, Union
-import json
-
-# Third Party
-from google.protobuf import json_format
-import numpy as np
-
-# First Party
-from caikit.core import DataObjectBase, dataobject
-from caikit.core.exceptions import error_handler
-import alog
-
-log = alog.use_channel("DATAM")
-error = error_handler.get(log)
-
-
-@dataobject(package="caikit_data_model.caikit_nlp")
-@dataclass
-class PyFloatSequence(DataObjectBase):
- values: List[float] = field(default_factory=list)
-
-
-@dataobject(package="caikit_data_model.caikit_nlp")
-@dataclass
-class NpFloat32Sequence(DataObjectBase):
- values: List[np.float32]
-
- @classmethod
- def from_proto(cls, proto):
- values = np.asarray(proto.values, dtype=np.float32)
- return cls(values)
-
-
-@dataobject(package="caikit_data_model.caikit_nlp")
-@dataclass
-class NpFloat64Sequence(DataObjectBase):
- values: List[np.float64]
-
- @classmethod
- def from_proto(cls, proto):
- values = np.asarray(proto.values, dtype=np.float64)
- return cls(values)
-
-
-@dataobject(package="caikit_data_model.caikit_nlp")
-@dataclass
-class Vector1D(DataObjectBase):
- """Data representation for a 1 dimension vector of float-type data."""
-
- data: Union[
- PyFloatSequence,
- NpFloat32Sequence,
- NpFloat64Sequence,
- ]
-
- def __post_init__(self):
- error.value_check(
- "",
- hasattr(self.data, "values"),
- ValueError("Vector1D requires a float sequence data object with values."),
- )
-
- @classmethod
- def from_vector(cls, vector):
- if vector.dtype == np.float32:
- data = NpFloat32Sequence(vector)
- elif vector.dtype == np.float64:
- data = NpFloat64Sequence(vector)
- else:
- data = PyFloatSequence(vector)
- return cls(data=data)
-
- @classmethod
- def from_json(cls, json_str):
- """JSON does not have different float types. Move data into data_pyfloatsequence"""
-
- json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str
- data = json_obj.pop("data")
- if data is not None:
- json_obj["data_pyfloatsequence"] = data
-
- json_str = json.dumps(json_obj)
- try:
- # Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType
- parsed_proto = json_format.Parse(
- json_str, cls.get_proto_class()(), ignore_unknown_fields=False
- )
-
- # Use from_proto to return the DataBase object from the parsed proto
- return cls.from_proto(parsed_proto)
-
- except json_format.ParseError as ex:
- error("", ValueError(ex))
-
- def to_dict(self) -> dict:
- """to_dict is needed to make things serializable"""
- values = self.data.values if self.data.values is not None else []
- return {
- "data": {
- # coerce numpy.ndarray and numpy.float32 into JSON serializable list of floats
- "values": values.tolist()
- if isinstance(values, np.ndarray)
- else values
- }
- }
-
- @classmethod
- def from_proto(cls, proto):
- """Wrap the data in an appropriate float sequence, wrapped by this class"""
- woo = proto.WhichOneof("data")
- if woo is None:
- return cls(PyFloatSequence())
-
- woo_data = getattr(proto, woo)
- if woo == "data_npfloat64sequence":
- ret = cls(NpFloat64Sequence.from_proto(woo_data))
- elif woo == "data_npfloat32sequence":
- ret = cls(NpFloat32Sequence.from_proto(woo_data))
- else:
- ret = cls(PyFloatSequence.from_proto(woo_data))
- return ret
-
- def fill_proto(self, proto):
- """Fill in the data in an appropriate data_"""
- values = self.data.values
- if values is not None and len(values) > 0:
- sample = values[0]
- error.type_check(
- "", float, np.float32, np.float64, sample=sample
- )
- if isinstance(sample, np.float64):
- proto.data_npfloat64sequence.values.extend(values)
- elif isinstance(sample, np.float32):
- proto.data_npfloat32sequence.values.extend(values)
- else:
- proto.data_pyfloatsequence.values.extend(values)
-
- return proto
-
-
-@dataobject(package="caikit_data_model.caikit_nlp")
-@dataclass
-class EmbeddingResult(DataObjectBase):
- """Result from text embedding task"""
-
- result: Vector1D
diff --git a/caikit_nlp/modules/text_embedding/__init__.py b/caikit_nlp/modules/text_embedding/__init__.py
index f56694f6..2451f4a2 100644
--- a/caikit_nlp/modules/text_embedding/__init__.py
+++ b/caikit_nlp/modules/text_embedding/__init__.py
@@ -12,6 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Text Embedding Module
+=====================
+
+Implements the following tasks:
+
+ 1. EmbeddingTask: Returns an embedding from an input text string
+ 2. EmbeddingsTasks: EmbeddingTask but with a list of inputs producing a list of outputs
+ 3. SentenceSimilarityTask: Compare one source sentence to a list of sentences
+ 4. SentenceSimilarityTasks: SentenceSimilarityTask but with a list of source sentences producing
+ a list of outputs
+ 5. RerankTask: Return top_n documents ordered by relevance given a query
+ 6. RerankTasks: RerankTask but with a list of queries producing a list of outputs
+
+"""
+
# Local
from .embedding import EmbeddingModule
-from .embedding_tasks import EmbeddingTask
diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py
index ea415048..f64e185e 100644
--- a/caikit_nlp/modules/text_embedding/embedding.py
+++ b/caikit_nlp/modules/text_embedding/embedding.py
@@ -13,29 +13,73 @@
# limitations under the License.
# Standard
+from typing import List, Optional
+import importlib
import os
-# Third Party
-from sentence_transformers import SentenceTransformer
-
# First Party
from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module
+from caikit.core.data_model.json_dict import JsonDict
from caikit.core.exceptions import error_handler
+from caikit.interfaces.common.data_model.vectors import ListOfVector1D, Vector1D
+from caikit.interfaces.nlp.data_model import (
+ EmbeddingResult,
+ EmbeddingResults,
+ RerankResult,
+ RerankResults,
+ RerankScore,
+ RerankScores,
+ SentenceSimilarityResult,
+ SentenceSimilarityResults,
+ SentenceSimilarityScores,
+)
+from caikit.interfaces.nlp.tasks import (
+ EmbeddingTask,
+ EmbeddingTasks,
+ RerankTask,
+ RerankTasks,
+ SentenceSimilarityTask,
+ SentenceSimilarityTasks,
+)
import alog
-# Local
-from .embedding_tasks import EmbeddingTask
-from caikit_nlp.data_model.embedding_vectors import EmbeddingResult, Vector1D
-
logger = alog.use_channel("TXT_EMB")
error = error_handler.get(logger)
+# To avoid dependency problems, make sentence-transformers an optional import and
+# defer any ModuleNotFoundError until someone actually tries to init a model with this module.
+try:
+ sentence_transformers = importlib.import_module("sentence_transformers")
+ # Third Party
+ from sentence_transformers import SentenceTransformer
+ from sentence_transformers.util import (
+ cos_sim,
+ dot_score,
+ normalize_embeddings,
+ semantic_search,
+ )
+except ModuleNotFoundError:
+ # When it is not available, create a dummy that raises an error on attempted init()
+ class SentenceTransformerNotAvailable:
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
+ # Will reproduce the ModuleNotFoundError if/when anyone actually tries this module/model
+ importlib.import_module("sentence_transformers")
+
+ SentenceTransformer = SentenceTransformerNotAvailable
+
@module(
"eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f",
"EmbeddingModule",
"0.0.1",
- EmbeddingTask,
+ tasks=[
+ EmbeddingTask,
+ EmbeddingTasks,
+ SentenceSimilarityTask,
+ SentenceSimilarityTasks,
+ RerankTask,
+ RerankTasks,
+ ],
)
class EmbeddingModule(ModuleBase):
@@ -76,19 +120,251 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":
return cls.bootstrap(model_name_or_path=artifacts_path)
- def run(
- self, input: str, **kwargs # pylint: disable=redefined-builtin
- ) -> EmbeddingResult:
- """Run inference on model.
+ @EmbeddingTask.taskmethod()
+ def run_embedding(self, text: str) -> EmbeddingResult:
+ """Get embedding for a string.
Args:
- input: str
+ text: str
Input text to be processed
Returns:
EmbeddingResult: the result vector nicely wrapped up
"""
- error.type_check("", str, input=input)
+ error.type_check("", str, text=text)
- return EmbeddingResult(Vector1D.from_vector(self.model.encode(input)))
+ return EmbeddingResult(
+ result=Vector1D.from_vector(self.model.encode(text)),
+ producer_id=self.PRODUCER_ID,
+ )
+
+ @EmbeddingTasks.taskmethod()
+ def run_embeddings(self, texts: List[str]) -> EmbeddingResults:
+ """Get embedding vectors for texts.
+ Args:
+ texts: List[str]
+ List of input texts to be processed
+ Returns:
+ EmbeddingResults: List of vectors. One for each input text (in order).
+ Each vector is a list of floats (supports various float types).
+ """
+ if isinstance(
+ texts, str
+ ): # encode allows str, but the result would lack a dimension
+ texts = [texts]
+
+ embeddings = self.model.encode(texts)
+ vectors = [Vector1D.from_vector(e) for e in embeddings]
+ return EmbeddingResults(
+ results=ListOfVector1D(vectors=vectors), producer_id=self.PRODUCER_ID
+ )
+
+ @SentenceSimilarityTask.taskmethod()
+ def run_sentence_similarity(
+ self, source_sentence: str, sentences: List[str]
+ ) -> SentenceSimilarityResult:
+ """Get similarity scores for each of sentences compared to the source_sentence.
+ Args:
+ source_sentence: str
+ sentences: List[str]
+ Sentences to compare to source_sentence
+ Returns:
+ SentenceSimilarityResult: Similarity scores for each sentence.
+ """
+
+ source_embedding = self.model.encode(source_sentence)
+ embeddings = self.model.encode(sentences)
+
+ res = cos_sim(source_embedding, embeddings)
+ return SentenceSimilarityResult(
+ result=SentenceSimilarityScores(scores=res.tolist()[0]),
+ producer_id=self.PRODUCER_ID,
+ )
+
+ @SentenceSimilarityTasks.taskmethod()
+ def run_sentence_similarities(
+ self, source_sentences: List[str], sentences: List[str]
+ ) -> SentenceSimilarityResults:
+ """Run sentence-similarities on model.
+ Args:
+ source_sentences: List[str]
+ sentences: List[str]
+ Sentences to compare to source_sentences
+ Returns:
+ SentenceSimilarityResults: Similarity scores for each source sentence in order.
+ Each one contains the source-sentence's score for each sentence in order.
+ """
+
+ source_embedding = self.model.encode(source_sentences)
+ embeddings = self.model.encode(sentences)
+
+ res = cos_sim(source_embedding, embeddings)
+ float_list_list = res.tolist()
+ return SentenceSimilarityResults(
+ results=[SentenceSimilarityScores(fl) for fl in float_list_list],
+ producer_id=self.PRODUCER_ID,
+ )
+
+ @RerankTask.taskmethod()
+ def run_rerank_query(
+ self,
+ query: str,
+ documents: List[JsonDict],
+ top_n: Optional[int] = None,
+ return_documents: bool = True,
+ return_query: bool = True,
+ return_text: bool = True,
+ ) -> RerankResult:
+ """Rerank the documents returning the most relevant top_n in order for this query.
+ Args:
+ query: str
+ Query is the source string to be compared to the text of the documents.
+ documents: List[JsonDict]
+ Each document is a dict. The text value is used for comparison to the query.
+ If there is no text key, then _text is used and finally default is "".
+ top_n: Optional[int]
+ Results for the top n most relevant documents will be returned.
+ If top_n is not provided or (not > 0), then all are returned.
+ return_documents: bool
+ Default True
+ Setting to False will disable returning of the input document (index is returned).
+ return_query: bool
+ Default True
+ Setting to False will disable returning of the query (results are in query order)
+ return_text: bool
+ Default True
+ Setting to False will disable returning of document text string that was used.
+ Returns:
+ RerankResult
+ Returns the (top_n) scores in relevance order (most relevant first).
+ The results always include a score and index which may be used to find the document
+ in the original documents list. Optionally, the results also contain the entire
+ document with its score (for use in chaining) and for convenience the query and
+ text used for comparison may be returned.
+
+ """
+
+ error.type_check(
+ "",
+ str,
+ query=query,
+ )
+
+ results = self.run_rerank_queries(
+ queries=[query],
+ documents=documents,
+ top_n=top_n,
+ return_documents=return_documents,
+ return_queries=return_query,
+ return_text=return_text,
+ ).results
+
+ if results:
+ return RerankResult(result=results[0], producer_id=self.PRODUCER_ID)
+
+ RerankResult(
+ producer_id=self.PRODUCER_ID,
+ result=RerankScore(
+ scores=[],
+ query=query if return_query else None,
+ ),
+ )
+
+ @RerankTasks.taskmethod()
+ def run_rerank_queries(
+ self,
+ queries: List[str],
+ documents: List[JsonDict],
+ top_n: Optional[int] = None,
+ return_documents: bool = True,
+ return_queries: bool = True,
+ return_text: bool = True,
+ ) -> RerankResults:
+ """Rerank the documents returning the most relevant top_n in order for each of the queries.
+ Args:
+ queries: List[str]
+ Each of the queries will be compared to the text of each of the documents.
+ documents: List[JsonDict]
+ Each document is a dict. The text value is used for comparison to the query.
+ If there is no text key, then _text is used and finally default is "".
+ top_n: Optional[int]
+ Results for the top n most relevant documents will be returned.
+ If top_n is not provided or (not > 0), then all are returned.
+ return_documents: bool
+ Default True
+ Setting to False will disable returning of the input document (index is returned).
+ return_queries: bool
+ Default True
+ Setting to False will disable returning of the query (results are in query order)
+ return_text: bool
+ Default True
+ Setting to False will disable returning of document text string that was used.
+ Returns:
+ RerankResults
+ For each query in queries (in the original order)...
+ Returns the (top_n) scores in relevance order (most relevant first).
+ The results always include a score and index which may be used to find the document
+ in the original documents list. Optionally, the results also contain the entire
+ document with its score (for use in chaining) and for convenience the query and
+ text used for comparison may be returned.
+ """
+
+ error.type_check(
+ "",
+ list,
+ queries=queries,
+ documents=documents,
+ )
+
+ error.value_check(
+ "",
+ queries and documents,
+ "Cannot rerank without a query and at least one document",
+ )
+
+ if top_n is None or top_n < 1:
+ top_n = len(documents)
+
+ # Using input document dicts so get "text" else "_text" else default to ""
+ def get_text(doc):
+ return doc.get("text") or doc.get("_text", "")
+
+ doc_texts = [get_text(doc) for doc in documents]
+
+ doc_embeddings = normalize_embeddings(
+ self.model.encode(doc_texts, convert_to_tensor=True).to(self.model.device)
+ )
+
+ query_embeddings = normalize_embeddings(
+ self.model.encode(queries, convert_to_tensor=True).to(self.model.device)
+ )
+
+ res = semantic_search(
+ query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score
+ )
+
+ # Fixup result dicts
+ for r in res:
+ for x in r:
+ # Renaming corpus_id to index
+ corpus_id = x.pop("corpus_id")
+ x["index"] = corpus_id
+ # Optionally adding the original document and/or just the text that was used
+ if return_documents:
+ x["document"] = documents[corpus_id]
+ if return_text:
+ x["text"] = get_text(documents[corpus_id])
+
+ def add_query(q):
+ return queries[q] if return_queries else None
+
+ results = [
+ RerankScores(
+ query=add_query(q),
+ scores=[RerankScore(**x) for x in r],
+ )
+ for q, r in enumerate(res)
+ ]
+
+ return RerankResults(results=results, producer_id=self.PRODUCER_ID)
@classmethod
def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule":
@@ -108,12 +384,11 @@ def save(self, model_path: str, *args, **kwargs):
Path to model config
"""
- model_config_path = model_path # because the param name is misleading
-
- error.type_check("", str, model_path=model_config_path)
+ error.type_check("", str, model_path=model_path)
+ model_config_path = model_path.strip()
error.value_check(
"",
- model_config_path is not None and model_config_path.strip(),
+ model_config_path,
f"model_path '{model_config_path}' is invalid",
)
diff --git a/caikit_nlp/modules/text_embedding/embedding_tasks.py b/caikit_nlp/modules/text_embedding/embedding_tasks.py
deleted file mode 100644
index 07b31dba..00000000
--- a/caikit_nlp/modules/text_embedding/embedding_tasks.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright The Caikit Authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Standard
-
-# First Party
-from caikit.core import TaskBase, task
-
-# Local
-from ...data_model import EmbeddingResult
-
-
-@task(
- required_parameters={"input": str},
- output_type=EmbeddingResult,
-)
-class EmbeddingTask(TaskBase):
- pass
diff --git a/pyproject.toml b/pyproject.toml
index ce5200dc..7df014ce 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,8 +14,8 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
- "caikit[runtime-grpc,runtime-http]>=0.24.0,<0.25.0",
- "caikit-tgis-backend>=0.1.17,<0.2.0",
+ "caikit[runtime-grpc,runtime-http]>=0.25.0,<0.26.0",
+ "caikit-tgis-backend>=0.1.25,<0.2.0",
# TODO: loosen dependencies
"accelerate>=0.22.0",
"datasets>=2.4.0",
@@ -24,7 +24,6 @@ dependencies = [
"pandas>=1.5.0",
"scikit-learn>=1.1",
"scipy>=1.8.1",
- "sentence-transformers~=2.2.2",
"tokenizers>=0.13.3",
"torch>=2.0.1",
"tqdm>=4.65.0",
diff --git a/sentence-transformers.nodeps.txt b/sentence-transformers.nodeps.txt
new file mode 100644
index 00000000..dad81257
--- /dev/null
+++ b/sentence-transformers.nodeps.txt
@@ -0,0 +1,7 @@
+# These can be installed with --no-deps.
+
+# Minimum needed to use sentence-transformers:
+sentence-transformers>=2.2.2,<2.3.0
+nltk>=3.8.1,<3.9.0
+Pillow>=10.0.0,<10.1.0
+
diff --git a/tests/data_model/test_embedding_vectors.py b/tests/data_model/test_embedding_vectors.py
deleted file mode 100644
index 010eac10..00000000
--- a/tests/data_model/test_embedding_vectors.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright The Caikit Authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Test for embedding vectors
-"""
-# Standard
-from collections import namedtuple
-
-# Third Party
-import numpy as np
-import pytest
-
-# Local
-from caikit_nlp import data_model as dm
-
-## Setup #########################################################################
-
-RANDOM_SEED = 77
-DUMMY_VECTOR_SHAPE = (5,)
-
-# To tests the limits of our type-checking, this can replace our legit data objects
-TRICK_SEQUENCE = namedtuple("Trick", "values")
-
-np.random.seed(RANDOM_SEED)
-
-random_number_generator = np.random.default_rng()
-
-random_numpy_vector1d_float32 = random_number_generator.random(
- DUMMY_VECTOR_SHAPE, dtype=np.float32
-)
-random_numpy_vector1d_float64 = random_number_generator.random(
- DUMMY_VECTOR_SHAPE, dtype=np.float64
-)
-random_python_vector1d_float = random_numpy_vector1d_float32.tolist()
-
-## Tests ########################################################################
-
-
-@pytest.mark.parametrize(
- "sequence",
- [
- dm.PyFloatSequence(),
- dm.NpFloat32Sequence(),
- dm.NpFloat64Sequence(),
- TRICK_SEQUENCE(values=None),
- ],
- ids=type,
-)
-def test_empty_sequences(sequence):
- """No type check error with empty sequences"""
- new_dm_from_init = dm.Vector1D(sequence)
- assert isinstance(new_dm_from_init.data, type(sequence))
- assert new_dm_from_init.data.values is None
-
- # Test proto
- proto_from_dm = new_dm_from_init.to_proto()
- new_dm_from_proto = dm.Vector1D.from_proto(proto_from_dm)
- assert isinstance(new_dm_from_proto, dm.Vector1D)
- assert new_dm_from_proto.data.values is None
-
- # Test json
- json_from_dm = new_dm_from_init.to_json()
- new_dm_from_json = dm.Vector1D.from_json(json_from_dm)
- assert isinstance(new_dm_from_json, dm.Vector1D)
- assert new_dm_from_json.data.values == []
-
-
-def test_vector1d_iterator_error():
- """Cannot just shove in an iterator and expect it to work"""
- with pytest.raises(ValueError):
- dm.Vector1D(data=[1.1, 2.2, 3.3])
-
-
-def _assert_array_check(new_array, data_values, float_type):
- for value in new_array.data.values:
- assert isinstance(value, float_type)
- np.testing.assert_array_equal(new_array.data.values, data_values)
-
-
-@pytest.mark.parametrize(
- "float_seq_class, random_values, float_type",
- [
- (dm.PyFloatSequence, random_python_vector1d_float, float),
- (dm.NpFloat32Sequence, random_numpy_vector1d_float32, np.float32),
- (dm.NpFloat64Sequence, random_numpy_vector1d_float64, np.float64),
- (TRICK_SEQUENCE, [1.1, 2.2], float), # Sneaky but tests corner cases for now
- ],
-)
-def test_vector1d_dm(float_seq_class, random_values, float_type):
-
- # Test init
- dm_init = dm.Vector1D(data=float_seq_class(random_values))
- _assert_array_check(dm_init, random_values, float_type)
-
- # Test proto
- dm_to_proto = dm_init.to_proto()
- dm_from_proto = dm.Vector1D.from_proto(dm_to_proto)
- _assert_array_check(dm_from_proto, random_values, float_type)
-
- # Test json
- dm_to_json = dm_init.to_json()
- dm_from_json = dm.Vector1D.from_json(dm_to_json)
- _assert_array_check(
- dm_from_json, random_values, float
- ) # NOTE: always float after json
diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py
index af4d0556..6dd25ff8 100644
--- a/tests/modules/text_embedding/test_embedding.py
+++ b/tests/modules/text_embedding/test_embedding.py
@@ -1,6 +1,7 @@
"""Tests for text embedding module
"""
# Standard
+from typing import List
import os
import tempfile
@@ -9,8 +10,18 @@
import numpy as np
import pytest
+# First Party
+from caikit.core import ModuleConfig
+from caikit.interfaces.common.data_model.vectors import ListOfVector1D
+from caikit.interfaces.nlp.data_model import (
+ EmbeddingResult,
+ RerankResult,
+ RerankResults,
+ RerankScore,
+ RerankScores,
+)
+
# Local
-from caikit_nlp.data_model import EmbeddingResult
from caikit_nlp.modules.text_embedding import EmbeddingModule
from tests.fixtures import SEQ_CLASS_MODEL
@@ -22,35 +33,109 @@
INPUT = "The quick brown fox jumps over the lazy dog."
+QUERY = "What is foo bar?"
+
+QUERIES: List[str] = [
+ "Who is foo?",
+ "Where is the bar?",
+]
+
+# These are used to test that documents can handle different types in and out
+TYPE_KEYS = "str_test", "int_test", "float_test", "nested_dict_test"
+
+DOCS = [
+ {
+ "text": "foo",
+ "title": "title or whatever",
+ "str_test": "test string",
+ "int_test": 1,
+ "float_test": 1.234,
+ "score": 99999,
+ "nested_dict_test": {"deep1": 1, "deep string": "just testing"},
+ },
+ {
+ "_text": "bar",
+ "title": "title 2",
+ },
+ {
+ "text": "foo and bar",
+ },
+ {
+ "_text": "Where is the bar",
+ "another": "something else",
+ },
+]
+
+# Use text or _text from DOCS for our test sentences
+SENTENCES = [d.get("text", d.get("_text")) for d in DOCS]
+
## Tests ########################################################################
+def _assert_is_expected_vector(vector):
+ assert isinstance(vector.data.values[0], np.float32)
+ assert len(vector.data.values) == 32
+ # Just testing a few values for readability
+ assert approx(vector.data.values[0]) == 0.3244932293891907
+ assert approx(vector.data.values[1]) == -0.4934631288051605
+ assert approx(vector.data.values[2]) == 0.5721234083175659
+
+
def _assert_is_expected_embedding_result(actual):
assert isinstance(actual, EmbeddingResult)
- assert isinstance(actual.result.data.values[0], np.float32)
- assert len(actual.result.data.values) == 32
- # Just testing a few values for readability
- assert approx(actual.result.data.values[0]) == 0.3244932293891907
- assert approx(actual.result.data.values[1]) == -0.4934631288051605
- assert approx(actual.result.data.values[2]) == 0.5721234083175659
+ vector = actual.result
+ _assert_is_expected_vector(vector)
-def test_bootstrap_and_run():
- """Check if we can bootstrap and run embedding"""
- model = EmbeddingModule.bootstrap(SEQ_CLASS_MODEL)
- result = model.run(INPUT)
- _assert_is_expected_embedding_result(result)
+def _assert_is_expected_embeddings_results(actual):
+ assert isinstance(actual, ListOfVector1D)
+ _assert_is_expected_vector(actual.vectors[0])
-def test_run_type_check():
- """Input cannot be a list"""
- model = BOOTSTRAPPED_MODEL
- with pytest.raises(TypeError):
- model.run([INPUT])
- pytest.fail("Should not reach here")
+def test_bootstrap():
+ assert isinstance(
+ EmbeddingModule.bootstrap(SEQ_CLASS_MODEL), EmbeddingModule
+ ), "bootstrap error"
+
+
+def _assert_types_found(types_found):
+ assert type(types_found["str_test"]) == str, "passthru str value type check"
+ assert type(types_found["int_test"]) == int, "passthru int value type check"
+ assert type(types_found["float_test"]) == float, "passthru float value type check"
+ assert (
+ type(types_found["nested_dict_test"]) == dict
+ ), "passthru nested dict value type check"
+
+
+def _assert_valid_scores(scores, type_tests={}):
+ for score in scores:
+ assert isinstance(score, RerankScore)
+ assert isinstance(score.score, float)
+ assert isinstance(score.index, int)
+ assert isinstance(score.text, str)
+
+ document = score.document
+ assert isinstance(document, dict)
+ assert document == DOCS[score.index]
+
+ # Test document key named score (None or 9999) is independent of the result score
+ assert score.score != document.get(
+ "score"
+ ), "unexpected passthru score same as result score"
+
+ # Gather various type test values when we have them
+ for k, v in document.items():
+ if k in TYPE_KEYS:
+ type_tests[k] = v
+
+ return type_tests
-def test_save_load_and_run_model():
+def test_bootstrap():
+ assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap error"
+
+
+def test_save_load_and_run():
"""Check if we can load and run a saved model successfully"""
model_id = "model_id"
with tempfile.TemporaryDirectory(suffix="-1st") as model_dir:
@@ -58,7 +143,11 @@ def test_save_load_and_run_model():
BOOTSTRAPPED_MODEL.save(model_path)
new_model = EmbeddingModule.load(model_path)
- result = new_model.run(input=INPUT)
+ assert isinstance(new_model, EmbeddingModule), "save and load error"
+ assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model"
+
+ # Use run_embedding just to make sure this new model is usable
+ result = new_model.run_embedding(text=INPUT)
_assert_is_expected_embedding_result(result)
@@ -94,3 +183,198 @@ def test_second_save_hits_exists_check():
def test_save_type_checks(model_path):
with pytest.raises(TypeError):
BOOTSTRAPPED_MODEL.save(model_path)
+
+
+def test_load_without_artifacts():
+ """Test coverage for the error message when config has no artifacts to load"""
+ with pytest.raises(ValueError):
+ EmbeddingModule.load(ModuleConfig({}))
+
+
+def test_run_embedding_type_check():
+ """Input cannot be a list"""
+ model = BOOTSTRAPPED_MODEL
+ with pytest.raises(TypeError):
+ model.run_embedding([INPUT])
+ pytest.fail("Should not reach here")
+
+
+def test_run_embedding():
+ model = BOOTSTRAPPED_MODEL
+ res = model.run_embedding(text=INPUT)
+ _assert_is_expected_embedding_result(res)
+
+
+def test_run_embeddings_str_type():
+ """Supposed to be a list, gets fixed automatically."""
+ model = BOOTSTRAPPED_MODEL
+ res = model.run_embeddings(texts=INPUT)
+ assert isinstance(res.results.vectors, list)
+ assert len(res.results.vectors) == 1
+
+
+def test_run_embeddings():
+ model = BOOTSTRAPPED_MODEL
+ res = model.run_embeddings(texts=[INPUT])
+ assert isinstance(res.results.vectors, list)
+ _assert_is_expected_embeddings_results(res.results)
+
+
+@pytest.mark.parametrize(
+ "query,docs,top_n",
+ [
+ (["test list"], DOCS, None),
+ (None, DOCS, 1234),
+ (False, DOCS, 1234),
+ (QUERY, {"testdict": "not list"}, 1234),
+ (QUERY, DOCS, "topN string is not an integer or None"),
+ ],
+)
+def test_run_rerank_query_type_error(query, docs, top_n):
+ """test for type checks matching task/run signature"""
+ with pytest.raises(TypeError):
+ BOOTSTRAPPED_MODEL.run_rerank_query(query=query, documents=docs, top_n=top_n)
+ pytest.fail("Should not reach here.")
+
+
+def test_run_rerank_query_no_type_error():
+ """no type error with list of string queries and list of dict documents"""
+ BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS, top_n=1)
+
+
+@pytest.mark.parametrize(
+ "top_n, expected",
+ [
+ (1, 1),
+ (2, 2),
+ (None, len(DOCS)),
+ (-1, len(DOCS)),
+ (0, len(DOCS)),
+ (9999, len(DOCS)),
+ ],
+)
+def test_run_rerank_query_top_n(top_n, expected):
+ res = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n)
+ assert isinstance(res, RerankResult)
+ assert len(res.result.scores) == expected
+
+
+def test_run_rerank_query_no_query():
+ with pytest.raises(TypeError):
+ BOOTSTRAPPED_MODEL.run_rerank_query(query=None, documents=DOCS, top_n=99)
+
+
+def test_run_rerank_query_zero_docs():
+ """No empty doc list therefore result is zero result scores"""
+ with pytest.raises(ValueError):
+ BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=[], top_n=99)
+
+
+def test_run_rerank_query():
+ res = BOOTSTRAPPED_MODEL.run_rerank_query(query=QUERY, documents=DOCS)
+ assert isinstance(res, RerankResult)
+
+ scores = res.result.scores
+ assert isinstance(scores, list)
+ assert len(scores) == len(DOCS)
+
+ types_found = _assert_valid_scores(scores)
+ _assert_types_found(types_found)
+
+
+@pytest.mark.parametrize(
+ "queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})]
+)
+def test_run_rerank_queries_type_error(queries, docs):
+ """type error check ensures params are lists and not just 1 string or just one doc (for example)"""
+ with pytest.raises(TypeError):
+ BOOTSTRAPPED_MODEL.run_rerank_queries(queries=queries, documents=docs)
+ pytest.fail("Should not reach here.")
+
+
+def test_run_rerank_queries_no_type_error():
+ """no type error with list of string queries and list of dict documents"""
+ BOOTSTRAPPED_MODEL.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99)
+
+
+@pytest.mark.parametrize(
+ "top_n, expected",
+ [
+ (1, 1),
+ (2, 2),
+ (None, len(DOCS)),
+ (-1, len(DOCS)),
+ (0, len(DOCS)),
+ (9999, len(DOCS)),
+ ],
+)
+def test_run_rerank_queries_top_n(top_n, expected):
+ """no type error with list of string queries and list of dict documents"""
+ res = BOOTSTRAPPED_MODEL.run_rerank_queries(
+ queries=QUERIES, documents=DOCS, top_n=top_n
+ )
+ assert isinstance(res, RerankResults)
+ assert len(res.results) == len(QUERIES)
+ for result in res.results:
+ assert len(result.scores) == expected
+
+
+@pytest.mark.parametrize(
+ "queries, docs",
+ [
+ ([], DOCS),
+ (QUERIES, []),
+ ([], []),
+ ],
+ ids=["no queries", "no docs", "no queries and no docs"],
+)
+def test_run_rerank_queries_no_queries_or_no_docs(queries, docs):
+ """No queries and/or no docs therefore result is zero results"""
+
+ with pytest.raises(ValueError):
+ BOOTSTRAPPED_MODEL.run_rerank_queries(queries=queries, documents=docs, top_n=9)
+
+
+def test_run_rerank_queries():
+ top_n = 2
+ rerank_result = BOOTSTRAPPED_MODEL.run_rerank_queries(
+ queries=QUERIES, documents=DOCS, top_n=top_n
+ )
+ assert isinstance(rerank_result, RerankResults)
+
+ results = rerank_result.results
+ assert isinstance(results, list)
+ assert len(results) == 2 == len(QUERIES) # 2 queries yields 2 result(s)
+
+ types_found = {} # Gather the type tests from any of the results
+
+ for result in results:
+ assert isinstance(result, RerankScores)
+ scores = result.scores
+ assert isinstance(scores, list)
+ assert len(scores) == top_n
+ types_found = _assert_valid_scores(scores, types_found)
+
+ # Make sure our document fields of different types made it in/out ok
+ _assert_types_found(types_found)
+
+
+def test_run_sentence_similarity():
+ model = BOOTSTRAPPED_MODEL
+ res = model.run_sentence_similarity(source_sentence=QUERY, sentences=SENTENCES)
+ scores = res.result.scores
+ assert len(scores) == len(SENTENCES)
+ for score in scores:
+ assert isinstance(score, float)
+
+
+def test_run_sentence_similarities():
+ model = BOOTSTRAPPED_MODEL
+ res = model.run_sentence_similarities(source_sentences=QUERIES, sentences=SENTENCES)
+ results = res.results
+ assert len(results) == len(QUERIES)
+ for result in results:
+ scores = result.scores
+ assert len(scores) == len(SENTENCES)
+ for score in scores:
+ assert isinstance(score, float)
diff --git a/tox.ini b/tox.ini
index ed361a28..d336b2a5 100644
--- a/tox.ini
+++ b/tox.ini
@@ -15,7 +15,9 @@ passenv =
LOG_FORMATTER
LOG_THREAD_ID
LOG_CHANNEL_WIDTH
-commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests}
+commands =
+ python -I -m pip install --force-reinstall --no-deps -r sentence-transformers.nodeps.txt
+ pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests}
; Unclear: We probably want to test wheel packaging
; But! tox will fail when this is set and _any_ interpreter is missing