Skip to content

Commit

Permalink
oops
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Feb 3, 2024
1 parent ef5c917 commit 63efa1d
Showing 1 changed file with 144 additions and 122 deletions.
266 changes: 144 additions & 122 deletions src/marvin/tools/chroma.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,158 @@
import re
from typing import Any, Iterable, Literal, Optional
import asyncio
import os
import uuid
from typing import TYPE_CHECKING, Any, Optional

try:
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Include, QueryResult
from chromadb import Documents, EmbeddingFunction, Embeddings, GetResult, HttpClient
except ImportError:
raise ImportError(
"You must have `chromadb` installed to use the Chroma vector store. "
"Install it with `pip install 'raggy[chroma]'`."
"The chromadb package is required to query Chroma. Please install"
" it with `pip install chromadb` or `pip install marvin[chroma]`."
)
import raggy
from pydantic import BaseModel, Field, model_validator
from raggy.documents import Document
from raggy.utils import get_distinct_documents

from marvin.tools.chroma import OpenAIEmbeddingFunction, get_client
from marvin.utilities.asyncio import run_async


class Chroma(BaseModel):
"""A wrapper for chromadb.Client - used as an async context manager"""

client_type: Literal["base", "http"] = "base"
embedding_fn: Any = Field(default_factory=OpenAIEmbeddingFunction)
collection: Optional[Collection] = None

_in_context: bool = False

@model_validator(mode="after")
def validate_collection(self):
if not self.collection:
self.collection = get_client(self.client_type).get_or_create_collection(
name="raggy", embedding_function=self.embedding_fn
)
return self

async def delete(
self,
ids: list[str] = None,
where: dict = None,
where_document: Document = None,
):
await run_async(
self.collection.delete,
ids=ids,
where=where,
where_document=where_document,
)

async def add(self, documents: list[Document]) -> Iterable[Document]:
documents = get_distinct_documents(documents)
kwargs = dict(
ids=[document.id for document in documents],
documents=[document.text for document in documents],
metadatas=[
document.metadata.model_dump(exclude_none=True) or None
for document in documents
],
embeddings=[document.embedding or [] for document in documents],
)

await run_async(self.collection.add, **kwargs)

get_result = await run_async(self.collection.get, ids=kwargs["ids"])

return get_result.get("documents")

async def query(
self,
query_embeddings: list[list[float]] = None,
query_texts: list[str] = None,
n_results: int = 10,
where: dict = None,
where_document: dict = None,
include: "Include" = ["metadatas"],
**kwargs,
) -> "QueryResult":
return await run_async(
self.collection.query,
query_embeddings=query_embeddings,
query_texts=query_texts,
n_results=n_results,
where=where,
where_document=where_document,
include=include,
**kwargs,
)
from typing import Literal

import marvin

async def count(self) -> int:
return await run_async(self.collection.count)

async def upsert(self, documents: list[Document]):
documents = get_distinct_documents(documents)
kwargs = dict(
ids=[document.id for document in documents],
documents=[document.text for document in documents],
metadatas=[
document.metadata.model_dump(exclude_none=True) or None
for document in documents
],
embeddings=[document.embedding or [] for document in documents],
if TYPE_CHECKING:
from openai.types import CreateEmbeddingResponse

QueryResultType = Literal["documents", "distances", "metadatas"]

try:
HOST, PORT = (
getattr(marvin.settings, "chroma_server_host"),
getattr(marvin.settings, "chroma_server_http_port"),
)
DEFAULT_COLLECTION_NAME = getattr(
marvin.settings, "chroma_default_collection_name", "marvin"
)
except AttributeError:
HOST = os.environ.get("MARVIN_CHROMA_SERVER_HOST", "localhost") # type: ignore
PORT = os.environ.get("MARVIN_CHROMA_SERVER_HTTP_PORT", 8000) # type: ignore
DEFAULT_COLLECTION_NAME = os.environ.get(
"MARVIN_CHROMA_DEFAULT_COLLECTION_NAME", "marvin"
)


def create_openai_embeddings(texts: list[str]) -> list[float]:
"""Create OpenAI embeddings for a list of texts."""

try:
import numpy # noqa F401 # type: ignore
except ImportError:
raise ImportError(
"The numpy package is required to create OpenAI embeddings. Please install"
" it with `pip install numpy`."
)
await run_async(self.collection.upsert, **kwargs)
from marvin.client.openai import MarvinClient

embedding: "CreateEmbeddingResponse" = MarvinClient().client.embeddings.create(
input=[text.replace("\n", " ") for text in texts],
model="text-embedding-ada-002",
)

get_result = await run_async(self.collection.get, ids=kwargs["ids"])
return embedding.data[0].embedding

return get_result.get("documents")

async def reset_collection(self):
"""Delete and recreate the collection."""
client = get_client(self.client_type)
await run_async(client.delete_collection, self.collection.name)
self.collection = await run_async(
client.create_collection,
name=self.collection.name,
embedding_function=self.embedding_fn,
class OpenAIEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
return [create_openai_embeddings(input)]


def get_http_client() -> HttpClient:
"""Get a Chroma HTTP client."""
return HttpClient(host=HOST, port=PORT)


async def query_chroma(
query: str,
collection: str = "marvin",
n_results: int = 5,
where: Optional[dict[str, Any]] = None,
where_document: Optional[dict[str, Any]] = None,
include: Optional[list[QueryResultType]] = None,
max_characters: int = 2000,
) -> str:
"""Query a collection of document excerpts for a query.
Example:
User: "What are prefect blocks?"
Assistant: >>> query_chroma("What are prefect blocks?")
"""
collection_object = get_http_client().get_or_create_collection(
name=collection or DEFAULT_COLLECTION_NAME,
embedding_function=OpenAIEmbeddingFunction(),
)
query_result = collection_object.query(
query_texts=[query],
n_results=n_results,
where=where,
where_document=where_document,
include=include or ["documents"],
)
return "".join(doc for doclist in query_result["documents"] for doc in doclist)[
:max_characters
]


async def multi_query_chroma(
queries: list[str],
collection: str = "marvin",
n_results: int = 5,
where: Optional[dict[str, Any]] = None,
where_document: Optional[dict[str, Any]] = None,
include: Optional[list[QueryResultType]] = None,
max_characters: int = 2000,
) -> str:
"""Retrieve excerpts to aid in answering multifacted questions.
Example:
User: "What are prefect blocks and tasks?"
Assistant: >>> multi_query_chroma(
["What are prefect blocks?", "What are prefect tasks?"]
)
multi_query_chroma -> document excerpts explaining both blocks and tasks
"""

coros = [
query_chroma(
query,
collection,
n_results,
where,
where_document,
include,
max_characters // len(queries),
)
for query in queries
]
return "\n".join(await asyncio.gather(*coros))[:max_characters]


def store_document(
document: str, metadata: dict[str, Any], collection_name: str = "glacial"
) -> GetResult:
"""Store a document in Chroma for future reference.
Args:
document: The document to store.
metadata: The metadata to store with the document.
Returns:
The stored document.
"""
collection = get_http_client().get_or_create_collection(
name=collection_name, embedding_function=OpenAIEmbeddingFunction()
)
doc_id = metadata.get("msg_id", str(uuid.uuid4()))

collection.add(
ids=[doc_id],
documents=[document],
metadatas=[metadata],
)

def ok(self) -> bool:
logger = raggy.utilities.logging.get_logger()
try:
version = self.client.get_version()
except Exception as e:
logger.error(f"Cannot connect to Chroma: {e}")
if re.match(r"^\d+\.\d+\.\d+$", version):
logger.debug(f"Connected to Chroma v{version}")
return True
return False

async def __aenter__(self):
self._in_context = True
return self

async def __aexit__(self, exc_type, exc_value, traceback):
self._in_context = False
return collection.get(ids=doc_id)

0 comments on commit 63efa1d

Please sign in to comment.