From 88ff4559f2a70ee04779a967f7abd8db60df2fc1 Mon Sep 17 00:00:00 2001 From: Konstantin Schulz Date: Tue, 20 Aug 2024 14:38:50 +0200 Subject: [PATCH] added lemmatization evaluation for greCy on UD test data --- .env.template | 1 + .gitignore | 1 + README.md | 5 ++++ config.py | 5 ++++ data/.gitignore | 2 ++ lemma.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++ metrics.py | 19 ++++++++++++ ner.py | 31 +++++++------------ requirements.txt | 22 ++++++++++---- 9 files changed, 137 insertions(+), 26 deletions(-) create mode 100644 .env.template create mode 100644 config.py create mode 100644 data/.gitignore create mode 100644 lemma.py create mode 100644 metrics.py diff --git a/.env.template b/.env.template new file mode 100644 index 0000000..e57671b --- /dev/null +++ b/.env.template @@ -0,0 +1 @@ +MORPHEUS_PATH=/morpheus-perseids diff --git a/.gitignore b/.gitignore index 9f11b75..3ad1afd 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .idea/ +.env diff --git a/README.md b/README.md index 5466291..55895ae 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,7 @@ # SEFLAG **S**ystematic **E**valuation **F**ramework for NLP models and datasets in **L**atin and **A**ncient **G**reek + +## Evaluation Results +### Lemmatization +#### greCy on UD test data +{'accuracy': 0.8942049121548943} diff --git a/config.py b/config.py new file mode 100644 index 0000000..384eec5 --- /dev/null +++ b/config.py @@ -0,0 +1,5 @@ +import os + + +class Config: + data_dir: str = os.path.abspath("data") diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..7207526 --- /dev/null +++ b/data/.gitignore @@ -0,0 +1,2 @@ +proiel/ +lemmatization_test/ diff --git a/lemma.py b/lemma.py new file mode 100644 index 0000000..8636e8e --- /dev/null +++ b/lemma.py @@ -0,0 +1,77 @@ +import os.path +import subprocess +import betacode.conv +import conllu +import spacy +from conllu import SentenceList +from dotenv import load_dotenv +from spacy import Language +from spacy.tokens import Doc +from tqdm import tqdm +import xml.etree.ElementTree as ET +from config import Config +from metrics import accuracy + +# need this for greCy to work properly +Doc.set_extension("trf_data", default=None) +beta_to_uni: dict[str, str] = dict() +uni_to_beta: dict[str, str] = dict() + + +def convert_labels(lemmata_predicted: list[str], lemmata_true: list[str]) -> tuple[list[int], list[int]]: + """ Converts lemmatization results from strings to integers, which is useful for calculating metrics. """ + all_lemmata: set[str] = set(lemmata_predicted + lemmata_true) + lemma_to_idx: dict[str, int] = {lemma: idx for idx, lemma in enumerate(all_lemmata)} + predictions_int: list[int] = [lemma_to_idx[x] for x in lemmata_predicted] + references_int: list[int] = [lemma_to_idx[x] for x in lemmata_true] + return predictions_int, references_int + + +def morpheus(text: str) -> list[str]: + """ Runs Morpheus and uses it to lemmatize a given word form. """ + if text not in uni_to_beta: + # Morpheus only accepts beta code; need to replace special sigma notation with ordinary one + uni_to_beta[text] = betacode.conv.uni_to_beta(text).replace("s1", "s") + # make sure that MORPHEUS_PATH is present in your .env file and points to the correct folder + # see https://github.com/perseids-tools/morpheus-perseids for installation instructions + load_dotenv() + env: dict = os.environ.copy() + env["MORPHLIB"] = "stemlib" + cp: subprocess.CompletedProcess = subprocess.run( + ["bin/morpheus", uni_to_beta[text]], capture_output=True, env=env, cwd=env["MORPHEUS_PATH"]) + output: bytes = cp.stdout + xml: str = output.decode("utf-8") + root: ET.Element = ET.fromstring(xml) + headwords: list[ET.Element] = root.findall(".//hdwd") + lemmata: list[str] = [x.text for x in headwords] + for lemma in lemmata: + if lemma not in beta_to_uni: + beta_to_uni[lemma] = betacode.conv.beta_to_uni(lemma) + return [beta_to_uni[x] for x in lemmata] + + +def run_evaluation(): + data_dir: str = os.path.join(Config.data_dir, 'lemmatization_test') + sl: SentenceList = SentenceList() + for file in [x for x in os.listdir(data_dir) if x.endswith(".conllu")]: + file_path: str = os.path.join(data_dir, file) + with open(file_path, 'r') as f: + new_sl: SentenceList = conllu.parse(f.read()) + sl += new_sl + nlp: Language = spacy.load( + "grc_proiel_trf", # grc_proiel_trf grc_odycy_joint_trf + exclude=["morphologizer", "parser", "tagger", "transformer"], # + ) + lemmata_predicted: list[str] = [] + lemmata_true: list[str] = [] + for sent in tqdm(sl): + words: list[str] = [tok["form"] for tok in sent] + new_lemmata_true: list[str] = [tok["lemma"] for tok in sent] + lemmata_true += new_lemmata_true + doc: Doc = nlp(Doc(vocab=nlp.vocab, words=words)) + lemmata_predicted += [x.lemma_ for x in doc] + predictions_int, references_int = convert_labels(lemmata_predicted, lemmata_true) + accuracy(predictions_int, references_int) + + +# run_evaluation() diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000..d6f6500 --- /dev/null +++ b/metrics.py @@ -0,0 +1,19 @@ +import evaluate +from evaluate import EvaluationModule + + +def accuracy(y_pred: list[int], y_true: list[int]): + """ Calculates the accuracy of the predicted results against the ground truth. """ + evaluation_module: EvaluationModule = evaluate.load("accuracy") + print(evaluation_module.compute(references=y_true, predictions=y_pred)) + + +def precision_recall_f1(predictions: list[int], references: list[int]): + """ Calculates various metrics for the given predicted and true labels. """ + averages: list[str] = ["weighted", "micro", "macro"] + metrics: list[str] = ["precision", "recall", "f1"] + for metric in metrics: + evaluation_module: EvaluationModule = evaluate.load(metric) + for average in averages: + print(average, + evaluation_module.compute(predictions=predictions, references=references, average=average)) diff --git a/ner.py b/ner.py index 5e988cd..0e39901 100644 --- a/ner.py +++ b/ner.py @@ -1,10 +1,7 @@ import csv import os.path - -import evaluate import yaml from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets -from evaluate import EvaluationModule from tqdm import tqdm import la_core_web_lg from flair.data import Sentence @@ -12,6 +9,9 @@ from spacy import Language from spacy.tokens import Doc +from config import Config +from metrics import precision_recall_f1 + class Mappings: def __init__(self, source_file_path: str): @@ -46,19 +46,9 @@ def annotate_latin_texts(words: list[str]) -> list[str]: return values -def calculate_metrics(predictions, references): - """ Calculates various metrics for the given predictions and references (i.e., ground truth). """ - averages: list[str] = ["weighted", "micro", "macro"] - metrics: list[str] = ["precision", "recall", "f1"] - for metric in metrics: - evaluation_module: EvaluationModule = evaluate.load(metric) - for average in averages: - print(average, evaluation_module.compute(predictions=predictions, references=references, average=average)) - - -def map_labels(original_labels, mapping): +def map_labels(original_labels, mapping: dict[str, int]) -> list[int]: """ Applies a mapping to string labels, thereby converting them to integers. """ - labels_mapped = [] + labels_mapped: list[int] = [] for label in original_labels: for key in mapping: if (len(key) <= 1 and key == label) or (len(key) > 1 and key in label): @@ -77,16 +67,15 @@ def run_evaluation(folder_path: str, reference_column_name: str, word_column_nam dataset_selection = dataset_combined # .select(list(range(100))) print(dataset_selection) references_raw = dataset_selection[reference_column_name] - references_mapped = map_labels(references_raw, references_mapping) + references_mapped: list[int] = map_labels(references_raw, references_mapping) words: list[str] = dataset_selection[word_column_name] ner_labels: list[str] = annotation_fn(words) - predictions = map_labels(ner_labels, predictions_mapping) - calculate_metrics(predictions, references_mapped) + predictions: list[int] = map_labels(ner_labels, predictions_mapping) + precision_recall_f1(predictions, references_mapped) mappings: Mappings = Mappings("mappings.yaml") -data_dir: str = os.path.abspath("data") -greek_data_path: str = os.path.join(data_dir, "yousef_et_al_dataset") -latin_data_path: str = os.path.join(data_dir, "Herodotos_dataset") +greek_data_path: str = os.path.join(Config.data_dir, "yousef_et_al_dataset") +latin_data_path: str = os.path.join(Config.data_dir, "Herodotos_dataset") run_evaluation(greek_data_path, "entity", "word", annotate_greek_texts, mappings.per_loc_misc, mappings.per_loc_misc) run_evaluation(latin_data_path, "Label", "Word", annotate_latin_texts, mappings.prs_geo_grp, mappings.per_loc_norp) diff --git a/requirements.txt b/requirements.txt index 27270d1..a04850e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ async-lru==2.0.4 attrs==23.2.0 Babel==2.14.0 beautifulsoup4==4.12.3 +betacode==1.0 bleach==6.1.0 blis==0.7.11 boto3==1.34.69 @@ -46,6 +47,8 @@ fsspec==2024.2.0 ftfy==6.2.0 gdown==5.1.0 gensim==4.3.2 +grc_proiel_trf==3.7.5 +grecy==1.0 h11==0.14.0 httpcore==1.0.4 httpx==0.27.0 @@ -65,11 +68,11 @@ jsonpointer==2.4 jsonschema==4.21.1 jsonschema-specifications==2023.12.1 jupyter==1.0.0 +jupyter_client==8.6.1 jupyter-console==6.6.3 +jupyter_core==5.7.2 jupyter-events==0.10.0 jupyter-lsp==2.2.4 -jupyter_client==8.6.1 -jupyter_core==5.7.2 jupyter_server==2.13.0 jupyter_server_terminals==0.5.3 jupyterlab==4.1.5 @@ -77,7 +80,7 @@ jupyterlab_pygments==0.3.0 jupyterlab_server==2.25.4 jupyterlab_widgets==3.0.10 kiwisolver==1.4.5 -la-core-web-lg @ https://huggingface.co/latincy/la_core_web_lg/resolve/main/la_core_web_lg-any-py3-none-any.whl#sha256=b20b559f395cb42193bf092d57c81d3ba4430bc9e65715675aaa7e0eeb66b7bb +la-core-web-lg==3.7.4 langcodes==3.4.0 langdetect==1.0.9 language_data==1.2.0 @@ -98,6 +101,7 @@ nbconvert==7.16.3 nbformat==5.10.3 nest-asyncio==1.6.0 networkx==3.2.1 +nltk==3.8.1 notebook==7.1.2 notebook_shim==0.2.4 numpy==1.26.4 @@ -120,6 +124,7 @@ pandocfilters==1.5.1 parso==0.8.3 pexpect==4.9.0 pillow==10.2.0 +pip==24.0 platformdirs==4.2.0 pptree==3.1 preshed==3.0.9 @@ -135,9 +140,11 @@ pycparser==2.21 pydantic==2.7.1 pydantic_core==2.18.2 Pygments==2.17.2 +pygtrie==2.5.0 pyparsing==3.1.2 PySocks==1.7.1 python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 python-json-logger==2.0.7 pytorch_revgrad==0.2.0 pytz==2024.1 @@ -160,13 +167,17 @@ segtok==1.5.11 semver==3.0.2 Send2Trash==1.8.2 sentencepiece==0.1.99 +seqeval==1.2.2 +setuptools==69.5.1 six==1.16.0 smart-open==6.4.0 sniffio==1.3.1 soupsieve==2.5 -spacy==3.7.4 +spacy==3.7.5 +spacy-alignments==0.9.1 spacy-legacy==3.0.12 spacy-loggers==1.0.5 +spacy-transformers==1.3.5 sqlitedict==2.1.0 srsly==2.4.8 stack-data==0.6.3 @@ -182,7 +193,7 @@ tornado==6.4 tqdm==4.66.2 traitlets==5.14.2 transformer-smaller-training-vocab==0.3.3 -transformers==4.39.1 +transformers==4.36.2 triton==2.2.0 typer==0.9.4 types-python-dateutil==2.9.0.20240316 @@ -196,6 +207,7 @@ weasel==0.3.4 webcolors==1.13 webencodings==0.5.1 websocket-client==1.7.0 +wheel==0.43.0 widgetsnbextension==4.0.10 Wikipedia-API==0.6.0 wrapt==1.16.0