Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Dec 10, 2023
1 parent 36b8ef9 commit cbb5c91
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 21 deletions.
2 changes: 1 addition & 1 deletion rag_based_llm/prompt/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(self, query, **prompt_kwargs):
"""
max_tokens = prompt_kwargs.get("max_tokens", 1024)

api_context = self.api_semantic_retriever.k_neighbors(query)[0]
api_context = self.api_semantic_retriever.k_neighbors(query)
context = "\n".join(
f"source: {api['source']} \n content: {api['text']}\n"
for api in api_context
Expand Down
3 changes: 2 additions & 1 deletion rag_based_llm/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._lexical import BM25Retriever
from ._semantic import SemanticRetriever

__all__ = ["SemanticRetriever"]
__all__ = ["BM25Retriever", "SemanticRetriever"]
127 changes: 127 additions & 0 deletions rag_based_llm/retrieval/_lexical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from numbers import Integral

import numpy as np
from sklearn.base import BaseEstimator, clone
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.utils._param_validation import HasMethods, Interval
from sklearn.utils.validation import check_is_fitted


class BM25Retriever(BaseEstimator):
"""Retrieve the k-nearest neighbors using a lexical search based on BM25.
Parameters
----------
count_vectorizer : transformer, default=None
A count vectorizer to compute the count of terms in documents. If None, a
:class:`sklearn.feature_extraction.text.CountVectorizer` is used.
n_neighbors : int, default=1
Number of neighbors to retrieve.
Attributes
----------
X_fit_ : list of str or dict
The input data.
X_embedded_ : ndarray of shape (n_sentences, n_features)
The embedded data.
"""

_parameter_constraints = {
"count_vectorizer": [HasMethods(["fit_transform", "transform"]), None],
"n_neighbors": [Interval(Integral, left=1, right=None, closed="left")],
}

def __init__(self, *, count_vectorizer=None, n_neighbors=1, b=0.75, k1=1.6):
self.count_vectorizer = count_vectorizer
self.n_neighbors = n_neighbors
self.b = b
self.k1 = k1

def fit(self, X, y=None):
"""Compute the vocabulary and the idf.
Parameters
----------
X : list of str or dict
The input data.
y : None
This parameter is ignored.
Returns
-------
self
The fitted estimator.
"""
self._validate_params()
self.X_fit_ = X

if isinstance(X[0], dict):
X = [x["text"] for x in X]

if self.count_vectorizer is None:
self.count_vectorizer_ = CountVectorizer().fit(X)
else:
self.count_vectorizer_ = clone(self.count_vectorizer).fit(X)

self.X_counts_ = self.count_vectorizer_.transform(X)
self.n_terms_by_document_ = self.X_counts_.sum(axis=1).A1
self.averaged_document_length_ = self.n_terms_by_document_.mean()

# compute idf
n_documents = len(self.X_fit_)
n_documents_by_term = self.X_counts_.sum(axis=0).A1
numerator = n_documents - n_documents_by_term + 0.5
denominator = n_documents_by_term + 0.5
self.idf_ = np.log(numerator / denominator + 1)
self.idf_[self.idf_ < 0] = 0.25 * np.mean(self.idf_)
return self

def k_neighbors(self, query, *, n_neighbors=None):
"""Retrieve the k-nearest neighbors.
Parameters
----------
query : str
The input data.
n_neighbors : int, default=None
The number of neighbors to retrieve. If None, the `n_neighbors` from the
constructor is used.
Returns
-------
list of str or dict
The k-nearest neighbors from the training set.
"""
check_is_fitted(self, "X_fit_")
if not isinstance(query, str):
raise TypeError(f"query should be a string, got {type(query)}.")
n_neighbors = n_neighbors or self.n_neighbors
query_terms_indices = self.count_vectorizer_.transform([query]).indices
counts_query_in_X_fit = self.X_counts_[:, query_terms_indices].toarray()
idf = self.idf_[query_terms_indices]
numerator = counts_query_in_X_fit * (self.k1 + 1)
denominator = counts_query_in_X_fit + self.k1 * (
1
- self.b
+ self.b
* (
self.n_terms_by_document_.reshape(-1, 1)
/ self.averaged_document_length_
)
)
scores = (idf * numerator / denominator).sum(axis=1)
indices = scores.argsort()[::-1][:n_neighbors]
if isinstance(self.X_fit_[0], dict):
return [
{
"source": self.X_fit_[neighbor]["source"],
"text": self.X_fit_[neighbor]["text"],
}
for neighbor in indices
]
else: # isinstance(self.X_fit_[0], str)
return [self.X_fit_[neighbor] for neighbor in indices]
31 changes: 14 additions & 17 deletions rag_based_llm/retrieval/_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import faiss
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_is_fitted
from sklearn.utils._param_validation import HasMethods, Interval


Expand Down Expand Up @@ -60,12 +61,12 @@ def fit(self, X, y=None):
self.index_.add(self.X_embedded_)
return self

def k_neighbors(self, X, n_neighbors=None):
def k_neighbors(self, query, *, n_neighbors=None):
"""Retrieve the k-nearest neighbors.
Parameters
----------
X : str or list of str or dict
query : str
The input data.
n_neighbors : int, default=None
Expand All @@ -77,25 +78,21 @@ def k_neighbors(self, X, n_neighbors=None):
list of str or dict
The k-nearest neighbors from the training set.
"""
check_is_fitted(self, "X_fit_")
if not isinstance(query, str):
raise TypeError(f"query should be a string, got {type(query)}.")
n_neighbors = n_neighbors or self.n_neighbors
X_embedded = self.embedding.transform(X)
X_embedded = self.embedding.transform(query)
# normalize vectors to compute the cosine similarity
faiss.normalize_L2(X_embedded)
xxx, indices = self.index_.search(X_embedded, n_neighbors)
print(xxx)
_, indices = self.index_.search(X_embedded, n_neighbors)
if isinstance(self.X_fit_[0], dict):
return [
[
{
"source": self.X_fit_[neighbor]["source"],
"text": self.X_fit_[neighbor]["text"],
}
for neighbor in neighbors
]
for neighbors in indices
{
"source": self.X_fit_[neighbor]["source"],
"text": self.X_fit_[neighbor]["text"],
}
for neighbor in indices[0]
]
else: # isinstance(self.X_fit_[0], str)
return [
[self.X_fit_[neighbor] for neighbor in neighbors]
for neighbors in indices
]
return [self.X_fit_[neighbor] for neighbor in indices[0]]
34 changes: 34 additions & 0 deletions rag_based_llm/retrieval/tests/test_lexical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
from sklearn.feature_extraction.text import CountVectorizer

from rag_based_llm.retrieval import BM25Retriever


@pytest.mark.parametrize(
"input_texts, output",
[
(
[
{"source": "source 1", "text": "xxx"},
{"source": "source 2", "text": "yyy"},
],
[{"source": "source 1", "text": "xxx"}],
),
(["xxx", "yyy"], ["xxx"]),
],
)
@pytest.mark.parametrize("count_vectorizer", [None, CountVectorizer()])
def test_lexical_retriever(input_texts, output, count_vectorizer):
"""Check that the SemanticRetriever wrapper works as expected"""
bm25 = BM25Retriever(count_vectorizer=count_vectorizer, n_neighbors=1).fit(
input_texts
)
assert bm25.k_neighbors("xxx") == output


def test_lexical_retriever_error():
"""Check that we raise an error when the input is not a string at inference time."""
input_texts = [{"source": "source 1", "text": "xxx"}]
bm25 = BM25Retriever(n_neighbors=1).fit(input_texts)
with pytest.raises(TypeError):
bm25.k_neighbors(["xxx"])
23 changes: 21 additions & 2 deletions rag_based_llm/retrieval/tests/test_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
{"source": "source 1", "text": "xxx"},
{"source": "source 2", "text": "yyy"},
],
[[{"source": "source 1", "text": "xxx"}]],
[{"source": "source 1", "text": "xxx"}],
),
(["xxx", "yyy"], [["xxx"]]),
(["xxx", "yyy"], ["xxx"]),
],
)
def test_semantic_retriever(input_texts, output):
Expand All @@ -34,3 +34,22 @@ def test_semantic_retriever(input_texts, output):

faiss = SemanticRetriever(embedding=embedder, n_neighbors=1).fit(input_texts)
assert faiss.k_neighbors("xx") == output


def test_semantic_retriever_error():
"""Check that we raise an error when the input is not a string at inference time."""
cache_folder_path = (
Path(__file__).parent.parent.parent / "embedding" / "tests" / "data"
)
model_name_or_path = "sentence-transformers/paraphrase-albert-small-v2"

embedder = SentenceTransformer(
model_name_or_path=model_name_or_path,
cache_folder=str(cache_folder_path),
show_progress_bar=False,
)

input_texts = [{"source": "source 1", "text": "xxx"}]
faiss = SemanticRetriever(embedding=embedder, n_neighbors=1).fit(input_texts)
with pytest.raises(TypeError):
faiss.k_neighbors(["xxxx"])

0 comments on commit cbb5c91

Please sign in to comment.