From 63efa1d44a341a9abbadd353eb43ff9ecdb689b6 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Sat, 3 Feb 2024 14:42:33 -0600 Subject: [PATCH] oops --- src/marvin/tools/chroma.py | 266 ++++++++++++++++++++----------------- 1 file changed, 144 insertions(+), 122 deletions(-) diff --git a/src/marvin/tools/chroma.py b/src/marvin/tools/chroma.py index 1e15cb29a..27617449d 100644 --- a/src/marvin/tools/chroma.py +++ b/src/marvin/tools/chroma.py @@ -1,136 +1,158 @@ -import re -from typing import Any, Iterable, Literal, Optional +import asyncio +import os +import uuid +from typing import TYPE_CHECKING, Any, Optional try: - from chromadb.api.models.Collection import Collection - from chromadb.api.types import Include, QueryResult + from chromadb import Documents, EmbeddingFunction, Embeddings, GetResult, HttpClient except ImportError: raise ImportError( - "You must have `chromadb` installed to use the Chroma vector store. " - "Install it with `pip install 'raggy[chroma]'`." + "The chromadb package is required to query Chroma. Please install" + " it with `pip install chromadb` or `pip install marvin[chroma]`." ) -import raggy -from pydantic import BaseModel, Field, model_validator -from raggy.documents import Document -from raggy.utils import get_distinct_documents - -from marvin.tools.chroma import OpenAIEmbeddingFunction, get_client -from marvin.utilities.asyncio import run_async - - -class Chroma(BaseModel): - """A wrapper for chromadb.Client - used as an async context manager""" - - client_type: Literal["base", "http"] = "base" - embedding_fn: Any = Field(default_factory=OpenAIEmbeddingFunction) - collection: Optional[Collection] = None - - _in_context: bool = False - - @model_validator(mode="after") - def validate_collection(self): - if not self.collection: - self.collection = get_client(self.client_type).get_or_create_collection( - name="raggy", embedding_function=self.embedding_fn - ) - return self - - async def delete( - self, - ids: list[str] = None, - where: dict = None, - where_document: Document = None, - ): - await run_async( - self.collection.delete, - ids=ids, - where=where, - where_document=where_document, - ) - async def add(self, documents: list[Document]) -> Iterable[Document]: - documents = get_distinct_documents(documents) - kwargs = dict( - ids=[document.id for document in documents], - documents=[document.text for document in documents], - metadatas=[ - document.metadata.model_dump(exclude_none=True) or None - for document in documents - ], - embeddings=[document.embedding or [] for document in documents], - ) - await run_async(self.collection.add, **kwargs) - - get_result = await run_async(self.collection.get, ids=kwargs["ids"]) - - return get_result.get("documents") - - async def query( - self, - query_embeddings: list[list[float]] = None, - query_texts: list[str] = None, - n_results: int = 10, - where: dict = None, - where_document: dict = None, - include: "Include" = ["metadatas"], - **kwargs, - ) -> "QueryResult": - return await run_async( - self.collection.query, - query_embeddings=query_embeddings, - query_texts=query_texts, - n_results=n_results, - where=where, - where_document=where_document, - include=include, - **kwargs, - ) +from typing import Literal + +import marvin - async def count(self) -> int: - return await run_async(self.collection.count) - - async def upsert(self, documents: list[Document]): - documents = get_distinct_documents(documents) - kwargs = dict( - ids=[document.id for document in documents], - documents=[document.text for document in documents], - metadatas=[ - document.metadata.model_dump(exclude_none=True) or None - for document in documents - ], - embeddings=[document.embedding or [] for document in documents], +if TYPE_CHECKING: + from openai.types import CreateEmbeddingResponse + +QueryResultType = Literal["documents", "distances", "metadatas"] + +try: + HOST, PORT = ( + getattr(marvin.settings, "chroma_server_host"), + getattr(marvin.settings, "chroma_server_http_port"), + ) + DEFAULT_COLLECTION_NAME = getattr( + marvin.settings, "chroma_default_collection_name", "marvin" + ) +except AttributeError: + HOST = os.environ.get("MARVIN_CHROMA_SERVER_HOST", "localhost") # type: ignore + PORT = os.environ.get("MARVIN_CHROMA_SERVER_HTTP_PORT", 8000) # type: ignore + DEFAULT_COLLECTION_NAME = os.environ.get( + "MARVIN_CHROMA_DEFAULT_COLLECTION_NAME", "marvin" + ) + + +def create_openai_embeddings(texts: list[str]) -> list[float]: + """Create OpenAI embeddings for a list of texts.""" + + try: + import numpy # noqa F401 # type: ignore + except ImportError: + raise ImportError( + "The numpy package is required to create OpenAI embeddings. Please install" + " it with `pip install numpy`." ) - await run_async(self.collection.upsert, **kwargs) + from marvin.client.openai import MarvinClient + + embedding: "CreateEmbeddingResponse" = MarvinClient().client.embeddings.create( + input=[text.replace("\n", " ") for text in texts], + model="text-embedding-ada-002", + ) - get_result = await run_async(self.collection.get, ids=kwargs["ids"]) + return embedding.data[0].embedding - return get_result.get("documents") - async def reset_collection(self): - """Delete and recreate the collection.""" - client = get_client(self.client_type) - await run_async(client.delete_collection, self.collection.name) - self.collection = await run_async( - client.create_collection, - name=self.collection.name, - embedding_function=self.embedding_fn, +class OpenAIEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + return [create_openai_embeddings(input)] + + +def get_http_client() -> HttpClient: + """Get a Chroma HTTP client.""" + return HttpClient(host=HOST, port=PORT) + + +async def query_chroma( + query: str, + collection: str = "marvin", + n_results: int = 5, + where: Optional[dict[str, Any]] = None, + where_document: Optional[dict[str, Any]] = None, + include: Optional[list[QueryResultType]] = None, + max_characters: int = 2000, +) -> str: + """Query a collection of document excerpts for a query. + + Example: + User: "What are prefect blocks?" + Assistant: >>> query_chroma("What are prefect blocks?") + """ + collection_object = get_http_client().get_or_create_collection( + name=collection or DEFAULT_COLLECTION_NAME, + embedding_function=OpenAIEmbeddingFunction(), + ) + query_result = collection_object.query( + query_texts=[query], + n_results=n_results, + where=where, + where_document=where_document, + include=include or ["documents"], + ) + return "".join(doc for doclist in query_result["documents"] for doc in doclist)[ + :max_characters + ] + + +async def multi_query_chroma( + queries: list[str], + collection: str = "marvin", + n_results: int = 5, + where: Optional[dict[str, Any]] = None, + where_document: Optional[dict[str, Any]] = None, + include: Optional[list[QueryResultType]] = None, + max_characters: int = 2000, +) -> str: + """Retrieve excerpts to aid in answering multifacted questions. + + Example: + User: "What are prefect blocks and tasks?" + Assistant: >>> multi_query_chroma( + ["What are prefect blocks?", "What are prefect tasks?"] + ) + multi_query_chroma -> document excerpts explaining both blocks and tasks + """ + + coros = [ + query_chroma( + query, + collection, + n_results, + where, + where_document, + include, + max_characters // len(queries), ) + for query in queries + ] + return "\n".join(await asyncio.gather(*coros))[:max_characters] + + +def store_document( + document: str, metadata: dict[str, Any], collection_name: str = "glacial" +) -> GetResult: + """Store a document in Chroma for future reference. + + Args: + document: The document to store. + metadata: The metadata to store with the document. + + Returns: + The stored document. + """ + collection = get_http_client().get_or_create_collection( + name=collection_name, embedding_function=OpenAIEmbeddingFunction() + ) + doc_id = metadata.get("msg_id", str(uuid.uuid4())) + + collection.add( + ids=[doc_id], + documents=[document], + metadatas=[metadata], + ) - def ok(self) -> bool: - logger = raggy.utilities.logging.get_logger() - try: - version = self.client.get_version() - except Exception as e: - logger.error(f"Cannot connect to Chroma: {e}") - if re.match(r"^\d+\.\d+\.\d+$", version): - logger.debug(f"Connected to Chroma v{version}") - return True - return False - - async def __aenter__(self): - self._in_context = True - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - self._in_context = False + return collection.get(ids=doc_id)