-
Notifications
You must be signed in to change notification settings - Fork 350
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
144 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,136 +1,158 @@ | ||
import re | ||
from typing import Any, Iterable, Literal, Optional | ||
import asyncio | ||
import os | ||
import uuid | ||
from typing import TYPE_CHECKING, Any, Optional | ||
|
||
try: | ||
from chromadb.api.models.Collection import Collection | ||
from chromadb.api.types import Include, QueryResult | ||
from chromadb import Documents, EmbeddingFunction, Embeddings, GetResult, HttpClient | ||
except ImportError: | ||
raise ImportError( | ||
"You must have `chromadb` installed to use the Chroma vector store. " | ||
"Install it with `pip install 'raggy[chroma]'`." | ||
"The chromadb package is required to query Chroma. Please install" | ||
" it with `pip install chromadb` or `pip install marvin[chroma]`." | ||
) | ||
import raggy | ||
from pydantic import BaseModel, Field, model_validator | ||
from raggy.documents import Document | ||
from raggy.utils import get_distinct_documents | ||
|
||
from marvin.tools.chroma import OpenAIEmbeddingFunction, get_client | ||
from marvin.utilities.asyncio import run_async | ||
|
||
|
||
class Chroma(BaseModel): | ||
"""A wrapper for chromadb.Client - used as an async context manager""" | ||
|
||
client_type: Literal["base", "http"] = "base" | ||
embedding_fn: Any = Field(default_factory=OpenAIEmbeddingFunction) | ||
collection: Optional[Collection] = None | ||
|
||
_in_context: bool = False | ||
|
||
@model_validator(mode="after") | ||
def validate_collection(self): | ||
if not self.collection: | ||
self.collection = get_client(self.client_type).get_or_create_collection( | ||
name="raggy", embedding_function=self.embedding_fn | ||
) | ||
return self | ||
|
||
async def delete( | ||
self, | ||
ids: list[str] = None, | ||
where: dict = None, | ||
where_document: Document = None, | ||
): | ||
await run_async( | ||
self.collection.delete, | ||
ids=ids, | ||
where=where, | ||
where_document=where_document, | ||
) | ||
|
||
async def add(self, documents: list[Document]) -> Iterable[Document]: | ||
documents = get_distinct_documents(documents) | ||
kwargs = dict( | ||
ids=[document.id for document in documents], | ||
documents=[document.text for document in documents], | ||
metadatas=[ | ||
document.metadata.model_dump(exclude_none=True) or None | ||
for document in documents | ||
], | ||
embeddings=[document.embedding or [] for document in documents], | ||
) | ||
|
||
await run_async(self.collection.add, **kwargs) | ||
|
||
get_result = await run_async(self.collection.get, ids=kwargs["ids"]) | ||
|
||
return get_result.get("documents") | ||
|
||
async def query( | ||
self, | ||
query_embeddings: list[list[float]] = None, | ||
query_texts: list[str] = None, | ||
n_results: int = 10, | ||
where: dict = None, | ||
where_document: dict = None, | ||
include: "Include" = ["metadatas"], | ||
**kwargs, | ||
) -> "QueryResult": | ||
return await run_async( | ||
self.collection.query, | ||
query_embeddings=query_embeddings, | ||
query_texts=query_texts, | ||
n_results=n_results, | ||
where=where, | ||
where_document=where_document, | ||
include=include, | ||
**kwargs, | ||
) | ||
from typing import Literal | ||
|
||
import marvin | ||
|
||
async def count(self) -> int: | ||
return await run_async(self.collection.count) | ||
|
||
async def upsert(self, documents: list[Document]): | ||
documents = get_distinct_documents(documents) | ||
kwargs = dict( | ||
ids=[document.id for document in documents], | ||
documents=[document.text for document in documents], | ||
metadatas=[ | ||
document.metadata.model_dump(exclude_none=True) or None | ||
for document in documents | ||
], | ||
embeddings=[document.embedding or [] for document in documents], | ||
if TYPE_CHECKING: | ||
from openai.types import CreateEmbeddingResponse | ||
|
||
QueryResultType = Literal["documents", "distances", "metadatas"] | ||
|
||
try: | ||
HOST, PORT = ( | ||
getattr(marvin.settings, "chroma_server_host"), | ||
getattr(marvin.settings, "chroma_server_http_port"), | ||
) | ||
DEFAULT_COLLECTION_NAME = getattr( | ||
marvin.settings, "chroma_default_collection_name", "marvin" | ||
) | ||
except AttributeError: | ||
HOST = os.environ.get("MARVIN_CHROMA_SERVER_HOST", "localhost") # type: ignore | ||
PORT = os.environ.get("MARVIN_CHROMA_SERVER_HTTP_PORT", 8000) # type: ignore | ||
DEFAULT_COLLECTION_NAME = os.environ.get( | ||
"MARVIN_CHROMA_DEFAULT_COLLECTION_NAME", "marvin" | ||
) | ||
|
||
|
||
def create_openai_embeddings(texts: list[str]) -> list[float]: | ||
"""Create OpenAI embeddings for a list of texts.""" | ||
|
||
try: | ||
import numpy # noqa F401 # type: ignore | ||
except ImportError: | ||
raise ImportError( | ||
"The numpy package is required to create OpenAI embeddings. Please install" | ||
" it with `pip install numpy`." | ||
) | ||
await run_async(self.collection.upsert, **kwargs) | ||
from marvin.client.openai import MarvinClient | ||
|
||
embedding: "CreateEmbeddingResponse" = MarvinClient().client.embeddings.create( | ||
input=[text.replace("\n", " ") for text in texts], | ||
model="text-embedding-ada-002", | ||
) | ||
|
||
get_result = await run_async(self.collection.get, ids=kwargs["ids"]) | ||
return embedding.data[0].embedding | ||
|
||
return get_result.get("documents") | ||
|
||
async def reset_collection(self): | ||
"""Delete and recreate the collection.""" | ||
client = get_client(self.client_type) | ||
await run_async(client.delete_collection, self.collection.name) | ||
self.collection = await run_async( | ||
client.create_collection, | ||
name=self.collection.name, | ||
embedding_function=self.embedding_fn, | ||
class OpenAIEmbeddingFunction(EmbeddingFunction): | ||
def __call__(self, input: Documents) -> Embeddings: | ||
return [create_openai_embeddings(input)] | ||
|
||
|
||
def get_http_client() -> HttpClient: | ||
"""Get a Chroma HTTP client.""" | ||
return HttpClient(host=HOST, port=PORT) | ||
|
||
|
||
async def query_chroma( | ||
query: str, | ||
collection: str = "marvin", | ||
n_results: int = 5, | ||
where: Optional[dict[str, Any]] = None, | ||
where_document: Optional[dict[str, Any]] = None, | ||
include: Optional[list[QueryResultType]] = None, | ||
max_characters: int = 2000, | ||
) -> str: | ||
"""Query a collection of document excerpts for a query. | ||
Example: | ||
User: "What are prefect blocks?" | ||
Assistant: >>> query_chroma("What are prefect blocks?") | ||
""" | ||
collection_object = get_http_client().get_or_create_collection( | ||
name=collection or DEFAULT_COLLECTION_NAME, | ||
embedding_function=OpenAIEmbeddingFunction(), | ||
) | ||
query_result = collection_object.query( | ||
query_texts=[query], | ||
n_results=n_results, | ||
where=where, | ||
where_document=where_document, | ||
include=include or ["documents"], | ||
) | ||
return "".join(doc for doclist in query_result["documents"] for doc in doclist)[ | ||
:max_characters | ||
] | ||
|
||
|
||
async def multi_query_chroma( | ||
queries: list[str], | ||
collection: str = "marvin", | ||
n_results: int = 5, | ||
where: Optional[dict[str, Any]] = None, | ||
where_document: Optional[dict[str, Any]] = None, | ||
include: Optional[list[QueryResultType]] = None, | ||
max_characters: int = 2000, | ||
) -> str: | ||
"""Retrieve excerpts to aid in answering multifacted questions. | ||
Example: | ||
User: "What are prefect blocks and tasks?" | ||
Assistant: >>> multi_query_chroma( | ||
["What are prefect blocks?", "What are prefect tasks?"] | ||
) | ||
multi_query_chroma -> document excerpts explaining both blocks and tasks | ||
""" | ||
|
||
coros = [ | ||
query_chroma( | ||
query, | ||
collection, | ||
n_results, | ||
where, | ||
where_document, | ||
include, | ||
max_characters // len(queries), | ||
) | ||
for query in queries | ||
] | ||
return "\n".join(await asyncio.gather(*coros))[:max_characters] | ||
|
||
|
||
def store_document( | ||
document: str, metadata: dict[str, Any], collection_name: str = "glacial" | ||
) -> GetResult: | ||
"""Store a document in Chroma for future reference. | ||
Args: | ||
document: The document to store. | ||
metadata: The metadata to store with the document. | ||
Returns: | ||
The stored document. | ||
""" | ||
collection = get_http_client().get_or_create_collection( | ||
name=collection_name, embedding_function=OpenAIEmbeddingFunction() | ||
) | ||
doc_id = metadata.get("msg_id", str(uuid.uuid4())) | ||
|
||
collection.add( | ||
ids=[doc_id], | ||
documents=[document], | ||
metadatas=[metadata], | ||
) | ||
|
||
def ok(self) -> bool: | ||
logger = raggy.utilities.logging.get_logger() | ||
try: | ||
version = self.client.get_version() | ||
except Exception as e: | ||
logger.error(f"Cannot connect to Chroma: {e}") | ||
if re.match(r"^\d+\.\d+\.\d+$", version): | ||
logger.debug(f"Connected to Chroma v{version}") | ||
return True | ||
return False | ||
|
||
async def __aenter__(self): | ||
self._in_context = True | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_value, traceback): | ||
self._in_context = False | ||
return collection.get(ids=doc_id) |