From 1850c9f34931488cbeee4aeaea00e63b02d7ebd5 Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Wed, 25 Dec 2024 23:00:09 +0100 Subject: [PATCH] rename --- .gitignore | 2 + docker-compose.yml | 25 +++ no-ocr-api/api.py | 148 ++++++++-------- no-ocr-api/evaluate_synthetic_data.py | 160 ++++++++++++++++++ no-ocr-api/llm-inference/llm_serving.py | 148 ++++++++++++++++ .../llm-inference/llm_serving_colpali.py | 110 ++++++++++++ .../llm-inference/llm_serving_load_models.py | 60 +++++++ no-ocr-ui/src/App.tsx | 12 +- .../components/{Collections.tsx => Case.tsx} | 10 +- .../{CreateCollection.tsx => CreateCase.tsx} | 30 ++-- no-ocr-ui/src/components/Navbar.tsx | 6 +- no-ocr-ui/src/components/Search.tsx | 48 +++--- no-ocr-ui/src/components/about/Features.tsx | 2 +- .../{CollectionCard.tsx => CaseCard.tsx} | 26 +-- .../src/components/collections/CaseList.tsx | 49 ++++++ .../components/collections/CollectionList.tsx | 49 ------ no-ocr-ui/src/components/layout/Navbar.tsx | 8 +- .../components/search/CollectionSelect.tsx | 38 ++--- no-ocr-ui/src/types/collection.ts | 5 +- no-ocr-ui/src/types/index.ts | 2 +- 20 files changed, 723 insertions(+), 215 deletions(-) create mode 100644 docker-compose.yml create mode 100644 no-ocr-api/evaluate_synthetic_data.py create mode 100644 no-ocr-api/llm-inference/llm_serving.py create mode 100644 no-ocr-api/llm-inference/llm_serving_colpali.py create mode 100644 no-ocr-api/llm-inference/llm_serving_load_models.py rename no-ocr-ui/src/components/{Collections.tsx => Case.tsx} (54%) rename no-ocr-ui/src/components/{CreateCollection.tsx => CreateCase.tsx} (85%) rename no-ocr-ui/src/components/collections/{CollectionCard.tsx => CaseCard.tsx} (62%) create mode 100644 no-ocr-ui/src/components/collections/CaseList.tsx delete mode 100644 no-ocr-ui/src/components/collections/CollectionList.tsx diff --git a/.gitignore b/.gitignore index b439cc1..2ad571c 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,5 @@ colpali/ data/ .DS_Store no-ocr-api/storage +example/ +no-ocr-api/vllm_cache/ \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..b6a048f --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,25 @@ +version: '3.8' + +services: + ui: + build: ./no-ocr-ui + ports: + - "3000:3000" + + api: + build: ./no-ocr-api + volumes: + - api-storage:/app/storage + ports: + - "5000:5000" + + qdrant: + image: qdrant/qdrant:v1.12.5 + volumes: + - qdrant-storage:/qdrant/storage + ports: + - "6333:6333" + +volumes: + api-storage: + qdrant-storage: \ No newline at end of file diff --git a/no-ocr-api/api.py b/no-ocr-api/api.py index 7765cca..dd6211d 100644 --- a/no-ocr-api/api.py +++ b/no-ocr-api/api.py @@ -108,9 +108,9 @@ def __init__(self, qdrant_uri: str = settings.QDRANT_URI): self.qdrant_client = QdrantClient(qdrant_uri, port=settings.QDRANT_PORT, https=settings.QDRANT_HTTPS) self.colpali_client = ColPaliClient() - def ingest(self, collection_name, dataset): + def ingest(self, case_name, dataset): self.qdrant_client.create_collection( - collection_name=collection_name, + collection_name=case_name, on_disk_payload=True, optimizers_config=models.OptimizersConfigDiff( indexing_threshold=settings.INDEXING_THRESHOLD @@ -155,7 +155,7 @@ def ingest(self, collection_name, dataset): # Upload point to Qdrant try: self.qdrant_client.upsert( - collection_name=collection_name, + collection_name=case_name, points=[point], wait=False, ) @@ -174,7 +174,7 @@ def __init__(self, qdrant_uri: str = settings.QDRANT_URI): self.qdrant_client = QdrantClient(qdrant_uri, port=settings.QDRANT_PORT, https=settings.QDRANT_HTTPS) self.colpali_client = ColPaliClient() - def search_images_by_text(self, query_text, collection_name: str, top_k=settings.TOP_K): + def search_images_by_text(self, query_text, case_name: str, top_k=settings.TOP_K): # Use ColPaliClient to query text and get the embedding query_embedding = self.colpali_client.query_text(query_text) @@ -183,7 +183,7 @@ def search_images_by_text(self, query_text, collection_name: str, top_k=settings # Search in Qdrant search_result = self.qdrant_client.query_points( - collection_name=collection_name, query=multivector_query, limit=top_k + collection_name=case_name, query=multivector_query, limit=top_k ) return search_result @@ -328,33 +328,33 @@ def vllm_call( @app.post("/search") def ai_search( user_query: str = Form(...), - collection_name: str = Form(...) + case_name: str = Form(...) ): """ - Given a user query and collection name, search relevant images in the Qdrant index + Given a user query and case name, search relevant images in the Qdrant index and return both the results and an LLM interpretation. """ if not os.path.exists(settings.STORAGE_DIR): raise HTTPException(status_code=404, detail="No collections found.") - collection_info_path = os.path.join(settings.STORAGE_DIR, collection_name, settings.COLLECTION_INFO_FILENAME) - if not os.path.exists(collection_info_path): - raise HTTPException(status_code=404, detail="Collection info not found.") + case_info_path = os.path.join(settings.STORAGE_DIR, case_name, settings.COLLECTION_INFO_FILENAME) + if not os.path.exists(case_info_path): + raise HTTPException(status_code=404, detail="Case info not found.") - with open(collection_info_path, "r") as json_file: - _ = json.load(json_file) # collection_info is not used directly below + with open(case_info_path, "r") as json_file: + _ = json.load(json_file) # case_info is not used directly below search_results = search_client.search_images_by_text( user_query, - collection_name=collection_name, + case_name=case_name, top_k=settings.SEARCH_TOP_K ) if not search_results: return {"message": "No results found."} - dataset_path = os.path.join(settings.STORAGE_DIR, collection_name, settings.HF_DATASET_DIRNAME) + dataset_path = os.path.join(settings.STORAGE_DIR, case_name, settings.HF_DATASET_DIRNAME) if not os.path.exists(dataset_path): - raise HTTPException(status_code=404, detail="Dataset for this collection not found.") + raise HTTPException(status_code=404, detail="Dataset for this case not found.") dataset = load_from_disk(dataset_path) search_results_data = [] @@ -380,105 +380,105 @@ def ai_search( return {"search_results": search_results_data} -@app.post("/create_collection") -def create_new_collection( +@app.post("/create_case") +def create_new_case( files: List[UploadFile] = File(...), - collection_name: str = Form(...) + case_name: str = Form(...) ): """ - Create a new collection, store the uploaded PDFs, and process/ingest them. + Create a new case, store the uploaded PDFs, and process/ingest them. """ - if not files or not collection_name: - raise HTTPException(status_code=400, detail="No files or collection name provided.") + if not files or not case_name: + raise HTTPException(status_code=400, detail="No files or case name provided.") - collection_dir = f"{settings.STORAGE_DIR}/{collection_name}" - os.makedirs(collection_dir, exist_ok=True) + case_dir = f"{settings.STORAGE_DIR}/{case_name}" + os.makedirs(case_dir, exist_ok=True) file_names = [] for uploaded_file in files: - file_path = os.path.join(collection_dir, uploaded_file.filename) + file_path = os.path.join(case_dir, uploaded_file.filename) with open(file_path, "wb") as f: f.write(uploaded_file.file.read()) file_names.append(uploaded_file.filename) - collection_info = { - "name": collection_name, + case_info = { + "name": case_name, "status": "processing", "number_of_PDFs": len(files), "files": file_names } - with open(os.path.join(collection_dir, settings.COLLECTION_INFO_FILENAME), "w") as json_file: - json.dump(collection_info, json_file) + with open(os.path.join(case_dir, settings.COLLECTION_INFO_FILENAME), "w") as json_file: + json.dump(case_info, json_file) # Process and ingest - dataset = pdfs_to_hf_dataset(collection_dir) - dataset.save_to_disk(os.path.join(collection_dir, settings.HF_DATASET_DIRNAME)) - ingest_client.ingest(collection_name, dataset) - collection_info['status'] = 'done' + dataset = pdfs_to_hf_dataset(case_dir) + dataset.save_to_disk(os.path.join(case_dir, settings.HF_DATASET_DIRNAME)) + ingest_client.ingest(case_name, dataset) + case_info['status'] = 'done' - with open(os.path.join(collection_dir, settings.COLLECTION_INFO_FILENAME), "w") as json_file: - json.dump(collection_info, json_file) + with open(os.path.join(case_dir, settings.COLLECTION_INFO_FILENAME), "w") as json_file: + json.dump(case_info, json_file) return { - "message": f"Uploaded {len(files)} PDFs to collection '{collection_name}'", - "collection_info": collection_info + "message": f"Uploaded {len(files)} PDFs to case '{case_name}'", + "case_info": case_info } -@app.get("/get_collections") -def get_collections(): +@app.get("/get_cases") +def get_cases(): """ - Return a list of all previously uploaded collections with their metadata. + Return a list of all previously uploaded cases with their metadata. """ if not os.path.exists(settings.STORAGE_DIR): - return {"message": "No collections found.", "collections": []} + return {"message": "No cases found.", "cases": []} - collections = os.listdir(settings.STORAGE_DIR) - collection_data = [] + cases = os.listdir(settings.STORAGE_DIR) + case_data = [] - for collection in collections: - collection_info_path = os.path.join(settings.STORAGE_DIR, collection, settings.COLLECTION_INFO_FILENAME) - if os.path.exists(collection_info_path): - with open(collection_info_path, "r") as json_file: - collection_info = json.load(json_file) - collection_data.append(collection_info) + for case in cases: + case_info_path = os.path.join(settings.STORAGE_DIR, case, settings.COLLECTION_INFO_FILENAME) + if os.path.exists(case_info_path): + with open(case_info_path, "r") as json_file: + case_info = json.load(json_file) + case_data.append(case_info) - if not collection_data: - return {"message": "No collection data found.", "collections": []} - return {"collections": collection_data} + if not case_data: + return {"message": "No case data found.", "cases": []} + return {"cases": case_data} -@app.delete("/delete_all_collections") -def delete_all_collections(): +@app.delete("/delete_all_cases") +def delete_all_cases(): """ - Delete all collections from storage and Qdrant. + Delete all cases from storage and Qdrant. """ - # Delete all collections from storage + # Delete all cases from storage if os.path.exists(settings.STORAGE_DIR): - for collection in os.listdir(settings.STORAGE_DIR): - shutil.rmtree(os.path.join(settings.STORAGE_DIR, collection)) + for case in os.listdir(settings.STORAGE_DIR): + shutil.rmtree(os.path.join(settings.STORAGE_DIR, case)) - # Delete all collections from Qdrant - collections = ingest_client.qdrant_client.get_collections().collections - for collection in collections: - ingest_client.qdrant_client.delete_collection(collection.name) + # Delete all cases from Qdrant + cases = ingest_client.qdrant_client.get_collections().collections + for case in cases: + ingest_client.qdrant_client.delete_collection(case.name) - return {"message": "All collections have been deleted from storage and Qdrant."} + return {"message": "All cases have been deleted from storage and Qdrant."} -@app.delete("/delete_collection/{collection_name}") -def delete_collection(collection_name: str): +@app.delete("/delete_case/{case_name}") +def delete_case(case_name: str): """ - Delete a specific collection from storage and Qdrant. + Delete a specific case from storage and Qdrant. """ - # Delete the collection from storage - collection_dir = os.path.join(settings.STORAGE_DIR, collection_name) - if os.path.exists(collection_dir): - shutil.rmtree(collection_dir) + # Delete the case from storage + case_dir = os.path.join(settings.STORAGE_DIR, case_name) + if os.path.exists(case_dir): + shutil.rmtree(case_dir) else: - raise HTTPException(status_code=404, detail="Collection not found in storage.") + raise HTTPException(status_code=404, detail="Case not found in storage.") - # Delete the collection from Qdrant + # Delete the case from Qdrant try: - ingest_client.qdrant_client.delete_collection(collection_name) + ingest_client.qdrant_client.delete_collection(case_name) except Exception as e: - raise HTTPException(status_code=500, detail=f"An error occurred while deleting the collection from Qdrant: {str(e)}") + raise HTTPException(status_code=500, detail=f"An error occurred while deleting the case from Qdrant: {str(e)}") - return {"message": f"Collection '{collection_name}' has been deleted from storage and Qdrant."} + return {"message": f"Case '{case_name}' has been deleted from storage and Qdrant."} diff --git a/no-ocr-api/evaluate_synthetic_data.py b/no-ocr-api/evaluate_synthetic_data.py new file mode 100644 index 0000000..4fe679a --- /dev/null +++ b/no-ocr-api/evaluate_synthetic_data.py @@ -0,0 +1,160 @@ +import base64 +import os +import random +from io import BytesIO +from typing import Dict, List + +import PIL +import typer +from colpali_engine.trainer.eval_utils import CustomRetrievalEvaluator +from datasets import Dataset, load_dataset +from openai import OpenAI +from pydantic import BaseModel +from rich import print +from rich.table import Table +from tqdm import tqdm + +from ai_search_demo.qdrant_inexing import SearchClient, pdfs_to_hf_dataset, IngestClient + +# Initialize OpenAI client +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + +class DataSample(BaseModel): + japanese_query: str + english_query: str + +def generate_synthetic_question(image: PIL.Image.Image) -> DataSample: + # Convert PIL image to base64 string + buffered = BytesIO() + image.save(buffered, format="JPEG") + image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + + prompt = """ + I am developing a visual retrieval dataset to evaluate my system. + Based on the image I provided, I want you to generate a query that this image will satisfy. + For example, if a user types this query into the search box, this image would be extremely relevant. + Generate the query in Japanese and English. + """ + # Generate synthetic question using OpenAI + chat_completion = client.beta.chat.completions.parse( + model="gpt-4o", + response_format=DataSample, + temperature=1, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, + }, + ], + } + ], + ) + + sample = chat_completion.choices[0].message.parsed + return sample + +def create_synthetic_dataset(input_folder: str, output_folder: str, hub_repo: str, num_samples: int = 10) -> None: + # Step 1: Read all PDFs and extract info + dataset = pdfs_to_hf_dataset(input_folder) + + # Step 2: Randomly sample data points + if num_samples > len(dataset): + indices = random.choices(range(len(dataset)), k=num_samples) + sampled_data = dataset.select(indices) + else: + sampled_data = dataset.shuffle().select(range(num_samples)) + + synthetic_data: List[Dict] = [] + + for index, data_point in enumerate(tqdm(sampled_data, desc="Generating synthetic questions")): + image = data_point['image'] + pdf_name = data_point['pdf_name'] + pdf_page = data_point['pdf_page'] + + # Step 3: Generate synthetic question + sample = generate_synthetic_question(image) + + # Step 4: Store samples in a new dataset + synthetic_data.append({ + "index": index, + "image": image, + "question_en": sample.english_query, + "question_jp": sample.japanese_query, + "pdf_name": pdf_name, + "pdf_page": pdf_page + }) + + # Create a new dataset from synthetic data + synthetic_dataset = Dataset.from_list(synthetic_data) + synthetic_dataset.save_to_disk(output_folder) + + # Save the dataset card + synthetic_dataset.push_to_hub(hub_repo, private=False) + +def evaluate_on_synthetic_dataset(hub_repo: str, collection_name: str = "synthetic-dataset-evaluate-full") -> None: + # Ingest collection with IngestClient + print("Load data") + synthetic_dataset = load_dataset(hub_repo)['train'] + + print("Ingest data to qdrant") + # ingest_client = IngestClient() + # ingest_client.ingest(collection_name, synthetic_dataset) + + run_evaluation(synthetic_dataset=synthetic_dataset, collection_name=collection_name, query_text_key='question_en') + run_evaluation(synthetic_dataset=synthetic_dataset, collection_name=collection_name, query_text_key='question_jp') + +def run_evaluation(synthetic_dataset: Dataset, collection_name: str, query_text_key: str) -> None: + search_client = SearchClient() + relevant_docs: Dict[str, Dict[str, int]] = {} + results: Dict[str, Dict[str, float]] = {} + + for x in synthetic_dataset: + query_id = f"{x['pdf_name']}_{x['pdf_page']}" + relevant_docs[query_id] = {query_id: 1} # The most relevant document is itself + + response = search_client.search_images_by_text(query_text=x[query_text_key], collection_name=collection_name, top_k=10) + + results[query_id] = {} + for point in response.points: + doc_id = f"{point.payload['pdf_name']}_{point.payload['pdf_page']}" + results[query_id][doc_id] = point.score + + mteb_evaluator = CustomRetrievalEvaluator() + + ndcg, _map, recall, precision, naucs = mteb_evaluator.evaluate( + relevant_docs, + results, + mteb_evaluator.k_values, + ) + + mrr = mteb_evaluator.evaluate_custom(relevant_docs, results, mteb_evaluator.k_values, "mrr") + + scores = { + **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, + **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, + **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, + **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, + **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()}, + **{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()}, + } + + # Use rich to print scores beautifully + table = Table(title=f"Evaluation Scores for {query_text_key}") + table.add_column("Metric", justify="right", style="cyan", no_wrap=True) + table.add_column("Score", style="magenta") + + for metric, score in scores.items(): + table.add_row(metric, f"{score:.4f}") + + print(table) + + +if __name__ == '__main__': + app = typer.Typer() + app.command()(create_synthetic_dataset) + app.command()(evaluate_on_synthetic_dataset) + app() \ No newline at end of file diff --git a/no-ocr-api/llm-inference/llm_serving.py b/no-ocr-api/llm-inference/llm_serving.py new file mode 100644 index 0000000..35c7c5e --- /dev/null +++ b/no-ocr-api/llm-inference/llm_serving.py @@ -0,0 +1,148 @@ +import modal + +vllm_image = modal.Image.debian_slim(python_version="3.12").pip_install( + "vllm==0.6.3post1", "fastapi[standard]==0.115.4" +) + + +MODELS_DIR = "/models" +MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct" + + +try: + volume = modal.Volume.lookup("models", create_if_missing=False) +except modal.exception.NotFoundError: + raise Exception("Download models first with modal run download_llama.py") + + + +app = modal.App("qwen2-vllm") + +N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count +TOKEN = "super-secret-token" # auth token. for production use, replace with a modal.Secret + +MINUTES = 60 # seconds +HOURS = 60 * MINUTES + + +@app.function( + image=vllm_image, + gpu=modal.gpu.A100(count=N_GPU), + keep_warm=0, + container_idle_timeout=1 * MINUTES, + timeout=24 * HOURS, + allow_concurrent_inputs=1000, + volumes={MODELS_DIR: volume}, +) +@modal.asgi_app() +def serve(): + import fastapi + import vllm.entrypoints.openai.api_server as api_server + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.entrypoints.logger import RequestLogger + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import ( + OpenAIServingCompletion, + ) + from vllm.entrypoints.openai.serving_engine import BaseModelPath + from vllm.usage.usage_lib import UsageContext + + volume.reload() # ensure we have the latest version of the weights + + # create a fastAPI app that uses vLLM's OpenAI-compatible router + web_app = fastapi.FastAPI( + title=f"OpenAI-compatible {MODEL_NAME} server", + description="Run an OpenAI-compatible LLM server with vLLM on modal.com 🚀", + version="0.0.1", + docs_url="/docs", + ) + + # security: CORS middleware for external requests + http_bearer = fastapi.security.HTTPBearer( + scheme_name="Bearer Token", + description="See code for authentication details.", + ) + web_app.add_middleware( + fastapi.middleware.cors.CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # security: inject dependency on authed routes + async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): + if api_key.credentials != TOKEN: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + ) + return {"username": "authenticated_user"} + + router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) + + # wrap vllm's router in auth router + router.include_router(api_server.router) + # add authed vllm to our fastAPI app + web_app.include_router(router) + + engine_args = AsyncEngineArgs( + model=MODELS_DIR + "/" + MODEL_NAME, + tensor_parallel_size=N_GPU, + gpu_memory_utilization=0.90, + max_model_len=8096, + enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s) + ) + + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER + ) + + model_config = get_model_config(engine) + + request_logger = RequestLogger(max_log_len=2048) + + base_model_paths = [ + BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) + ] + + api_server.chat = lambda s: OpenAIServingChat( + engine, + model_config=model_config, + base_model_paths=base_model_paths, + chat_template=None, + response_role="assistant", + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + api_server.completion = lambda s: OpenAIServingCompletion( + engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + + return web_app + + +def get_model_config(engine): + import asyncio + + try: # adapted from vLLM source -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1 + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running(): + # If the current is instanced by Ray Serve, + # there is already a running event loop + model_config = event_loop.run_until_complete(engine.get_model_config()) + else: + # When using single vLLM without engine_use_ray + model_config = asyncio.run(engine.get_model_config()) + + return model_config \ No newline at end of file diff --git a/no-ocr-api/llm-inference/llm_serving_colpali.py b/no-ocr-api/llm-inference/llm_serving_colpali.py new file mode 100644 index 0000000..c39e0a6 --- /dev/null +++ b/no-ocr-api/llm-inference/llm_serving_colpali.py @@ -0,0 +1,110 @@ +import modal + +vllm_image = modal.Image.debian_slim(python_version="3.12").pip_install( + "vllm==0.6.3post1", "fastapi[standard]==0.115.4" +).pip_install("colpali-engine") + +MODELS_DIR = "/models" +MODEL_NAME = "vidore/colqwen2-v1.0-merged" + + +try: + volume = modal.Volume.lookup("models", create_if_missing=False) +except modal.exception.NotFoundError: + raise Exception("Download models first with modal run download_llama.py") + + + +app = modal.App("colpali-embedding") + +N_GPU = 1 +TOKEN = "super-secret-token" + +MINUTES = 60 # seconds +HOURS = 60 * MINUTES + + +@app.function( + image=vllm_image, + gpu=modal.gpu.A100(count=N_GPU), + keep_warm=0, + container_idle_timeout=1 * MINUTES, + timeout=24 * HOURS, + allow_concurrent_inputs=1000, + volumes={MODELS_DIR: volume}, +) +@modal.asgi_app() +def serve(): + import fastapi + import torch + from colpali_engine.models import ColQwen2, ColQwen2Processor + from fastapi import APIRouter, Depends, HTTPException, Security + from fastapi.middleware.cors import CORSMiddleware + from fastapi.security import HTTPBearer + + volume.reload() # ensure we have the latest version of the weights + + # create a fastAPI app for serving the ColPali model + web_app = fastapi.FastAPI( + title=f"ColPali {MODEL_NAME} server", + description="Run a ColPali model server with fastAPI on modal.com 🚀", + version="0.0.1", + docs_url="/docs", + ) + + # security: CORS middleware for external requests + http_bearer = HTTPBearer( + scheme_name="Bearer Token", + description="See code for authentication details.", + ) + web_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # security: inject dependency on authed routes + async def is_authenticated(api_key: str = Security(http_bearer)): + if api_key.credentials != TOKEN: + raise HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + ) + return {"username": "authenticated_user"} + + router = APIRouter(dependencies=[Depends(is_authenticated)]) + + # Define the model and processor + model_name = "/models/vidore/colqwen2-v1.0-merged" + colpali_model = ColQwen2.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="cuda:0", + ).eval() + + colpali_processor = ColQwen2Processor.from_pretrained(model_name) + + # Define a simple endpoint to process text queries + @router.post("/query") + async def query_model(query_text: str): + with torch.no_grad(): + batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device) + query_embedding = colpali_model(**batch_query) + return {"embedding": query_embedding[0].cpu().float().numpy().tolist()} + + @router.post("/process_image") + async def process_image(image: fastapi.UploadFile): + from PIL import Image + pil_image = Image.open(image.file) + with torch.no_grad(): + batch_image = colpali_processor.process_images([pil_image]).to(colpali_model.device) + image_embedding = colpali_model(**batch_image) + return {"embedding": image_embedding[0].cpu().float().numpy().tolist()} + + + # add authed router to our fastAPI app + web_app.include_router(router) + + return web_app diff --git a/no-ocr-api/llm-inference/llm_serving_load_models.py b/no-ocr-api/llm-inference/llm_serving_load_models.py new file mode 100644 index 0000000..68a4942 --- /dev/null +++ b/no-ocr-api/llm-inference/llm_serving_load_models.py @@ -0,0 +1,60 @@ +import modal + +MODELS_DIR = "/models" + +DEFAULT_NAME = "Qwen/Qwen2.5-7B-Instruct" +DEFAULT_REVISION = "bb46c15ee4bb56c5b63245ef50fd7637234d6f75" + + +volume = modal.Volume.from_name("models", create_if_missing=True) + +image = ( + modal.Image.debian_slim(python_version="3.10") + .pip_install( + [ + "huggingface_hub", + "hf-transfer", + ] + ) + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) +) + + +MINUTES = 60 +HOURS = 60 * MINUTES + + +app = modal.App( + image=image, secrets=[modal.Secret.from_name("huggingface-secret")] +) + + +@app.function(volumes={MODELS_DIR: volume}, timeout=4 * HOURS) +def download_model(model_name, model_revision, force_download=False): + from huggingface_hub import snapshot_download + + volume.reload() + + snapshot_download( + model_name, + local_dir=MODELS_DIR + "/" + model_name, + ignore_patterns=[ + "*.pt", + "*.bin", + "*.pth", + "original/*", + ], # Ensure safetensors + revision=model_revision, + force_download=force_download, + ) + + volume.commit() + + +@app.local_entrypoint() +def main( + model_name: str = DEFAULT_NAME, + model_revision: str = DEFAULT_REVISION, + force_download: bool = False, +): + download_model.remote(model_name, model_revision, force_download) \ No newline at end of file diff --git a/no-ocr-ui/src/App.tsx b/no-ocr-ui/src/App.tsx index 820c229..09ba4b3 100644 --- a/no-ocr-ui/src/App.tsx +++ b/no-ocr-ui/src/App.tsx @@ -5,8 +5,8 @@ import Navbar from './components/layout/Navbar'; import { LoginForm } from './components/auth/LoginForm'; import { RegisterForm } from './components/auth/RegisterForm'; import { AuthGuard } from './components/auth/AuthGuard'; -import CreateCollection from './components/CreateCollection'; -import Collections from './components/Collections'; +import CreateCase from './components/CreateCase'; +import Cases from './components/Case'; import Search from './components/Search'; import About from './components/About'; @@ -42,8 +42,8 @@ export default function App() {
- } /> - } /> + } /> + } /> } /> } /> @@ -52,8 +52,8 @@ export default function App() { } > - } /> - } /> + } /> + } /> } /> } /> diff --git a/no-ocr-ui/src/components/Collections.tsx b/no-ocr-ui/src/components/Case.tsx similarity index 54% rename from no-ocr-ui/src/components/Collections.tsx rename to no-ocr-ui/src/components/Case.tsx index 88f2a6b..5067c1a 100644 --- a/no-ocr-ui/src/components/Collections.tsx +++ b/no-ocr-ui/src/components/Case.tsx @@ -1,19 +1,19 @@ -import { CollectionList } from './collections/CollectionList'; +import { CaseList } from './collections/CaseList'; -export default function Collections() { +export default function Cases() { return (
-

Collections

+

Cases

- Manage your PDF collections and their contents + Manage your PDF cases and their contents

- +
); diff --git a/no-ocr-ui/src/components/CreateCollection.tsx b/no-ocr-ui/src/components/CreateCase.tsx similarity index 85% rename from no-ocr-ui/src/components/CreateCollection.tsx rename to no-ocr-ui/src/components/CreateCase.tsx index 6849fa0..92fdf9d 100644 --- a/no-ocr-ui/src/components/CreateCollection.tsx +++ b/no-ocr-ui/src/components/CreateCase.tsx @@ -2,8 +2,8 @@ import React, { useState } from 'react'; import { Upload, Loader2 } from 'lucide-react'; import { noOcrApiUrl } from '../config/api'; -export default function CreateCollection() { - const [collectionName, setCollectionName] = useState(''); +export default function CreateCase() { + const [caseName, setCaseName] = useState(''); const [files, setFiles] = useState(null); const [isUploading, setIsUploading] = useState(false); const [uploadProgress, setUploadProgress] = useState(0); @@ -11,7 +11,7 @@ export default function CreateCollection() { const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); - if (!files || !collectionName) return; + if (!files || !caseName) return; setIsUploading(true); @@ -28,10 +28,10 @@ export default function CreateCollection() { try { const formData = new FormData(); - formData.append('collection_name', collectionName); + formData.append('case_name', caseName); Array.from(files).forEach(file => formData.append('files', file)); - const response = await fetch(`${noOcrApiUrl}/create_collection`, { + const response = await fetch(`${noOcrApiUrl}/create_case`, { method: 'POST', body: formData, }); @@ -47,7 +47,7 @@ export default function CreateCollection() { setUploadProgress(100); setTimeout(() => { setIsUploading(false); - setCollectionName(''); + setCaseName(''); setFiles(null); setUploadProgress(0); }, 500); @@ -61,20 +61,20 @@ export default function CreateCollection() { return (
-

Create New Collection

+

Create New Case

-
@@ -124,7 +124,7 @@ export default function CreateCollection() {
diff --git a/no-ocr-ui/src/components/Navbar.tsx b/no-ocr-ui/src/components/Navbar.tsx index 3cdd39a..cc28363 100644 --- a/no-ocr-ui/src/components/Navbar.tsx +++ b/no-ocr-ui/src/components/Navbar.tsx @@ -21,15 +21,15 @@ export default function Navbar() {
- Create Collection + Create Case ([]); const [isSearching, setIsSearching] = useState(false); const [answers, setAnswers] = useState<{ [key: number]: { is_answer: boolean, answer: string } }>({}); - const [collections, setCollections] = useState([]); + const [cases, setCases] = useState([]); const [isModalOpen, setIsModalOpen] = useState(false); const [modalImage, setModalImage] = useState(''); useEffect(() => { - async function fetchCollections() { + async function fetchCases() { try { - const response = await fetch(`${noOcrApiUrl}/get_collections`); + const response = await fetch(`${noOcrApiUrl}/get_cases`); if (!response.ok) throw new Error('Network response was not ok'); const data = await response.json(); - setCollections(data.collections || []); + setCases(data.cases || []); } catch (error) { - console.error('Error fetching collections:', error); + console.error('Error fetching cases:', error); } } - fetchCollections(); + fetchCases(); }, []); const handleSearch = async (e: React.FormEvent) => { e.preventDefault(); - if (!selectedCollection || !searchQuery) return; + if (!selectedCase || !searchQuery) return; setIsSearching(true); setResults([]); @@ -45,7 +45,7 @@ export default function Search() { }, body: new URLSearchParams({ user_query: searchQuery, - collection_name: selectedCollection, + case_name: selectedCase, }), }); @@ -54,7 +54,7 @@ export default function Search() { setResults(data.search_results || []); data.search_results.forEach((result: any, index: number) => { - fetchAnswer(searchQuery, selectedCollection, result.pdf_name, result.pdf_page) + fetchAnswer(searchQuery, selectedCase, result.pdf_name, result.pdf_page) .then(answer => { setAnswers(prevAnswers => ({ ...prevAnswers, [index]: answer })); }); @@ -66,7 +66,7 @@ export default function Search() { } }; - const fetchAnswer = async (userQuery: string, collectionName: string, pdfName: string, pdfPage: number) => { + const fetchAnswer = async (userQuery: string, caseName: string, pdfName: string, pdfPage: number) => { try { const response = await fetch(`${noOcrApiUrl}/vllm_call`, { method: 'POST', @@ -75,7 +75,7 @@ export default function Search() { }, body: new URLSearchParams({ user_query: userQuery, - collection_name: collectionName, + case_name: caseName, pdf_name: pdfName, pdf_page: pdfPage.toString(), }), @@ -106,21 +106,21 @@ export default function Search() {
-
diff --git a/no-ocr-ui/src/components/about/Features.tsx b/no-ocr-ui/src/components/about/Features.tsx index 8468bea..56d93c3 100644 --- a/no-ocr-ui/src/components/about/Features.tsx +++ b/no-ocr-ui/src/components/about/Features.tsx @@ -21,7 +21,7 @@ export function Features() {
{ - if (!window.confirm('Are you sure you want to delete this collection?')) return; + if (!window.confirm('Are you sure you want to delete this case?')) return; setIsDeleting(true); try { - const response = await fetch(`${noOcrApiUrl}/delete_collection/${collection.name}`, { + const response = await fetch(`${noOcrApiUrl}/delete_case/${caseItem.name}`, { method: 'DELETE', }); - if (!response.ok) throw new Error('Failed to delete collection'); - // Collection will be removed from the list by the parent's useEffect + if (!response.ok) throw new Error('Failed to delete case'); + // Case will be removed from the list by the parent's useEffect } catch (error) { - console.error('Error deleting collection:', error); - alert('Failed to delete collection'); + console.error('Error deleting case:', error); + alert('Failed to delete case'); } finally { setIsDeleting(false); } @@ -35,15 +35,15 @@ export function CollectionCard({ collection }: CollectionCardProps) {
-

{collection.name}

+

{caseItem.name}

- {collection.documentCount} document{collection.documentCount !== 1 ? 's' : ''} + {caseItem.documentCount} document{caseItem.documentCount !== 1 ? 's' : ''}

- Created {formatDate(collection.createdAt)} + Created {formatDate(caseItem.createdAt)}

diff --git a/no-ocr-ui/src/components/collections/CaseList.tsx b/no-ocr-ui/src/components/collections/CaseList.tsx new file mode 100644 index 0000000..3bd33ae --- /dev/null +++ b/no-ocr-ui/src/components/collections/CaseList.tsx @@ -0,0 +1,49 @@ +import { useEffect, useState } from 'react'; +import { Case } from '../../types/collection'; +import { CaseCard } from './CaseCard'; +import { LoadingSpinner } from '../shared/LoadingSpinner'; +import { EmptyState } from '../shared/EmptyState'; +import { noOcrApiUrl } from '../../config/api'; + +export function CaseList() { + const [cases, setCases] = useState([]); + const [isLoading, setIsLoading] = useState(true); + + useEffect(() => { + async function fetchCases() { + try { + const response = await fetch(`${noOcrApiUrl}/get_cases`); + if (!response.ok) throw new Error('Network response was not ok'); + const data = await response.json(); + setCases(data.cases || []); + } catch (error) { + console.error('Error fetching cases:', error); + } finally { + setIsLoading(false); + } + } + + fetchCases(); + }, []); + + if (isLoading) return ; + + if (cases.length === 0) { + return ( + + ); + } + + return ( +
+ {cases.map((caseItem) => ( + + ))} +
+ ); +} \ No newline at end of file diff --git a/no-ocr-ui/src/components/collections/CollectionList.tsx b/no-ocr-ui/src/components/collections/CollectionList.tsx deleted file mode 100644 index 24952c9..0000000 --- a/no-ocr-ui/src/components/collections/CollectionList.tsx +++ /dev/null @@ -1,49 +0,0 @@ -import { useEffect, useState } from 'react'; -import { Collection } from '../../types/collection'; -import { CollectionCard } from './CollectionCard'; -import { LoadingSpinner } from '../shared/LoadingSpinner'; -import { EmptyState } from '../shared/EmptyState'; -import { noOcrApiUrl } from '../../config/api'; - -export function CollectionList() { - const [collections, setCollections] = useState([]); - const [isLoading, setIsLoading] = useState(true); - - useEffect(() => { - async function fetchCollections() { - try { - const response = await fetch(`${noOcrApiUrl}/get_collections`); - if (!response.ok) throw new Error('Network response was not ok'); - const data = await response.json(); - setCollections(data.collections || []); - } catch (error) { - console.error('Error fetching collections:', error); - } finally { - setIsLoading(false); - } - } - - fetchCollections(); - }, []); - - if (isLoading) return ; - - if (collections.length === 0) { - return ( - - ); - } - - return ( -
- {collections.map((collection) => ( - - ))} -
- ); -} \ No newline at end of file diff --git a/no-ocr-ui/src/components/layout/Navbar.tsx b/no-ocr-ui/src/components/layout/Navbar.tsx index aad0964..d2cf098 100644 --- a/no-ocr-ui/src/components/layout/Navbar.tsx +++ b/no-ocr-ui/src/components/layout/Navbar.tsx @@ -28,14 +28,14 @@ export default function Navbar() { label="AI Search" /> } - label="Create Collection" + label="Create Case" /> } - label="Collections" + label="Cases" /> )} diff --git a/no-ocr-ui/src/components/search/CollectionSelect.tsx b/no-ocr-ui/src/components/search/CollectionSelect.tsx index d4aacd9..f734e66 100644 --- a/no-ocr-ui/src/components/search/CollectionSelect.tsx +++ b/no-ocr-ui/src/components/search/CollectionSelect.tsx @@ -1,32 +1,32 @@ -import { Collection } from '../../types/collection'; +import { Case } from '../../types/collection'; -interface CollectionSelectProps { - collections: Collection[]; - selectedCollection: string; - onCollectionChange: (collectionId: string) => void; +interface CaseSelectProps { + cases: Case[]; + selectedCase: string; + onCaseChange: (caseId: string) => void; } -export function CollectionSelect({ - collections, - selectedCollection, - onCollectionChange, -}: CollectionSelectProps) { +export function CaseSelect({ + cases, + selectedCase, + onCaseChange, +}: CaseSelectProps) { return (
-