Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update document name in models #57

Merged
merged 4 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
rev: v0.8.0
hooks:
# Run the linter.
- id: ruff
Expand Down
38 changes: 20 additions & 18 deletions cohere/compass/clients/compass.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def __init__(
"create_index": "/api/v1/indexes/{index_name}",
"list_indexes": "/api/v1/indexes",
"delete_index": "/api/v1/indexes/{index_name}",
"delete_document": "/api/v1/indexes/{index_name}/documents/{doc_id}",
"get_document": "/api/v1/indexes/{index_name}/documents/{doc_id}",
"delete_document": "/api/v1/indexes/{index_name}/documents/{document_id}",
"get_document": "/api/v1/indexes/{index_name}/documents/{document_id}",
"put_documents": "/api/v1/indexes/{index_name}/documents",
"search_documents": "/api/v1/indexes/{index_name}/documents/_search",
"search_chunks": "/api/v1/indexes/{index_name}/documents/_search_chunks",
"add_context": "/api/v1/indexes/{index_name}/documents/add_context/{doc_id}",
"add_context": "/api/v1/indexes/{index_name}/documents/add_context/{document_id}",
"refresh": "/api/v1/indexes/{index_name}/_refresh",
"upload_documents": "/api/v1/indexes/{index_name}/documents/_upload",
"edit_group_authorization": "/api/v1/indexes/{index_name}/group_authorization",
Expand Down Expand Up @@ -185,31 +185,31 @@ def delete_index(self, *, index_name: str):
index_name=index_name,
)

def delete_document(self, *, index_name: str, doc_id: str):
def delete_document(self, *, index_name: str, document_id: str):
"""
Delete a document from Compass
:param index_name: the name of the index
:doc_id: the id of the document
:document_id: the id of the document
:return: the response from the Compass API
"""
return self._send_request(
api_name="delete_document",
doc_id=doc_id,
document_id=document_id,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
index_name=index_name,
)

def get_document(self, *, index_name: str, doc_id: str):
def get_document(self, *, index_name: str, document_id: str):
"""
Get a document from Compass
:param index_name: the name of the index
:doc_id: the id of the document
:document_id: the id of the document
:return: the response from the Compass API
"""
return self._send_request(
api_name="get_document",
doc_id=doc_id,
document_id=document_id,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
index_name=index_name,
Expand All @@ -231,7 +231,7 @@ def add_context(
self,
*,
index_name: str,
doc_id: str,
document_id: str,
context: dict[str, Any],
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
Expand All @@ -240,15 +240,15 @@ def add_context(
Update the content field of an existing document with additional context

:param index_name: the name of the index
:param doc_id: the document to modify
:param document_id: the document to modify
:param context: A dictionary of key:value pairs to insert into the content field of a document
:param max_retries: the maximum number of times to retry a doc insertion
:param sleep_retry_seconds: number of seconds to go to sleep before retrying a doc insertion
"""

return self._send_request(
api_name="add_context",
doc_id=doc_id,
document_id=document_id,
data=context,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
Expand Down Expand Up @@ -373,7 +373,7 @@ def put_request(
compass_doc for compass_doc, _ in request_data
]
put_docs_input = PutDocumentsInput(
docs=[input_doc for _, input_doc in request_data],
documents=[input_doc for _, input_doc in request_data],
authorized_groups=authorized_groups,
merge_groups_on_conflict=merge_groups_on_conflict,
)
Expand Down Expand Up @@ -401,7 +401,7 @@ def put_request(
)
errors.append(
{
doc.metadata.doc_id: f"{doc.metadata.filename}: {results.error}"
doc.metadata.document_id: f"{doc.metadata.filename}: {results.error}"
}
)
else:
Expand Down Expand Up @@ -570,9 +570,11 @@ def _get_request_blocks(
num_chunks = 0
for _, doc in enumerate(docs, 1):
if doc.status != CompassDocumentStatus.Success:
logger.error(f"Document {doc.metadata.doc_id} has errors: {doc.errors}")
logger.error(
f"Document {doc.metadata.document_id} has errors: {doc.errors}"
)
for error in doc.errors:
errors.append({doc.metadata.doc_id: list(error.values())[0]})
errors.append({doc.metadata.document_id: list(error.values())[0]})
else:
num_chunks += (
len(doc.chunks)
Expand All @@ -588,8 +590,8 @@ def _get_request_blocks(
(
doc,
Document(
doc_id=doc.metadata.doc_id,
parent_doc_id=doc.metadata.parent_doc_id,
document_id=doc.metadata.document_id,
parent_document_id=doc.metadata.parent_document_id,
path=doc.metadata.filename,
content=doc.content,
chunks=[Chunk(**c.model_dump()) for c in doc.chunks],
Expand Down
37 changes: 36 additions & 1 deletion cohere/compass/clients/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def process_file(
docs: list[CompassDocument] = []
for doc in res.json()["docs"]:
if not doc.get("errors", []):
compass_doc = CompassDocument(**doc)
compass_doc = self._adapt_doc_id_compass_doc(doc)
additional_metadata = CompassParserClient._get_metadata(
doc=compass_doc, custom_context=custom_context
)
Expand All @@ -268,3 +268,38 @@ def process_file(
logger.error(f"Error processing file: {res.text}")

return docs

@staticmethod
def _adapt_doc_id_compass_doc(doc: Dict[Any, Any]) -> CompassDocument:
"""
Adapt the doc_id to document_id
"""

metadata = doc["metadata"]
if "document_id" not in metadata:
metadata["document_id"] = metadata.pop("doc_id")
metadata["parent_document_id"] = metadata.pop("parent_doc_id")

chunks = doc["chunks"]
for chunk in chunks:
if "parent_document_id" not in chunk:
chunk["parent_document_id"] = chunk.pop("parent_doc_id")
if "document_id" not in chunk:
chunk["document_id"] = chunk.pop("doc_id")
if "path" not in chunk:
chunk["path"] = doc["metadata"]["filename"]

res = CompassDocument(
filebytes=doc["filebytes"],
metadata=metadata,
content=doc["content"],
content_type=doc["content_type"],
elements=doc["elements"],
chunks=chunks,
index_fields=doc["index_fields"],
errors=doc["errors"],
ignore_metadata_errors=doc["ignore_metadata_errors"],
markdown=doc["markdown"],
)

return res
26 changes: 15 additions & 11 deletions cohere/compass/models/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class CompassDocumentMetadata(ValidatedModel):
Compass document metadata
"""

doc_id: str = ""
document_id: str = ""
filename: str = ""
meta: list[Any] = field(default_factory=list)
parent_doc_id: str = ""
parent_document_id: str = ""


class CompassDocumentChunkAsset(BaseModel):
Expand All @@ -30,14 +30,15 @@ class CompassDocumentChunkAsset(BaseModel):
class CompassDocumentChunk(BaseModel):
chunk_id: str
sort_id: str
doc_id: str
parent_doc_id: str
document_id: str
parent_document_id: str
content: Dict[str, Any]
origin: Optional[Dict[str, Any]] = None
assets: Optional[list[CompassDocumentChunkAsset]] = None
path: Optional[str] = ""

def parent_doc_is_split(self):
return self.doc_id != self.parent_doc_id
return self.document_id != self.parent_document_id


class CompassDocumentStatus(str, Enum):
Expand Down Expand Up @@ -140,23 +141,26 @@ class DocumentChunkAsset(BaseModel):
class Chunk(BaseModel):
chunk_id: str
sort_id: int
parent_document_id: str
path: str = ""
content: Dict[str, Any]
origin: Optional[Dict[str, Any]] = None
assets: Optional[list[DocumentChunkAsset]] = None
parent_doc_id: str
assets: Optional[List[DocumentChunkAsset]] = None
asset_ids: Optional[List[str]] = None


class Document(BaseModel):
"""
A document that can be indexed in Compass (i.e., a list of indexable chunks)
"""

doc_id: str
document_id: str
path: str
parent_doc_id: str
parent_document_id: str
content: Dict[str, Any]
chunks: List[Chunk]
index_fields: List[str] = field(default_factory=list)
index_fields: Optional[List[str]] = None
authorized_groups: Optional[List[str]] = None


class ParseableDocument(BaseModel):
Expand All @@ -183,6 +187,6 @@ class PutDocumentsInput(BaseModel):
A Compass request to put a list of Document
"""

docs: List[Document]
documents: List[Document]
authorized_groups: Optional[List[str]] = None
merge_groups_on_conflict: bool = False
Loading
Loading