Skip to content

Commit

Permalink
added lemmatization evaluation for greCy on UD test data
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinschulz committed Aug 20, 2024
1 parent 94e4fef commit 88ff455
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 26 deletions.
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MORPHEUS_PATH=/morpheus-perseids
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.idea/
.env
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
# SEFLAG
**S**ystematic **E**valuation **F**ramework for NLP models and datasets in **L**atin and **A**ncient **G**reek

## Evaluation Results
### Lemmatization
#### greCy on UD test data
{'accuracy': 0.8942049121548943}
5 changes: 5 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os


class Config:
data_dir: str = os.path.abspath("data")
2 changes: 2 additions & 0 deletions data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
proiel/
lemmatization_test/
77 changes: 77 additions & 0 deletions lemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os.path
import subprocess
import betacode.conv
import conllu
import spacy
from conllu import SentenceList
from dotenv import load_dotenv
from spacy import Language
from spacy.tokens import Doc
from tqdm import tqdm
import xml.etree.ElementTree as ET
from config import Config
from metrics import accuracy

# need this for greCy to work properly
Doc.set_extension("trf_data", default=None)
beta_to_uni: dict[str, str] = dict()
uni_to_beta: dict[str, str] = dict()


def convert_labels(lemmata_predicted: list[str], lemmata_true: list[str]) -> tuple[list[int], list[int]]:
""" Converts lemmatization results from strings to integers, which is useful for calculating metrics. """
all_lemmata: set[str] = set(lemmata_predicted + lemmata_true)
lemma_to_idx: dict[str, int] = {lemma: idx for idx, lemma in enumerate(all_lemmata)}
predictions_int: list[int] = [lemma_to_idx[x] for x in lemmata_predicted]
references_int: list[int] = [lemma_to_idx[x] for x in lemmata_true]
return predictions_int, references_int


def morpheus(text: str) -> list[str]:
""" Runs Morpheus and uses it to lemmatize a given word form. """
if text not in uni_to_beta:
# Morpheus only accepts beta code; need to replace special sigma notation with ordinary one
uni_to_beta[text] = betacode.conv.uni_to_beta(text).replace("s1", "s")
# make sure that MORPHEUS_PATH is present in your .env file and points to the correct folder
# see https://github.com/perseids-tools/morpheus-perseids for installation instructions
load_dotenv()
env: dict = os.environ.copy()
env["MORPHLIB"] = "stemlib"
cp: subprocess.CompletedProcess = subprocess.run(
["bin/morpheus", uni_to_beta[text]], capture_output=True, env=env, cwd=env["MORPHEUS_PATH"])
output: bytes = cp.stdout
xml: str = output.decode("utf-8")
root: ET.Element = ET.fromstring(xml)
headwords: list[ET.Element] = root.findall(".//hdwd")
lemmata: list[str] = [x.text for x in headwords]
for lemma in lemmata:
if lemma not in beta_to_uni:
beta_to_uni[lemma] = betacode.conv.beta_to_uni(lemma)
return [beta_to_uni[x] for x in lemmata]


def run_evaluation():
data_dir: str = os.path.join(Config.data_dir, 'lemmatization_test')
sl: SentenceList = SentenceList()
for file in [x for x in os.listdir(data_dir) if x.endswith(".conllu")]:
file_path: str = os.path.join(data_dir, file)
with open(file_path, 'r') as f:
new_sl: SentenceList = conllu.parse(f.read())
sl += new_sl
nlp: Language = spacy.load(
"grc_proiel_trf", # grc_proiel_trf grc_odycy_joint_trf
exclude=["morphologizer", "parser", "tagger", "transformer"], #
)
lemmata_predicted: list[str] = []
lemmata_true: list[str] = []
for sent in tqdm(sl):
words: list[str] = [tok["form"] for tok in sent]
new_lemmata_true: list[str] = [tok["lemma"] for tok in sent]
lemmata_true += new_lemmata_true
doc: Doc = nlp(Doc(vocab=nlp.vocab, words=words))
lemmata_predicted += [x.lemma_ for x in doc]
predictions_int, references_int = convert_labels(lemmata_predicted, lemmata_true)
accuracy(predictions_int, references_int)


# run_evaluation()
19 changes: 19 additions & 0 deletions metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import evaluate
from evaluate import EvaluationModule


def accuracy(y_pred: list[int], y_true: list[int]):
""" Calculates the accuracy of the predicted results against the ground truth. """
evaluation_module: EvaluationModule = evaluate.load("accuracy")
print(evaluation_module.compute(references=y_true, predictions=y_pred))


def precision_recall_f1(predictions: list[int], references: list[int]):
""" Calculates various metrics for the given predicted and true labels. """
averages: list[str] = ["weighted", "micro", "macro"]
metrics: list[str] = ["precision", "recall", "f1"]
for metric in metrics:
evaluation_module: EvaluationModule = evaluate.load(metric)
for average in averages:
print(average,
evaluation_module.compute(predictions=predictions, references=references, average=average))
31 changes: 10 additions & 21 deletions ner.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import csv
import os.path

import evaluate
import yaml
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
from evaluate import EvaluationModule
from tqdm import tqdm
import la_core_web_lg
from flair.data import Sentence
from flair.models import SequenceTagger
from spacy import Language
from spacy.tokens import Doc

from config import Config
from metrics import precision_recall_f1


class Mappings:
def __init__(self, source_file_path: str):
Expand Down Expand Up @@ -46,19 +46,9 @@ def annotate_latin_texts(words: list[str]) -> list[str]:
return values


def calculate_metrics(predictions, references):
""" Calculates various metrics for the given predictions and references (i.e., ground truth). """
averages: list[str] = ["weighted", "micro", "macro"]
metrics: list[str] = ["precision", "recall", "f1"]
for metric in metrics:
evaluation_module: EvaluationModule = evaluate.load(metric)
for average in averages:
print(average, evaluation_module.compute(predictions=predictions, references=references, average=average))


def map_labels(original_labels, mapping):
def map_labels(original_labels, mapping: dict[str, int]) -> list[int]:
""" Applies a mapping to string labels, thereby converting them to integers. """
labels_mapped = []
labels_mapped: list[int] = []
for label in original_labels:
for key in mapping:
if (len(key) <= 1 and key == label) or (len(key) > 1 and key in label):
Expand All @@ -77,16 +67,15 @@ def run_evaluation(folder_path: str, reference_column_name: str, word_column_nam
dataset_selection = dataset_combined # .select(list(range(100)))
print(dataset_selection)
references_raw = dataset_selection[reference_column_name]
references_mapped = map_labels(references_raw, references_mapping)
references_mapped: list[int] = map_labels(references_raw, references_mapping)
words: list[str] = dataset_selection[word_column_name]
ner_labels: list[str] = annotation_fn(words)
predictions = map_labels(ner_labels, predictions_mapping)
calculate_metrics(predictions, references_mapped)
predictions: list[int] = map_labels(ner_labels, predictions_mapping)
precision_recall_f1(predictions, references_mapped)


mappings: Mappings = Mappings("mappings.yaml")
data_dir: str = os.path.abspath("data")
greek_data_path: str = os.path.join(data_dir, "yousef_et_al_dataset")
latin_data_path: str = os.path.join(data_dir, "Herodotos_dataset")
greek_data_path: str = os.path.join(Config.data_dir, "yousef_et_al_dataset")
latin_data_path: str = os.path.join(Config.data_dir, "Herodotos_dataset")
run_evaluation(greek_data_path, "entity", "word", annotate_greek_texts, mappings.per_loc_misc, mappings.per_loc_misc)
run_evaluation(latin_data_path, "Label", "Word", annotate_latin_texts, mappings.prs_geo_grp, mappings.per_loc_norp)
22 changes: 17 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ async-lru==2.0.4
attrs==23.2.0
Babel==2.14.0
beautifulsoup4==4.12.3
betacode==1.0
bleach==6.1.0
blis==0.7.11
boto3==1.34.69
Expand Down Expand Up @@ -46,6 +47,8 @@ fsspec==2024.2.0
ftfy==6.2.0
gdown==5.1.0
gensim==4.3.2
grc_proiel_trf==3.7.5
grecy==1.0
h11==0.14.0
httpcore==1.0.4
httpx==0.27.0
Expand All @@ -65,19 +68,19 @@ jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter_client==8.6.1
jupyter-console==6.6.3
jupyter_core==5.7.2
jupyter-events==0.10.0
jupyter-lsp==2.2.4
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.13.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.4
jupyterlab_widgets==3.0.10
kiwisolver==1.4.5
la-core-web-lg @ https://huggingface.co/latincy/la_core_web_lg/resolve/main/la_core_web_lg-any-py3-none-any.whl#sha256=b20b559f395cb42193bf092d57c81d3ba4430bc9e65715675aaa7e0eeb66b7bb
la-core-web-lg==3.7.4
langcodes==3.4.0
langdetect==1.0.9
language_data==1.2.0
Expand All @@ -98,6 +101,7 @@ nbconvert==7.16.3
nbformat==5.10.3
nest-asyncio==1.6.0
networkx==3.2.1
nltk==3.8.1
notebook==7.1.2
notebook_shim==0.2.4
numpy==1.26.4
Expand All @@ -120,6 +124,7 @@ pandocfilters==1.5.1
parso==0.8.3
pexpect==4.9.0
pillow==10.2.0
pip==24.0
platformdirs==4.2.0
pptree==3.1
preshed==3.0.9
Expand All @@ -135,9 +140,11 @@ pycparser==2.21
pydantic==2.7.1
pydantic_core==2.18.2
Pygments==2.17.2
pygtrie==2.5.0
pyparsing==3.1.2
PySocks==1.7.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytorch_revgrad==0.2.0
pytz==2024.1
Expand All @@ -160,13 +167,17 @@ segtok==1.5.11
semver==3.0.2
Send2Trash==1.8.2
sentencepiece==0.1.99
seqeval==1.2.2
setuptools==69.5.1
six==1.16.0
smart-open==6.4.0
sniffio==1.3.1
soupsieve==2.5
spacy==3.7.4
spacy==3.7.5
spacy-alignments==0.9.1
spacy-legacy==3.0.12
spacy-loggers==1.0.5
spacy-transformers==1.3.5
sqlitedict==2.1.0
srsly==2.4.8
stack-data==0.6.3
Expand All @@ -182,7 +193,7 @@ tornado==6.4
tqdm==4.66.2
traitlets==5.14.2
transformer-smaller-training-vocab==0.3.3
transformers==4.39.1
transformers==4.36.2
triton==2.2.0
typer==0.9.4
types-python-dateutil==2.9.0.20240316
Expand All @@ -196,6 +207,7 @@ weasel==0.3.4
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
wheel==0.43.0
widgetsnbextension==4.0.10
Wikipedia-API==0.6.0
wrapt==1.16.0
Expand Down

0 comments on commit 88ff455

Please sign in to comment.