Skip to content

Commit

Permalink
allow configuration of source and target languages
Browse files Browse the repository at this point in the history
  • Loading branch information
heinpa committed Aug 20, 2024
1 parent d26dfc1 commit 84eb6c2
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 131 deletions.
2 changes: 1 addition & 1 deletion qanary-component-MT-Python-NLLB/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.7
FROM python:3.10

COPY requirements.txt ./

Expand Down
6 changes: 4 additions & 2 deletions qanary-component-MT-Python-NLLB/boot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

export $(grep -v '^#' .env | xargs)

echo Downloading the model
python -c 'from transformers import AutoModelForSeq2SeqLM, AutoTokenizer; model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") ; tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")'
echo Downloading the models

python -c "from utils.model_utils import load_models_and_tokenizers; load_models_and_tokenizers(); "

echo Downloading the model finished

echo The port number is: $SERVER_PORT
Expand Down
2 changes: 1 addition & 1 deletion qanary-component-MT-Python-NLLB/component/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from component.mt_nllb import mt_nllb_bp
from flask import Flask

version = "0.1.3"
version = "0.2.0"

# default config file
configfile = "app.conf"
Expand Down
171 changes: 72 additions & 99 deletions qanary-component-MT-Python-NLLB/component/mt_nllb.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,23 @@
from langdetect import detect
import logging
import os
from flask import Blueprint, jsonify, request
from qanary_helpers.qanary_queries import get_text_question_in_graph, insert_into_triplestore
from qanary_helpers.language_queries import get_translated_texts_in_triplestore, get_texts_with_detected_language_in_triplestore, question_text_with_language, create_annotation_of_question_translation
from utils.model_utils import load_models_and_tokenizers
from utils.lang_utils import translation_options, LANG_CODE_MAP

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)

mt_nllb_bp = Blueprint("mt_nllb_bp", __name__, template_folder="templates")

SERVICE_NAME_COMPONENT = os.environ["SERVICE_NAME_COMPONENT"]
SOURCE_LANG = os.environ["SOURCE_LANGUAGE"]
TARGET_LANG = os.environ["TARGET_LANGUAGE"]

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
lang_code_map = {
'en': 'eng_Latn',
'de': 'deu_Latn',
'ru': 'rus_Cyrl',
'fr': 'fra_Latn',
'es': 'spa_Latn',
'pt': 'por_Latn'
}

model, tokenizer = load_models_and_tokenizers()

@mt_nllb_bp.route("/annotatequestion", methods=["POST"])
def qanary_service():
"""the POST endpoint required for a Qanary service"""

triplestore_endpoint = request.json["values"]["urn:qanary#endpoint"]
triplestore_ingraph = request.json["values"]["urn:qanary#inGraph"]
triplestore_outgraph = request.json["values"]["urn:qanary#outGraph"]
logging.info("endpoint: %s, inGraph: %s, outGraph: %s" % \
(triplestore_endpoint, triplestore_ingraph, triplestore_outgraph))

text = get_text_question_in_graph(triplestore_endpoint=triplestore_endpoint,
graph=triplestore_ingraph)[0]["text"]
question_uri = get_text_question_in_graph(triplestore_endpoint=triplestore_endpoint,
graph=triplestore_ingraph)[0]["uri"]
logging.info(f"Question text: {text}")

if SOURCE_LANG != None and len(SOURCE_LANG.strip()) > 0:
lang = SOURCE_LANG
logging.info("Using custom SOURCE_LANGUAGE")
else:
lang = detect(text)
logging.info("No SOURCE_LANGUAGE specified, using langdetect!")
logging.info(f"source language: {lang}")


## MAIN FUNCTIONALITY
tokenizer.src_lang = lang_code_map[lang]
def translate_input(text: str, source_lang: str, target_lang: str) -> str:
logging.info(f"translating \"{text}\" from \"{source_lang}\" to \"{target_lang}\"")
tokenizer.src_lang = LANG_CODE_MAP[source_lang]
logging.info(f"source language mapped code: {tokenizer.src_lang}")
batch = tokenizer(text, return_tensors="pt")

Expand All @@ -64,66 +29,74 @@ def qanary_service():
# Perform the translation and decode the output
generated_tokens = model.generate(
**batch,
forced_bos_token_id=tokenizer.lang_code_to_id[lang_code_map[TARGET_LANG]])
forced_bos_token_id=tokenizer.convert_tokens_to_ids(LANG_CODE_MAP[target_lang]))
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
translation = result.replace("\"", "\\\"") #keep quotation marks that are part of the translation
logging.info(f"result: \"{translation}\"")
return translation


def find_source_texts_in_triplestore(triplestore_endpoint: str, graph_uri: str, lang: str) -> list[question_text_with_language]:
source_texts = []

# check if supported languages have been determined already (LD)
# (use filters)
# if so, use the target uris to find the question text to translate
ld_source_texts = get_texts_with_detected_language_in_triplestore(triplestore_endpoint, graph_uri, lang)
source_texts.extend(ld_source_texts)

# check if there are translations into the relevant language (MT)
# (use filters)
# if so, use the translation texts
mt_source_texts = get_translated_texts_in_triplestore(triplestore_endpoint, graph_uri, lang)
source_texts.extend(mt_source_texts)

# TODO: what if nothing found?
if len(source_texts) == 0:
logging.warning(f"No source texts with language {lang} could be found In the triplestore!")

return source_texts


@mt_nllb_bp.route("/annotatequestion", methods=["POST"])
def qanary_service():
"""the POST endpoint required for a Qanary service"""

triplestore_endpoint = request.json["values"]["urn:qanary#endpoint"]
triplestore_ingraph = request.json["values"]["urn:qanary#inGraph"]
triplestore_outgraph = request.json["values"]["urn:qanary#outGraph"]
logging.info("endpoint: %s, inGraph: %s, outGraph: %s" % \
(triplestore_endpoint, triplestore_ingraph, triplestore_outgraph))

# building SPARQL query TODO: verify this annotation AnnotationOfQuestionTranslation ??
SPARQLqueryAnnotationOfQuestionTranslation = """
PREFIX qa: <http://www.wdaqua.eu/qa#>
PREFIX oa: <http://www.w3.org/ns/openannotation/core/>
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
INSERT {{
GRAPH <{uuid}> {{
?a a qa:AnnotationOfQuestionTranslation ;
oa:hasTarget <{qanary_question_uri}> ;
oa:hasBody "{translation_result}"@{target_lang} ;
oa:annotatedBy <urn:qanary:{app_name}> ;
oa:annotatedAt ?time .
}}
}}
WHERE {{
BIND (IRI(str(RAND())) AS ?a) .
BIND (now() as ?time)
}}""".format(
uuid=triplestore_ingraph,
qanary_question_uri=question_uri,
translation_result=result.replace("\"", "\\\""), #keep quotation marks that are part of the translation
target_lang=TARGET_LANG,
app_name=SERVICE_NAME_COMPONENT
)

SPARQLqueryAnnotationOfQuestionLanguage = """
PREFIX qa: <http://www.wdaqua.eu/qa#>
PREFIX oa: <http://www.w3.org/ns/openannotation/core/>
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
INSERT {{
GRAPH <{uuid}> {{
?b a qa:AnnotationOfQuestionLanguage ;
oa:hasTarget <{qanary_question_uri}> ;
oa:hasBody "{src_lang}"^^xsd:string ;
oa:annotatedBy <urn:qanary:{app_name}> ;
oa:annotatedAt ?time .
}}
}}
WHERE {{
BIND (IRI(str(RAND())) AS ?b) .
BIND (now() as ?time)
}}""".format(
uuid=triplestore_ingraph,
qanary_question_uri=question_uri,
src_lang=lang,
app_name=SERVICE_NAME_COMPONENT
)

logging.info(f'SPARQL: {SPARQLqueryAnnotationOfQuestionTranslation}')
logging.info(f'SPARQL: {SPARQLqueryAnnotationOfQuestionLanguage}')
# inserting new data to the triplestore
insert_into_triplestore(triplestore_endpoint, SPARQLqueryAnnotationOfQuestionTranslation)
insert_into_triplestore(triplestore_endpoint, SPARQLqueryAnnotationOfQuestionLanguage)
text_question_in_graph = get_text_question_in_graph(triplestore_endpoint=triplestore_endpoint, graph=triplestore_ingraph)
question_text = text_question_in_graph[0]['text']
logging.info(f'Original question text: {question_text}')


# Collect texts to be translated (group by source language)

for source_lang in translation_options.keys():
source_texts = find_source_texts_in_triplestore(
triplestore_endpoint=triplestore_endpoint,
graph_uri=triplestore_ingraph,
lang=source_lang
)

# translate source texts into specified target languages
for target_lang in translation_options[source_lang]:
for source_text in source_texts:
translation = translate_input(source_text.get_text(), source_lang, target_lang)
if len(translation.strip()) > 0:
SPARQLqueryAnnotationOfQuestionTranslation = create_annotation_of_question_translation(
graph_uri=triplestore_ingraph,
question_uri=source_text.get_uri(),
translation=translation,
translation_language=target_lang,
app_name=SERVICE_NAME_COMPONENT
)
insert_into_triplestore(triplestore_endpoint, SPARQLqueryAnnotationOfQuestionTranslation)
else:
logging.error(f"result is empty string!")

return jsonify(request.get_json())

Expand Down
10 changes: 4 additions & 6 deletions qanary-component-MT-Python-NLLB/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
[pytest]
log_cli = 0
log_cli = 1
log_cli_level = INFO
log_cli_format = %(asctime)s [%(levelname)8s] [%(filename)s:%(lineno)s] %(message)s
log_cli_date_format=%Y-%m-%d %H:%M:%S
env =
SERVER_PORT=40120
SERVICE_PORT=40120
SERVICE_HOST=http://public-component-host
SPRING_BOOT_ADMIN_URL=http://qanary-pipeline-host:40111
SERVER_HOST=http://public-component-host
SPRING_BOOT_ADMIN_CLIENT_INSTANCE_SERVICE-BASE-URL=http://public-component-host:40120
SPRING_BOOT_ADMIN_USERNAME=admin
SPRING_BOOT_ADMIN_PASSWORD=admin
SERVICE_NAME_COMPONENT=MT-NLLB
SERVICE_NAME_COMPONENT=MT-NLLB-Component
SERVICE_DESCRIPTION_COMPONENT=Translates question to English
SOURCE_LANGUAGE=de
TARGET_LANGUAGE=en
19 changes: 8 additions & 11 deletions qanary-component-MT-Python-NLLB/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
Flask
langdetect==1.0.9
mock==3.0.5
python-dotenv==0.21.1
Flask==3.0.3
pytest==8.3.2
pytest-env==1.1.3
qanary_helpers==0.2.2
transformers==4.41.0
sentencepiece==0.1.97
torch==2.3.0
gunicorn==20.1.0
protobuf==3.20.*
pytest
pytest-env
SentencePiece==0.2.0
SPARQLWrapper==2.0.0
torch==2.4.0
transformers==4.44.0
qanary-helpers==0.2.2
70 changes: 70 additions & 0 deletions qanary-component-MT-Python-NLLB/tests/test_lang_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from unittest import mock
from unittest import TestCase
import os
import importlib

class TestLangUtils(TestCase):

logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)

@mock.patch.dict(os.environ, {'SOURCE_LANGUAGE': 'fr'})
def test_only_one_source_language(self):
import utils.lang_utils
importlib.reload(utils.lang_utils)
from utils.lang_utils import translation_options
assert 'fr' in translation_options.keys()
assert len(translation_options.keys()) == 1


@mock.patch.dict(os.environ, {'TARGET_LANGUAGE': 'ru'})
def test_only_one_target_language(self):
import utils.lang_utils
importlib.reload(utils.lang_utils)
from utils.lang_utils import translation_options
# all 5 non-russian source languages should support 'ru'
assert len(translation_options.items()) == 5
# but each item should only contain the one target language!
assert ('en', ['ru']) in translation_options.items()
assert ('de', ['ru']) in translation_options.items()
assert ('es', ['ru']) in translation_options.items()
assert ('fr', ['ru']) in translation_options.items()
assert ('pt', ['ru']) in translation_options.items()


@mock.patch.dict(os.environ, {'SOURCE_LANGUAGE': 'en', 'TARGET_LANGUAGE': 'es'})
def test_specific_source_and_target_language(self):
import utils.lang_utils
importlib.reload(utils.lang_utils)
from utils.lang_utils import translation_options
assert translation_options == {'en': ['es']}


@mock.patch.dict(os.environ, {'SOURCE_LANGUAGE': 'zh'})
def test_unsupported_source_language_raises_error(self):
try:
import utils.lang_utils
importlib.reload(utils.lang_utils)
except ValueError as ve:
logging.error(ve)
pass


@mock.patch.dict(os.environ, {'SOURCE_LANGUAGE': 'en', 'TARGET_LANGUAGE': 'zh'})
def test_unsupported_target_for_source_language_raises_error(self):
try:
import utils.lang_utils
importlib.reload(utils.lang_utils)
except ValueError as ve:
logging.error(ve)
pass


@mock.patch.dict(os.environ, {'TARGET_LANGUAGE': 'zh'})
def test_unsupported_target_language_raises_error(self):
try:
import utils.lang_utils
importlib.reload(utils.lang_utils)
except ValueError as ve:
logging.error(ve)
pass
Loading

0 comments on commit 84eb6c2

Please sign in to comment.