Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
truskovskiyk committed Dec 25, 2024
1 parent 1bd8977 commit 1850c9f
Show file tree
Hide file tree
Showing 20 changed files with 723 additions and 215 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,5 @@ colpali/
data/
.DS_Store
no-ocr-api/storage
example/
no-ocr-api/vllm_cache/
25 changes: 25 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
version: '3.8'

services:
ui:
build: ./no-ocr-ui
ports:
- "3000:3000"

api:
build: ./no-ocr-api
volumes:
- api-storage:/app/storage
ports:
- "5000:5000"

qdrant:
image: qdrant/qdrant:v1.12.5
volumes:
- qdrant-storage:/qdrant/storage
ports:
- "6333:6333"

volumes:
api-storage:
qdrant-storage:
148 changes: 74 additions & 74 deletions no-ocr-api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def __init__(self, qdrant_uri: str = settings.QDRANT_URI):
self.qdrant_client = QdrantClient(qdrant_uri, port=settings.QDRANT_PORT, https=settings.QDRANT_HTTPS)
self.colpali_client = ColPaliClient()

def ingest(self, collection_name, dataset):
def ingest(self, case_name, dataset):
self.qdrant_client.create_collection(
collection_name=collection_name,
collection_name=case_name,
on_disk_payload=True,
optimizers_config=models.OptimizersConfigDiff(
indexing_threshold=settings.INDEXING_THRESHOLD
Expand Down Expand Up @@ -155,7 +155,7 @@ def ingest(self, collection_name, dataset):
# Upload point to Qdrant
try:
self.qdrant_client.upsert(
collection_name=collection_name,
collection_name=case_name,
points=[point],
wait=False,
)
Expand All @@ -174,7 +174,7 @@ def __init__(self, qdrant_uri: str = settings.QDRANT_URI):
self.qdrant_client = QdrantClient(qdrant_uri, port=settings.QDRANT_PORT, https=settings.QDRANT_HTTPS)
self.colpali_client = ColPaliClient()

def search_images_by_text(self, query_text, collection_name: str, top_k=settings.TOP_K):
def search_images_by_text(self, query_text, case_name: str, top_k=settings.TOP_K):
# Use ColPaliClient to query text and get the embedding
query_embedding = self.colpali_client.query_text(query_text)

Expand All @@ -183,7 +183,7 @@ def search_images_by_text(self, query_text, collection_name: str, top_k=settings

# Search in Qdrant
search_result = self.qdrant_client.query_points(
collection_name=collection_name, query=multivector_query, limit=top_k
collection_name=case_name, query=multivector_query, limit=top_k
)

return search_result
Expand Down Expand Up @@ -328,33 +328,33 @@ def vllm_call(
@app.post("/search")
def ai_search(
user_query: str = Form(...),
collection_name: str = Form(...)
case_name: str = Form(...)
):
"""
Given a user query and collection name, search relevant images in the Qdrant index
Given a user query and case name, search relevant images in the Qdrant index
and return both the results and an LLM interpretation.
"""
if not os.path.exists(settings.STORAGE_DIR):
raise HTTPException(status_code=404, detail="No collections found.")

collection_info_path = os.path.join(settings.STORAGE_DIR, collection_name, settings.COLLECTION_INFO_FILENAME)
if not os.path.exists(collection_info_path):
raise HTTPException(status_code=404, detail="Collection info not found.")
case_info_path = os.path.join(settings.STORAGE_DIR, case_name, settings.COLLECTION_INFO_FILENAME)
if not os.path.exists(case_info_path):
raise HTTPException(status_code=404, detail="Case info not found.")

with open(collection_info_path, "r") as json_file:
_ = json.load(json_file) # collection_info is not used directly below
with open(case_info_path, "r") as json_file:
_ = json.load(json_file) # case_info is not used directly below

search_results = search_client.search_images_by_text(
user_query,
collection_name=collection_name,
case_name=case_name,
top_k=settings.SEARCH_TOP_K
)
if not search_results:
return {"message": "No results found."}

dataset_path = os.path.join(settings.STORAGE_DIR, collection_name, settings.HF_DATASET_DIRNAME)
dataset_path = os.path.join(settings.STORAGE_DIR, case_name, settings.HF_DATASET_DIRNAME)
if not os.path.exists(dataset_path):
raise HTTPException(status_code=404, detail="Dataset for this collection not found.")
raise HTTPException(status_code=404, detail="Dataset for this case not found.")

dataset = load_from_disk(dataset_path)
search_results_data = []
Expand All @@ -380,105 +380,105 @@ def ai_search(

return {"search_results": search_results_data}

@app.post("/create_collection")
def create_new_collection(
@app.post("/create_case")
def create_new_case(
files: List[UploadFile] = File(...),
collection_name: str = Form(...)
case_name: str = Form(...)
):
"""
Create a new collection, store the uploaded PDFs, and process/ingest them.
Create a new case, store the uploaded PDFs, and process/ingest them.
"""
if not files or not collection_name:
raise HTTPException(status_code=400, detail="No files or collection name provided.")
if not files or not case_name:
raise HTTPException(status_code=400, detail="No files or case name provided.")

collection_dir = f"{settings.STORAGE_DIR}/{collection_name}"
os.makedirs(collection_dir, exist_ok=True)
case_dir = f"{settings.STORAGE_DIR}/{case_name}"
os.makedirs(case_dir, exist_ok=True)

file_names = []
for uploaded_file in files:
file_path = os.path.join(collection_dir, uploaded_file.filename)
file_path = os.path.join(case_dir, uploaded_file.filename)
with open(file_path, "wb") as f:
f.write(uploaded_file.file.read())
file_names.append(uploaded_file.filename)

collection_info = {
"name": collection_name,
case_info = {
"name": case_name,
"status": "processing",
"number_of_PDFs": len(files),
"files": file_names
}
with open(os.path.join(collection_dir, settings.COLLECTION_INFO_FILENAME), "w") as json_file:
json.dump(collection_info, json_file)
with open(os.path.join(case_dir, settings.COLLECTION_INFO_FILENAME), "w") as json_file:
json.dump(case_info, json_file)

# Process and ingest
dataset = pdfs_to_hf_dataset(collection_dir)
dataset.save_to_disk(os.path.join(collection_dir, settings.HF_DATASET_DIRNAME))
ingest_client.ingest(collection_name, dataset)
collection_info['status'] = 'done'
dataset = pdfs_to_hf_dataset(case_dir)
dataset.save_to_disk(os.path.join(case_dir, settings.HF_DATASET_DIRNAME))
ingest_client.ingest(case_name, dataset)
case_info['status'] = 'done'

with open(os.path.join(collection_dir, settings.COLLECTION_INFO_FILENAME), "w") as json_file:
json.dump(collection_info, json_file)
with open(os.path.join(case_dir, settings.COLLECTION_INFO_FILENAME), "w") as json_file:
json.dump(case_info, json_file)

return {
"message": f"Uploaded {len(files)} PDFs to collection '{collection_name}'",
"collection_info": collection_info
"message": f"Uploaded {len(files)} PDFs to case '{case_name}'",
"case_info": case_info
}

@app.get("/get_collections")
def get_collections():
@app.get("/get_cases")
def get_cases():
"""
Return a list of all previously uploaded collections with their metadata.
Return a list of all previously uploaded cases with their metadata.
"""
if not os.path.exists(settings.STORAGE_DIR):
return {"message": "No collections found.", "collections": []}
return {"message": "No cases found.", "cases": []}

collections = os.listdir(settings.STORAGE_DIR)
collection_data = []
cases = os.listdir(settings.STORAGE_DIR)
case_data = []

for collection in collections:
collection_info_path = os.path.join(settings.STORAGE_DIR, collection, settings.COLLECTION_INFO_FILENAME)
if os.path.exists(collection_info_path):
with open(collection_info_path, "r") as json_file:
collection_info = json.load(json_file)
collection_data.append(collection_info)
for case in cases:
case_info_path = os.path.join(settings.STORAGE_DIR, case, settings.COLLECTION_INFO_FILENAME)
if os.path.exists(case_info_path):
with open(case_info_path, "r") as json_file:
case_info = json.load(json_file)
case_data.append(case_info)

if not collection_data:
return {"message": "No collection data found.", "collections": []}
return {"collections": collection_data}
if not case_data:
return {"message": "No case data found.", "cases": []}
return {"cases": case_data}

@app.delete("/delete_all_collections")
def delete_all_collections():
@app.delete("/delete_all_cases")
def delete_all_cases():
"""
Delete all collections from storage and Qdrant.
Delete all cases from storage and Qdrant.
"""
# Delete all collections from storage
# Delete all cases from storage
if os.path.exists(settings.STORAGE_DIR):
for collection in os.listdir(settings.STORAGE_DIR):
shutil.rmtree(os.path.join(settings.STORAGE_DIR, collection))
for case in os.listdir(settings.STORAGE_DIR):
shutil.rmtree(os.path.join(settings.STORAGE_DIR, case))

# Delete all collections from Qdrant
collections = ingest_client.qdrant_client.get_collections().collections
for collection in collections:
ingest_client.qdrant_client.delete_collection(collection.name)
# Delete all cases from Qdrant
cases = ingest_client.qdrant_client.get_collections().collections
for case in cases:
ingest_client.qdrant_client.delete_collection(case.name)

return {"message": "All collections have been deleted from storage and Qdrant."}
return {"message": "All cases have been deleted from storage and Qdrant."}

@app.delete("/delete_collection/{collection_name}")
def delete_collection(collection_name: str):
@app.delete("/delete_case/{case_name}")
def delete_case(case_name: str):
"""
Delete a specific collection from storage and Qdrant.
Delete a specific case from storage and Qdrant.
"""
# Delete the collection from storage
collection_dir = os.path.join(settings.STORAGE_DIR, collection_name)
if os.path.exists(collection_dir):
shutil.rmtree(collection_dir)
# Delete the case from storage
case_dir = os.path.join(settings.STORAGE_DIR, case_name)
if os.path.exists(case_dir):
shutil.rmtree(case_dir)
else:
raise HTTPException(status_code=404, detail="Collection not found in storage.")
raise HTTPException(status_code=404, detail="Case not found in storage.")

# Delete the collection from Qdrant
# Delete the case from Qdrant
try:
ingest_client.qdrant_client.delete_collection(collection_name)
ingest_client.qdrant_client.delete_collection(case_name)
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred while deleting the collection from Qdrant: {str(e)}")
raise HTTPException(status_code=500, detail=f"An error occurred while deleting the case from Qdrant: {str(e)}")

return {"message": f"Collection '{collection_name}' has been deleted from storage and Qdrant."}
return {"message": f"Case '{case_name}' has been deleted from storage and Qdrant."}
Loading

0 comments on commit 1850c9f

Please sign in to comment.