diff --git a/docker-compose-test.yml b/docker-compose-test.yml index 80dc49f..ea27f71 100755 --- a/docker-compose-test.yml +++ b/docker-compose-test.yml @@ -43,7 +43,7 @@ services: worker-pdf-layout: container_name: "worker-pdf-layout" entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"] - image: ghcr.io/huridocs/pdf-document-layout-analysis:0.0.16 + image: ghcr.io/huridocs/pdf-document-layout-analysis:0.0.21 init: true restart: unless-stopped ports: diff --git a/docker-compose.yml b/docker-compose.yml index 68d3046..f052239 100755 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -40,7 +40,7 @@ services: worker-pdf-layout-gpu: container_name: "worker-pdf-layout-no-gpu" entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"] - image: ghcr.io/huridocs/pdf-document-layout-analysis:0.0.16 + image: ghcr.io/huridocs/pdf-document-layout-analysis:0.0.21 init: true restart: unless-stopped network_mode: host diff --git a/src/app.py b/src/app.py index 8559ec9..6eea5e7 100755 --- a/src/app.py +++ b/src/app.py @@ -3,6 +3,7 @@ from os.path import join import pymongo +import requests from fastapi import FastAPI, HTTPException, File, UploadFile import sys @@ -12,7 +13,7 @@ from starlette.responses import PlainTextResponse, FileResponse from catch_exceptions import catch_exceptions -from configuration import MONGO_HOST, MONGO_PORT, service_logger, OCR_OUTPUT +from configuration import MONGO_HOST, MONGO_PORT, service_logger, OCR_OUTPUT, DOCUMENT_LAYOUT_ANALYSIS_URL from PdfFile import PdfFile from get_paragraphs import get_paragraphs from get_xml import get_xml @@ -42,11 +43,17 @@ async def lifespan(app: FastAPI): @app.get("/") -async def info(): +async def root(): service_logger.info("Get PDF paragraphs info endpoint") return sys.version +@app.get("/info") +@catch_exceptions +async def info(): + return requests.get(f"{DOCUMENT_LAYOUT_ANALYSIS_URL}/info").json() + + @app.get("/error") async def error(): service_logger.error("This is a test error from the error endpoint") diff --git a/src/configuration.py b/src/configuration.py index 7e2b0d9..e571a57 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -4,7 +4,7 @@ from pathlib import Path import graypy -QUEUES_NAMES = os.environ.get("QUEUES_NAMES", "segmentation development_segmentation") +QUEUES_NAMES = os.environ.get("QUEUES_NAMES", "segmentation development_segmentation ocr development_ocr") SERVICE_HOST = os.environ.get("SERVICE_HOST", "http://127.0.0.1") SERVICE_PORT = os.environ.get("SERVICE_PORT", "5051") diff --git a/src/delete_queues.py b/src/delete_queues.py index 9cf2c8c..3381194 100644 --- a/src/delete_queues.py +++ b/src/delete_queues.py @@ -1,3 +1,5 @@ +from sys import maxsize + from redis import exceptions from rsmq import RedisSMQ @@ -10,25 +12,16 @@ def delete_queues(): try: for queue_name in QUEUES_NAMES.split(): - queue = RedisSMQ( - host=REDIS_HOST, - port=REDIS_PORT, - qname=queue_name + "_tasks", - quiet=False, - ) - - queue.deleteQueue().exceptions(False).execute() - queue.createQueue().exceptions(False).execute() - - queue = RedisSMQ( - host=REDIS_HOST, - port=REDIS_PORT, - qname=queue_name + "_results", - quiet=False, - ) - - queue.deleteQueue().exceptions(False).execute() - queue.createQueue().exceptions(False).execute() + for suffix in ["_tasks", "_results"]: + queue = RedisSMQ( + host=REDIS_HOST, + port=REDIS_PORT, + qname=queue_name + suffix, + quiet=False, + ) + + queue.deleteQueue().exceptions(False).execute() + queue.createQueue(maxsize=-1, vt=120).exceptions(False).execute() print("Queues properly deleted") diff --git a/src/extract_segments.py b/src/extract_segments.py index 43a0050..da9994e 100644 --- a/src/extract_segments.py +++ b/src/extract_segments.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from configuration import DOCUMENT_LAYOUT_ANALYSIS_URL, service_logger, USE_FAST, OCR_OUTPUT @@ -58,7 +59,9 @@ def ocr_pdf(task: Task) -> bool: if results and results.status_code == 200: results_path = Path(OCR_OUTPUT, task.tenant, task.params.filename) + os.makedirs(results_path.parent, exist_ok=True) results_path.write_bytes(results.content) + path.unlink() return True raise RuntimeError(f"Error OCR document: {results.status_code} - {results.text}") diff --git a/src/test_end_to_end.py b/src/test_end_to_end.py index 2cc4372..487b6c4 100644 --- a/src/test_end_to_end.py +++ b/src/test_end_to_end.py @@ -88,13 +88,66 @@ def test_one_token_per_page_pdf(self): self.assertEqual(response_json[0]["page_number"], 1) self.assertEqual(response_json[1]["page_number"], 2) + def async_ocr(self, pdf_file_name, language) -> list[dict[str, any]]: + namespace = "async_ocr" + + with open(f"{configuration.APP_PATH}/test_files/{pdf_file_name}", "rb") as stream: + files = {"file": stream} + requests.post(f"{self.service_url}/upload/{namespace}", files=files) + + task = Task( + tenant=namespace, + task="ocr", + params=Params(filename=pdf_file_name, language=language), + ) + + queue = RedisSMQ(host="127.0.0.1", port="6379", qname="ocr_tasks") + queue.sendMessage().message(str(task.model_dump_json())).execute() + + extraction_message = self.get_redis_message() + ocr_response = requests.get(extraction_message.file_url) + segmentation_response = requests.post(f"{self.service_url}", files={"file": ocr_response.content}) + return segmentation_response.json() + + def test_async_ocr(self): + paragraphs_per_page = self.async_ocr("ocr-sample-english.pdf", language="en") + self.assertEqual(1, len(paragraphs_per_page)) + self.assertEqual("Test text OCR", paragraphs_per_page[0]["text"]) + + def test_async_ocr_specific_language(self): + paragraphs_per_page = self.async_ocr("ocr-sample-french.pdf", language="fr") + self.assertEqual(1, len(paragraphs_per_page)) + self.assertEqual("Où puis-je m'en procurer", paragraphs_per_page[0]["text"]) + + def test_error_ocr(self): + tenant = "end_to_end_test_error" + pdf_file_name = "error_pdf.pdf" + queue = RedisSMQ(host="127.0.0.1", port="6379", qname="segmentation_tasks") + + with open(f"{configuration.APP_PATH}/test_files/{pdf_file_name}", "rb") as stream: + files = {"file": stream} + requests.post(f"{self.service_url}/upload/{tenant}", files=files) + + task = Task(tenant=tenant, task="ocr", params=Params(filename=pdf_file_name)) + + queue.sendMessage().message(task.model_dump_json()).execute() + + extraction_message = self.get_redis_message() + + self.assertEqual(tenant, extraction_message.tenant) + self.assertEqual("ocr", extraction_message.task) + self.assertEqual("error_pdf.pdf", extraction_message.params.filename) + self.assertEqual(False, extraction_message.success) + @staticmethod def get_redis_message() -> ResultMessage: - queue = RedisSMQ(host="127.0.0.1", port="6379", qname="segmentation_results", quiet=True) - - for i in range(80): - time.sleep(3) - message = queue.receiveMessage().exceptions(False).execute() - if message: - queue.deleteMessage(id=message["id"]).execute() - return ResultMessage(**json.loads(message["message"])) + queues_names = ["segmentation", "ocr"] + + for i in range(160): + for queue_name in queues_names: + time.sleep(1) + queue = RedisSMQ(host="127.0.0.1", port="6379", qname=f"{queue_name}_results", quiet=True) + message = queue.receiveMessage().exceptions(False).execute() + if message: + queue.deleteMessage(id=message["id"]).execute() + return ResultMessage(**json.loads(message["message"])) diff --git a/src/test_files/ocr-sample-english.pdf b/src/test_files/ocr-sample-english.pdf new file mode 100644 index 0000000..6975963 Binary files /dev/null and b/src/test_files/ocr-sample-english.pdf differ diff --git a/src/test_files/ocr-sample-french.pdf b/src/test_files/ocr-sample-french.pdf new file mode 100644 index 0000000..dfd0eac Binary files /dev/null and b/src/test_files/ocr-sample-french.pdf differ