Skip to content

Commit

Permalink
allow for multiple source and target langugages
Browse files Browse the repository at this point in the history
  • Loading branch information
heinpa committed Aug 13, 2024
1 parent 011db0a commit f511be8
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 55 deletions.
3 changes: 2 additions & 1 deletion qanary-component-MT-Python-HelsinkiNLP/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.7
FROM python:3.10

COPY requirements.txt ./

Expand All @@ -7,6 +7,7 @@ RUN pip install -r requirements.txt; exit 0
RUN pip install gunicorn

COPY component component
COPY utils utils
COPY run.py boot.sh ./

RUN chmod +x boot.sh
Expand Down
8 changes: 5 additions & 3 deletions qanary-component-MT-Python-HelsinkiNLP/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ SPRING_BOOT_ADMIN_CLIENT_INSTANCE_SERVICE-BASE-URL=http://public-component-host:
SPRING_BOOT_ADMIN_USERNAME=admin
SPRING_BOOT_ADMIN_PASSWORD=admin
SERVICE_NAME_COMPONENT=MT-Helsinki-NLP
SERVICE_DESCRIPTION_COMPONENT=Translates question to English
SOURCE_LANGUAGE=de
SERVICE_DESCRIPTION_COMPONENT=Translates questions
SOURCE_LANGUAGE_DEFAULT=de
TARGET_LANGUAGE_DEFAULT=en
```

The parameters description:
Expand All @@ -68,7 +69,8 @@ The parameters description:
* `SPRING_BOOT_ADMIN_CLIENT_INSTANCE_SERVICE-BASE-URL` -- the URL of your Qanary component (has to be visible to the Qanary pipeline)
* `SERVICE_NAME_COMPONENT` -- the name of your Qanary component (for better identification)
* `SERVICE_DESCRIPTION_COMPONENT` -- the description of your Qanary component
* `SOURCE_LANGUAGE` -- (optional) the source language of the text (the component will use langdetect if no source language is given)
* `SOURCE_LANGUAGE_DEFAULT` -- the default source language of the translation
* `TARGET_LANGUAGE_DEFAULT` -- the default target language of the translation

4. Build the Docker image:

Expand Down
7 changes: 4 additions & 3 deletions qanary-component-MT-Python-HelsinkiNLP/boot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

export $(grep -v '^#' .env | xargs)

echo Downloading the model
python -c "from transformers.models.marian.modeling_marian import MarianMTModel; from transformers.models.marian.tokenization_marian import MarianTokenizer; supported_langs = ['ru', 'es', 'de', 'fr']; models = {lang: MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-{lang}-en'.format(lang=lang)) for lang in supported_langs}; tokenizers = {lang: MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-{lang}-en'.format(lang=lang)) for lang in supported_langs}"
echo Downloading the model finished
echo Downloading the models

python -c "from utils.model_utils import load_models_and_tokenizers; SUPPORTED_LANGS = { 'en': ['de', 'fr', 'ru', 'es'], 'de': ['en', 'fr', 'es'], 'fr': ['en', 'de', 'ru', 'es'], 'ru': ['en', 'fr', 'es'], 'es': ['en', 'de', 'fr', 'es'], }; load_models_and_tokenizers(SUPPORTED_LANGS); "

echo Downloading the model finished

echo The port number is: $SERVER_PORT
echo The Qanary pipeline URL is: $SPRING_BOOT_ADMIN_URL
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from component.mt_helsinki_nlp import mt_helsinki_nlp_bp
from flask import Flask

version = "0.1.2"
version = "0.2.0"

# default config file (use -c parameter on command line specify a custom config file)
configfile = "app.conf"
Expand Down
84 changes: 50 additions & 34 deletions qanary-component-MT-Python-HelsinkiNLP/component/mt_helsinki_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,57 @@
import os
from flask import Blueprint, jsonify, request
from qanary_helpers.qanary_queries import get_text_question_in_graph, insert_into_triplestore
from transformers.models.marian.modeling_marian import MarianMTModel
from transformers.models.marian.tokenization_marian import MarianTokenizer
from utils.model_utils import load_models_and_tokenizers

logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
mt_helsinki_nlp_bp = Blueprint('mt_helsinki_nlp_bp', __name__, template_folder='templates')

SERVICE_NAME_COMPONENT = os.environ['SERVICE_NAME_COMPONENT']
SOURCE_LANG = os.environ["SOURCE_LANGUAGE"]
TARGET_LANG = "en" # currently only used for annotation
# TODO: no target language is set, because only 'en' is supported
# TODO: determine supported target langs and download models for that

supported_langs = ['ru', 'es', 'de', 'fr']
langid.set_languages(supported_langs)
models = {lang: MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-{lang}-en'.format(lang=lang)) for lang in supported_langs}
tokenizers = {lang: MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-{lang}-en'.format(lang=lang)) for lang in supported_langs}
SOURCE_LANG_DEFAULT = os.environ["SOURCE_LANGUAGE_DEFAULT"]
TARGET_LANG_DEFAULT = os.environ["TARGET_LANGUAGE_DEFAULT"]
SUPPORTED_LANGS = {
# source: targets
'en': ['de', 'fr', 'ru', 'es'],
'de': ['en', 'fr', 'es'],
'fr': ['en', 'de', 'ru', 'es'],
'ru': ['en', 'fr', 'es'],
'es': ['en', 'de', 'fr', 'es'],
}
if SOURCE_LANG_DEFAULT not in SUPPORTED_LANGS.keys():
raise ValueError(f"default source language \"{SOURCE_LANG_DEFAULT}\" is not supported!")
if TARGET_LANG_DEFAULT not in SUPPORTED_LANGS[SOURCE_LANG_DEFAULT]:
raise ValueError(f"default target language \"{TARGET_LANG_DEFAULT}\" is not supported for default source language \"{SOURCE_LANG_DEFAULT}\"!")

langid.set_languages(SUPPORTED_LANGS.keys())

models, tokenizers = load_models_and_tokenizers(SUPPORTED_LANGS)

def detect_source_and_target_language(text: str):
# TODO: this currently uses set source and target languages from configuration
# this might be extended to use annotations in the triplestore, or other means to pass
# source and target language dynamically.

#if SOURCE_LANG_DEFAULT != None and len(SOURCE_LANG_DEFAULT.strip()) > 0:
# logging.info(f"Using SOURCE_LANGUAGE from configuration: {SOURCE_LANG_DEFAULT}")
#else:
# logging.info("No SOURCE_LANGUAGE specified, detecting with langid!")
# source_lang, prob = langid.classify(text)
# logging.info(f"source language: {source_lang} ({prob} %)")

return SOURCE_LANG_DEFAULT, TARGET_LANG_DEFAULT


def translate_input(text: str, source_lang: str, target_lang: str) -> str:

batch = tokenizers[source_lang][target_lang]([text], return_tensors="pt", padding=True)
# Make sure that the tokenized text does not exceed the maximum
# allowed size of 512
batch["input_ids"] = batch["input_ids"][:, :512]
batch["attention_mask"] = batch["attention_mask"][:, :512]
# Perform the translation and decode the output
translation = models[source_lang][target_lang].generate(**batch)
result = tokenizers[source_lang][target_lang].batch_decode(translation, skip_special_tokens=True)[0]
return result


@mt_helsinki_nlp_bp.route("/annotatequestion", methods=['POST'])
Expand All @@ -34,27 +69,8 @@ def qanary_service():
question_uri = get_text_question_in_graph(triplestore_endpoint=triplestore_endpoint, graph=triplestore_ingraph)[0]['uri']
logging.info(f'Question Text: {text}')

if SOURCE_LANG != None and len(SOURCE_LANG.strip()) > 0:
lang = SOURCE_LANG
logging.info("Using custom SOURCE_LANGUAGE")
else:
lang, prob = langid.classify(text)
logging.info("No SOURCE_LANGUAGE specified, using langid!")
logging.info(f"source language: {lang}")
if lang not in supported_langs:
raise RuntimeError(f"source language {lang} is not supported!")



batch = tokenizers[lang]([text], return_tensors="pt", padding=True)

# Make sure that the tokenized text does not exceed the maximum
# allowed size of 512
batch["input_ids"] = batch["input_ids"][:, :512]
batch["attention_mask"] = batch["attention_mask"][:, :512]
# Perform the translation and decode the output
translation = models[lang].generate(**batch)
result = tokenizers[lang].batch_decode(translation, skip_special_tokens=True)[0]
source_lang, target_lang = detect_source_and_target_language(text)
result = translate_input(text, source_lang, target_lang)

# building SPARQL query TODO: verify this annotation AnnotationOfQuestionTranslation ??
SPARQLqueryAnnotationOfQuestionTranslation = """
Expand All @@ -80,7 +96,7 @@ def qanary_service():
uuid=triplestore_ingraph,
qanary_question_uri=question_uri,
translation_result=result.replace("\"", "\\\""), #keep quotation marks that are part of the translation
target_lang=TARGET_LANG,
target_lang=target_lang,
app_name=SERVICE_NAME_COMPONENT
)

Expand All @@ -105,7 +121,7 @@ def qanary_service():
""".format(
uuid=triplestore_ingraph,
qanary_question_uri=question_uri,
src_lang=lang,
src_lang=source_lang,
app_name=SERVICE_NAME_COMPONENT
)

Expand Down
3 changes: 2 additions & 1 deletion qanary-component-MT-Python-HelsinkiNLP/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ env =
SERVICE_PORT=41062
SERVICE_NAME_COMPONENT=MT-Helsinki-NLP-Component
SERVICE_DESCRIPTION_COMPONENT=MT tool that uses pre-trained models by Helsinki NLP implemented in transformers library
SOURCE_LANGUAGE=de
SOURCE_LANGUAGE_DEFAULT=de
TARGET_LANGUAGE_DEFAULT=en
19 changes: 10 additions & 9 deletions qanary-component-MT-Python-HelsinkiNLP/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
transformers
torch
SentencePiece
qanary-helpers
SPARQLWrapper
Flask
langid
pytest
pytest-env
Flask==3.0.3
langid==1.1.6
pytest==8.3.2
pytest-env==1.1.3
qanary_helpers==0.2.2
SentencePiece==0.2.0
SPARQLWrapper==2.0.0
torch==2.4.0
transformers==4.44.0
qanary-helpers==0.2.2
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestComponent(TestCase):

logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)

questions = list([{"uri": "urn:test-uri", "text": "was ist ein Test?"}])
questions = list([{"uri": "urn:test-uri", "text": "Was ist die Hauptstadt von Deutschland?"}])
endpoint = "urn:qanary#test-endpoint"
in_graph = "urn:qanary#test-inGraph"
out_graph = "urn:qanary#test-outGraph"
Expand All @@ -32,8 +32,6 @@ class TestComponent(TestCase):
}

def test_qanary_service(self):


logging.info("port: %s" % (os.environ["SERVICE_PORT"]))
assert os.environ["SERVICE_NAME_COMPONENT"] == "MT-Helsinki-NLP-Component"

Expand Down Expand Up @@ -77,3 +75,23 @@ def test_qanary_service(self):

# then the response is not empty
assert response_json != None


def test_translate_input(self):
translations = [
{"text": "Was ist die Hauptstadt von Deutschland?",
"translation": "What is the capital of Germany?",
"source_lang": "de", "target_lang": "en"},
{"text": "What is the capital of Germany?",
"translation": "Quelle est la capitale de l'Allemagne?",
"source_lang": "en", "target_lang": "fr"},
{"text": "What is the capital of Germany?",
"translation": "Какая столица Германии?",
"source_lang": "en", "target_lang": "ru"},
]

for translation in translations:
expected = translation["translation"]
actual = translate_input(translation["text"], translation["source_lang"], translation["target_lang"])
assert expected == actual

Empty file.
13 changes: 13 additions & 0 deletions qanary-component-MT-Python-HelsinkiNLP/utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from transformers.models.marian.modeling_marian import MarianMTModel
from transformers.models.marian.tokenization_marian import MarianTokenizer


def load_models_and_tokenizers(supported_langs: dict):
models = {}
tokenizers = {}
for s_lang in supported_langs.keys():
lang_models = {t_lang: MarianMTModel.from_pretrained(f'Helsinki-NLP/opus-mt-{s_lang}-{t_lang}') for t_lang in supported_langs[s_lang]}
lang_tokenizers = {t_lang: MarianTokenizer.from_pretrained(f'Helsinki-NLP/opus-mt-{s_lang}-{t_lang}') for t_lang in supported_langs[s_lang]}
models[s_lang] = lang_models
tokenizers[s_lang] = lang_tokenizers
return models, tokenizers

0 comments on commit f511be8

Please sign in to comment.