Skip to content

Commit

Permalink
Add PDF OCR
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-piles committed Dec 12, 2024
1 parent fc89b40 commit 99f8de3
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 32 deletions.
2 changes: 1 addition & 1 deletion docker-compose-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ services:
worker-pdf-layout:
container_name: "worker-pdf-layout"
entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"]
image: ghcr.io/huridocs/pdf-document-layout-analysis:0.0.16
image: ghcr.io/huridocs/pdf-document-layout-analysis:0.0.21
init: true
restart: unless-stopped
ports:
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ services:
worker-pdf-layout-gpu:
container_name: "worker-pdf-layout-no-gpu"
entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"]
image: ghcr.io/huridocs/pdf-document-layout-analysis:0.0.16
image: ghcr.io/huridocs/pdf-document-layout-analysis:0.0.21
init: true
restart: unless-stopped
network_mode: host
Expand Down
11 changes: 9 additions & 2 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from os.path import join

import pymongo
import requests
from fastapi import FastAPI, HTTPException, File, UploadFile
import sys

Expand All @@ -12,7 +13,7 @@
from starlette.responses import PlainTextResponse, FileResponse

from catch_exceptions import catch_exceptions
from configuration import MONGO_HOST, MONGO_PORT, service_logger, OCR_OUTPUT
from configuration import MONGO_HOST, MONGO_PORT, service_logger, OCR_OUTPUT, DOCUMENT_LAYOUT_ANALYSIS_URL
from PdfFile import PdfFile
from get_paragraphs import get_paragraphs
from get_xml import get_xml
Expand Down Expand Up @@ -42,11 +43,17 @@ async def lifespan(app: FastAPI):


@app.get("/")
async def info():
async def root():
service_logger.info("Get PDF paragraphs info endpoint")
return sys.version


@app.get("/info")
@catch_exceptions
async def info():
return requests.get(f"{DOCUMENT_LAYOUT_ANALYSIS_URL}/info").json()


@app.get("/error")
async def error():
service_logger.error("This is a test error from the error endpoint")
Expand Down
2 changes: 1 addition & 1 deletion src/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
import graypy

QUEUES_NAMES = os.environ.get("QUEUES_NAMES", "segmentation development_segmentation")
QUEUES_NAMES = os.environ.get("QUEUES_NAMES", "segmentation development_segmentation ocr development_ocr")

SERVICE_HOST = os.environ.get("SERVICE_HOST", "http://127.0.0.1")
SERVICE_PORT = os.environ.get("SERVICE_PORT", "5051")
Expand Down
31 changes: 12 additions & 19 deletions src/delete_queues.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sys import maxsize

from redis import exceptions
from rsmq import RedisSMQ

Expand All @@ -10,25 +12,16 @@
def delete_queues():
try:
for queue_name in QUEUES_NAMES.split():
queue = RedisSMQ(
host=REDIS_HOST,
port=REDIS_PORT,
qname=queue_name + "_tasks",
quiet=False,
)

queue.deleteQueue().exceptions(False).execute()
queue.createQueue().exceptions(False).execute()

queue = RedisSMQ(
host=REDIS_HOST,
port=REDIS_PORT,
qname=queue_name + "_results",
quiet=False,
)

queue.deleteQueue().exceptions(False).execute()
queue.createQueue().exceptions(False).execute()
for suffix in ["_tasks", "_results"]:
queue = RedisSMQ(
host=REDIS_HOST,
port=REDIS_PORT,
qname=queue_name + suffix,
quiet=False,
)

queue.deleteQueue().exceptions(False).execute()
queue.createQueue(maxsize=-1, vt=120).exceptions(False).execute()

print("Queues properly deleted")

Expand Down
3 changes: 3 additions & 0 deletions src/extract_segments.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path

from configuration import DOCUMENT_LAYOUT_ANALYSIS_URL, service_logger, USE_FAST, OCR_OUTPUT
Expand Down Expand Up @@ -58,7 +59,9 @@ def ocr_pdf(task: Task) -> bool:

if results and results.status_code == 200:
results_path = Path(OCR_OUTPUT, task.tenant, task.params.filename)
os.makedirs(results_path.parent, exist_ok=True)
results_path.write_bytes(results.content)
path.unlink()
return True

raise RuntimeError(f"Error OCR document: {results.status_code} - {results.text}")
69 changes: 61 additions & 8 deletions src/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,66 @@ def test_one_token_per_page_pdf(self):
self.assertEqual(response_json[0]["page_number"], 1)
self.assertEqual(response_json[1]["page_number"], 2)

def async_ocr(self, pdf_file_name, language) -> list[dict[str, any]]:
namespace = "async_ocr"

with open(f"{configuration.APP_PATH}/test_files/{pdf_file_name}", "rb") as stream:
files = {"file": stream}
requests.post(f"{self.service_url}/upload/{namespace}", files=files)

task = Task(
tenant=namespace,
task="ocr",
params=Params(filename=pdf_file_name, language=language),
)

queue = RedisSMQ(host="127.0.0.1", port="6379", qname="ocr_tasks")
queue.sendMessage().message(str(task.model_dump_json())).execute()

extraction_message = self.get_redis_message()
ocr_response = requests.get(extraction_message.file_url)
segmentation_response = requests.post(f"{self.service_url}", files={"file": ocr_response.content})
return segmentation_response.json()

def test_async_ocr(self):
paragraphs_per_page = self.async_ocr("ocr-sample-english.pdf", language="en")
self.assertEqual(1, len(paragraphs_per_page))
self.assertEqual("Test text OCR", paragraphs_per_page[0]["text"])

def test_async_ocr_specific_language(self):
paragraphs_per_page = self.async_ocr("ocr-sample-french.pdf", language="fr")
self.assertEqual(1, len(paragraphs_per_page))
self.assertEqual("Où puis-je m'en procurer", paragraphs_per_page[0]["text"])

def test_error_ocr(self):
tenant = "end_to_end_test_error"
pdf_file_name = "error_pdf.pdf"
queue = RedisSMQ(host="127.0.0.1", port="6379", qname="segmentation_tasks")

with open(f"{configuration.APP_PATH}/test_files/{pdf_file_name}", "rb") as stream:
files = {"file": stream}
requests.post(f"{self.service_url}/upload/{tenant}", files=files)

task = Task(tenant=tenant, task="ocr", params=Params(filename=pdf_file_name))

queue.sendMessage().message(task.model_dump_json()).execute()

extraction_message = self.get_redis_message()

self.assertEqual(tenant, extraction_message.tenant)
self.assertEqual("ocr", extraction_message.task)
self.assertEqual("error_pdf.pdf", extraction_message.params.filename)
self.assertEqual(False, extraction_message.success)

@staticmethod
def get_redis_message() -> ResultMessage:
queue = RedisSMQ(host="127.0.0.1", port="6379", qname="segmentation_results", quiet=True)

for i in range(80):
time.sleep(3)
message = queue.receiveMessage().exceptions(False).execute()
if message:
queue.deleteMessage(id=message["id"]).execute()
return ResultMessage(**json.loads(message["message"]))
queues_names = ["segmentation", "ocr"]

for i in range(160):
for queue_name in queues_names:
time.sleep(1)
queue = RedisSMQ(host="127.0.0.1", port="6379", qname=f"{queue_name}_results", quiet=True)
message = queue.receiveMessage().exceptions(False).execute()
if message:
queue.deleteMessage(id=message["id"]).execute()
return ResultMessage(**json.loads(message["message"]))
Binary file added src/test_files/ocr-sample-english.pdf
Binary file not shown.
Binary file added src/test_files/ocr-sample-french.pdf
Binary file not shown.

0 comments on commit 99f8de3

Please sign in to comment.