From 447f0cab6329e09f18af52d9b4aa2b1ab1f51aa3 Mon Sep 17 00:00:00 2001 From: heinpa Date: Thu, 22 Aug 2024 18:09:43 +0200 Subject: [PATCH] add fastapi support to MBart component --- qanary-component-MT-Python-MBart/boot.sh | 33 +++++++++++---- .../component/__init__.py | 36 ++++++++++++---- .../component/mt_mbart_nlp.py | 42 ++++++++++++------- .../requirements.txt | 3 +- 4 files changed, 80 insertions(+), 34 deletions(-) diff --git a/qanary-component-MT-Python-MBart/boot.sh b/qanary-component-MT-Python-MBart/boot.sh index e40993cc..ccda09fa 100755 --- a/qanary-component-MT-Python-MBart/boot.sh +++ b/qanary-component-MT-Python-MBart/boot.sh @@ -1,6 +1,26 @@ #!/bin/sh +export $(grep -v "^#" < .env) -export $(grep -v '^#' .env | xargs) +# check required parameters +declare -a required_vars=( +"SPRING_BOOT_ADMIN_URL" +"SERVER_HOST" +"SERVER_PORT" +"SPRING_BOOT_ADMIN_USERNAME" +"SPRING_BOOT_ADMIN_PASSWORD" +"SERVICE_NAME_COMPONENT" +"SERVICE_DESCRIPTION_COMPONENT" +# TODO: other? +) + +for param in ${required_vars[@]}; +do + if [[ -z ${!param} ]]; then + echo "Required variable \"$param\" is not set!" + echo "The required variables are: ${required_vars[@]}" + exit 4 + fi +done echo Downloading the models @@ -8,10 +28,7 @@ python -c "from utils.model_utils import load_models_and_tokenizers; load_models echo Downloading the model finished -echo SERVER_PORT: $SERVER_PORT -echo Qanary pipeline at SPRING_BOOT_ADMIN_URL: $SPRING_BOOT_ADMIN_URL - -if [ -n $SERVER_PORT ] -then - exec gunicorn -b :$SERVER_PORT --access-logfile - --error-logfile - run:app # refer to the gunicorn documentation for more options -fi +echo The port number is: $SERVER_PORT +echo The host is: $SERVER_HOST +echo The Qanary pipeline URL is: $SPRING_BOOT_ADMIN_URL +exec uvicorn run:app --host 0.0.0.0 --port $SERVER_PORT --log-level warning diff --git a/qanary-component-MT-Python-MBart/component/__init__.py b/qanary-component-MT-Python-MBart/component/__init__.py index 6cd66870..bac36347 100644 --- a/qanary-component-MT-Python-MBart/component/__init__.py +++ b/qanary-component-MT-Python-MBart/component/__init__.py @@ -1,5 +1,7 @@ -from component.mt_mbart_nlp import mt_mbart_nlp_bp -from flask import Flask +from component import mt_mbart +from flask import Flask +from fastapi import FastAPI +from fastapi.responses import RedirectResponse, Response, JSONResponse version = "0.2.0" @@ -8,19 +10,35 @@ # service status information healthendpoint = "/health" - aboutendpoint = "/about" +translateendpoint = "/translate" +# TODO: add languages endpoint? # init Flask app and add externalized service information -app = Flask(__name__) -app.register_blueprint(mt_mbart_nlp_bp) +app = FastAPI(docs_url="/swagger-ui.html") +app.include_router(mt_mbart.router) + + +@app.get("/") +async def main(): + return RedirectResponse("/about") + -@app.route(healthendpoint, methods=["GET"]) +@app.get(healthendpoint) def health(): """required health endpoint for callback of Spring Boot Admin server""" - return "alive" + return Response("alive", media_type="text/plain") -@app.route(aboutendpoint, methods=["GET"]) +@app.get(aboutendpoint) def about(): """required about endpoint for callback of Srping Boot Admin server""" - return "about" # TODO: replace this with a service description from configuration + return Response("Translates questions into English", media_type="text/plain") + +@app.get(translateendpoint+"_to_one", description="", tags=["Translate"]) +def translate_to_one(text: str, source_lang: str, target_lang: str): + return JSONResponse(translate_to_one(text, source_lang, target_lang)) + +@app.get(translateendpoint+"_to_all", description="", tags=["Translate"]) +def translate_to_all(text: str, source_lang: str): + return JSONResponse(translate_to_all(text, source_lang)) + diff --git a/qanary-component-MT-Python-MBart/component/mt_mbart_nlp.py b/qanary-component-MT-Python-MBart/component/mt_mbart_nlp.py index afd88cff..d5d79692 100644 --- a/qanary-component-MT-Python-MBart/component/mt_mbart_nlp.py +++ b/qanary-component-MT-Python-MBart/component/mt_mbart_nlp.py @@ -1,14 +1,15 @@ import logging import os -from flask import Blueprint, jsonify, request from qanary_helpers.qanary_queries import get_text_question_in_graph, insert_into_triplestore from qanary_helpers.language_queries import get_translated_texts_in_triplestore, get_texts_with_detected_language_in_triplestore, question_text_with_language, create_annotation_of_question_translation from utils.model_utils import load_models_and_tokenizers from utils.lang_utils import translation_options, LANG_CODE_MAP +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) -mt_mbart_nlp_bp = Blueprint("mt_mbart_nlp_bp", __name__, template_folder="templates") +router = APIRouter() SERVICE_NAME_COMPONENT = os.environ["SERVICE_NAME_COMPONENT"] @@ -38,13 +39,30 @@ def translate_input(text:str, source_lang: str, target_lang: str) -> str: return translation -@mt_mbart_nlp_bp.route("/annotatequestion", methods=["POST"]) -def qanary_service(): +def translate_to_one(text: str, source_lang: str, target_lang: str): + translation = translate_input(text, source_lang, target_lang) + return {target_lang: translation} + + +def translate_to_all(text: str, source_lang: str): + translations = list() + for target_lang in translation_options[source_lang]: + translation = translate_input(text, source_lang, target_lang) + translations.append({ + target_lang: translation + }) + return translations + + +@router.post("/annotatequestion", description="", tags=["Qanary"]) +async def qanary_service(request: Request): """the POST endpoint required for a Qanary service""" - triplestore_endpoint = request.json["values"]["urn:qanary#endpoint"] - triplestore_ingraph = request.json["values"]["urn:qanary#inGraph"] - triplestore_outgraph = request.json["values"]["urn:qanary#outGraph"] + request_json = await request.json() + + triplestore_endpoint = request_json["values"]["urn:qanary#endpoint"] + triplestore_ingraph = request_json["values"]["urn:qanary#inGraph"] + triplestore_outgraph = request_json["values"]["urn:qanary#outGraph"] logging.info("endpoint: %s, inGraph: %s, outGraph: %s" % \ (triplestore_endpoint, triplestore_ingraph, triplestore_outgraph)) @@ -78,7 +96,7 @@ def qanary_service(): else: logging.error(f"result is empty string!") - return jsonify(request.get_json()) + return JSONResponse(request_json) def find_source_texts_in_triplestore(triplestore_endpoint: str, graph_uri: str, lang: str) -> list[question_text_with_language]: @@ -101,11 +119,3 @@ def find_source_texts_in_triplestore(triplestore_endpoint: str, graph_uri: str, logging.warning(f"No source texts with language {lang} could be found In the triplestore!") return source_texts - - -@mt_mbart_nlp_bp.route("/", methods=["GET"]) -def index(): - """examplary GET endpoint""" - - logging.info("host_url: %s" % (request.host_url)) - return "Python MT MBart Qanary component" diff --git a/qanary-component-MT-Python-MBart/requirements.txt b/qanary-component-MT-Python-MBart/requirements.txt index 3fdfe222..7cde1026 100644 --- a/qanary-component-MT-Python-MBart/requirements.txt +++ b/qanary-component-MT-Python-MBart/requirements.txt @@ -1,4 +1,4 @@ -Flask==3.0.3 +fastapi==0.109.1 pytest==8.3.2 pytest-env==1.1.3 qanary_helpers==0.2.2 @@ -7,3 +7,4 @@ SPARQLWrapper==2.0.0 torch==2.4.0 transformers==4.44.0 qanary-helpers==0.2.2 +uvicorn==0.30.1