From ee685a41832b40eafb9954a67dd190fc1f73cb8c Mon Sep 17 00:00:00 2001 From: dwiddows Date: Tue, 9 Apr 2024 23:21:43 +0000 Subject: [PATCH] Eval options and improvements --- experiments/classification_report.py | 1 + experiments/distilbert_expts.ipynb | 573 +++++++++++++++++++++++++++ experiments/eval_tool.py | 100 +++++ experiments/huggingface_client.py | 32 +- experiments/twituser_eval.py | 54 --- 5 files changed, 696 insertions(+), 64 deletions(-) create mode 100644 experiments/distilbert_expts.ipynb create mode 100644 experiments/eval_tool.py delete mode 100644 experiments/twituser_eval.py diff --git a/experiments/classification_report.py b/experiments/classification_report.py index 34faea5..ddf28cb 100644 --- a/experiments/classification_report.py +++ b/experiments/classification_report.py @@ -17,6 +17,7 @@ def nullsafe_classification_report(y_label: List[str], y_pred: List[str]): num_pred_labels = len({y for y in y_pred if y in label_set}) y_pred = [y if y in label_set else dummy_val for y in y_pred] report = classification_report(y_label, y_pred, output_dict=True, zero_division=0.0) + if dummy_val in report: del report[dummy_val] report["macro avg"]["precision"] = report["macro avg"]["precision"] * (num_pred_labels + 1) / num_pred_labels diff --git a/experiments/distilbert_expts.ipynb b/experiments/distilbert_expts.ipynb new file mode 100644 index 0000000..4f10b76 --- /dev/null +++ b/experiments/distilbert_expts.ipynb @@ -0,0 +1,573 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/widdows/.local/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from datasets import Dataset, DatasetDict\n", + "from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments\n", + "from sklearn.metrics import precision_recall_fscore_support\n", + "import evaluate\n", + "import numpy as np\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "from eval_tool import sample_from_big_string, sample_texts_from_dir\n", + "\n", + "\n", + "HUGGINGFACE_MODEL = \"distilbert/distilbert-base-multilingual-cased\"\n", + "OUTPUT_DIR = \"distilmbert_lc_model_80_b\"\n", + "num_train_per_lang = 80\n", + "train_len = 256\n", + "num_test_per_lang = 20\n", + "test_len = 256\n", + "\n", + "\n", + "TOKENIZER = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL)\n", + "def preprocess_function(examples):\n", + " return TOKENIZER(examples[\"text\"], max_length=512, truncation=True)\n", + "\n", + "\n", + "ACCURACY = evaluate.load(\"accuracy\")\n", + "def compute_metrics(eval_pred):\n", + " predictions, labels = eval_pred\n", + " predictions = np.argmax(predictions, axis=1)\n", + " return ACCURACY.compute(predictions=predictions, references=labels)\n", + "\n", + "def compute_metrics(eval_pred):\n", + " logits, labels = eval_pred\n", + " predictions = np.argmax(logits, axis=-1)\n", + " precision, recall, fscore, _ = precision_recall_fscore_support(labels, predictions, average='weighted')\n", + " return {\n", + " 'precision': precision,\n", + " 'recall': recall,\n", + " 'fscore': fscore\n", + " }\n", + "\n", + "data_collator = DataCollatorWithPadding(tokenizer=TOKENIZER)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:9585 training characters for language 'bug'. We need 20480.\n", + "WARNING:root:2394 training characters for language 'bug'. We need 5120.\n", + "WARNING:root:Skipping\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:10895 training characters for language 'iu'. We need 20480.\n", + "WARNING:root:2720 training characters for language 'iu'. We need 5120.\n", + "WARNING:root:Skipping\n", + "WARNING:root:2256 training characters for language 'chy'. We need 20480.\n", + "WARNING:root:557 training characters for language 'chy'. We need 5120.\n", + "WARNING:root:Skipping\n", + "WARNING:root:12836 training characters for language 'bi'. We need 20480.\n", + "WARNING:root:3206 training characters for language 'bi'. We need 5120.\n", + "WARNING:root:Skipping\n", + "WARNING:root:13542 training characters for language 'ty'. We need 20480.\n", + "WARNING:root:3386 training characters for language 'ty'. We need 5120.\n", + "WARNING:root:Skipping\n", + "WARNING:root:5236 training characters for language 'ik'. We need 20480.\n", + "WARNING:root:1301 training characters for language 'ik'. We need 5120.\n", + "WARNING:root:Skipping\n", + "WARNING:root:13646 training characters for language 'sg'. We need 20480.\n", + "WARNING:root:3409 training characters for language 'sg'. We need 5120.\n", + "WARNING:root:Skipping\n", + "WARNING:root:2507 training characters for language 'cr'. We need 20480.\n", + "WARNING:root:621 training characters for language 'cr'. We need 5120.\n", + "WARNING:root:Skipping\n", + "Map: 100%|██████████| 22560/22560 [00:03<00:00, 7501.15 examples/s]\n", + "Map: 100%|██████████| 5640/5640 [00:00<00:00, 8118.87 examples/s]\n" + ] + } + ], + "source": [ + "import logging\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "# Small Wikipedia corpus from https://lukelindemann.com/wiki_corpus.html, preprocessed using lplangid training script.\n", + "wiki_root = Path.home() / \"Data\" / \"WikipediaLindemann\"\n", + "language_codes = os.listdir(wiki_root / \"train\")\n", + "label2id = {lang: idx for idx, lang in enumerate(language_codes)}\n", + "id2label = {idx: lang for lang, idx in label2id.items()}\n", + "\n", + "train_texts, train_labels, test_texts, test_labels = [], [], [], []\n", + "\n", + "for lang in language_codes:\n", + " train_fh = open(wiki_root / \"train\" / lang, encoding='utf-8')\n", + " test_fh = open(wiki_root / \"test\" / lang, encoding='utf-8')\n", + " train_contents = train_fh.read()\n", + " test_contents = test_fh.read()\n", + "\n", + " # Check there is enough data, otherwise skip.\n", + " if len(train_contents) < train_len * num_train_per_lang or len(test_contents) < test_len * num_test_per_lang:\n", + " logging.warning(f\"{len(train_contents)} training characters for language '{lang}'. We need {train_len * num_train_per_lang}.\")\n", + " logging.warning(f\"{len(test_contents)} training characters for language '{lang}'. We need {test_len * num_test_per_lang}.\")\n", + " logging.warning(\"Skipping\")\n", + " continue\n", + "\n", + " train_texts.extend(sample_from_big_string(train_contents, train_len, num_train_per_lang))\n", + " train_labels.extend([label2id[lang]] * num_train_per_lang)\n", + " test_texts.extend(sample_from_big_string(train_contents, test_len, num_test_per_lang))\n", + " test_labels.extend([label2id[lang]] * num_test_per_lang)\n", + "\n", + "train_dataset = Dataset.from_dict({\"text\": train_texts, \"label\": train_labels})\n", + "test_dataset = Dataset.from_dict({\"text\": test_texts, \"label\": test_labels})\n", + "\n", + "# Bundle into a DatasetDict\n", + "dataset_dict = DatasetDict({\n", + " \"train\": train_dataset,\n", + " \"test\": test_dataset\n", + "})\n", + "tokenized_data = dataset_dict.map(preprocess_function, batched=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model = AutoModelForSequenceClassification.from_pretrained(\n", + " HUGGINGFACE_MODEL,\n", + " num_labels=len(label2id), id2label=id2label, label2id=label2id\n", + " )\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir=OUTPUT_DIR,\n", + " learning_rate=2e-5,\n", + " per_device_train_batch_size=16,\n", + " per_device_eval_batch_size=16,\n", + " num_train_epochs=10,\n", + " weight_decay=0.01,\n", + " evaluation_strategy=\"epoch\",\n", + " save_strategy=\"epoch\",\n", + " load_best_model_at_end=True,\n", + " push_to_hub=False,\n", + ")\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=tokenized_data[\"train\"],\n", + " eval_dataset=tokenized_data[\"test\"],\n", + " tokenizer=TOKENIZER,\n", + " data_collator=data_collator,\n", + " compute_metrics=compute_metrics,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting training based on distilbert/distilbert-base-multilingual-cased ... outputting to distilbert_lc_model_80_b\n", + "Labels: 282 Num train: 80 (len 256). Num test: 20 (len 256).\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [14100/14100 1:03:18, Epoch 10/10]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation LossPrecisionRecallFscore
12.8356001.0510480.8568250.8602840.842108
20.5238000.2933000.9150680.9200350.908701
30.2583000.1670550.9391260.9457450.938646
40.1728000.1443050.9532320.9457450.939434
50.1370000.1063130.9584310.9579790.953299
60.1195000.0965190.9623200.9622340.958400
70.0893000.0807630.9670460.9684400.964432
80.0892000.0727170.9742600.9707450.967565
90.0823000.0692960.9740260.9719860.969491
100.0802000.0676430.9691850.9730500.969544

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done training.\n" + ] + } + ], + "source": [ + "retrain = True\n", + "if retrain:\n", + " print(f\"Starting training based on {HUGGINGFACE_MODEL} ... outputting to {trainer.args.output_dir}\")\n", + " print(f\"Labels: {len(set(test_labels))} Num train: {num_train_per_lang} (len {train_len}). Num test: {num_test_per_lang} (len {test_len}).\")\n", + " trainer.train()\n", + " print(\"Done training.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [705/705 00:19]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "eval_loss: 0.0676, eval_precision: 0.9692, eval_recall: 0.9730, eval_fscore: 0.9695, eval_runtime: 20.5156, eval_samples_per_second: 274.9120, eval_steps_per_second: 34.3640\n" + ] + } + ], + "source": [ + "lc_model = AutoModelForSequenceClassification.from_pretrained(\n", + " OUTPUT_DIR + \"/checkpoint-14100\",\n", + " num_labels=len(label2id), id2label=id2label, label2id=label2id)\n", + "\n", + "evaluator = Trainer(\n", + " model=lc_model,\n", + " eval_dataset=tokenized_data[\"test\"],\n", + " data_collator=data_collator,\n", + " compute_metrics=compute_metrics,\n", + ")\n", + "\n", + "test_results = evaluator.evaluate(tokenized_data[\"test\"])\n", + "print(\", \".join([f\"{k}: {v:0.4f}\" for k, v in test_results.items()]))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 5800/5800 [00:00<00:00, 48306.12 examples/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [725/725 00:04]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 5800/5800 [00:00<00:00, 16108.81 examples/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [725/725 00:05]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 5800/5800 [00:00<00:00, 23825.13 examples/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [725/725 00:06]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 5800/5800 [00:00<00:00, 14059.37 examples/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [725/725 00:11]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 5800/5800 [00:00<00:00, 8162.97 examples/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [725/725 00:19]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results for length 16\n", + "eval_loss: 2.9565, eval_precision: 0.5560, eval_recall: 0.4667, eval_fscore: 0.4619, eval_runtime: 4.8304, eval_samples_per_second: 1200.7290, eval_steps_per_second: 150.0910\n", + "Results for length 32\n", + "eval_loss: 2.0615, eval_precision: 0.6743, eval_recall: 0.6347, eval_fscore: 0.6260, eval_runtime: 5.0815, eval_samples_per_second: 1141.3880, eval_steps_per_second: 142.6730\n", + "Results for length 64\n", + "eval_loss: 1.4544, eval_precision: 0.7739, eval_recall: 0.7717, eval_fscore: 0.7599, eval_runtime: 6.7701, eval_samples_per_second: 856.7070, eval_steps_per_second: 107.0880\n", + "Results for length 128\n", + "eval_loss: 1.1765, eval_precision: 0.8394, eval_recall: 0.8419, eval_fscore: 0.8306, eval_runtime: 11.0415, eval_samples_per_second: 525.2910, eval_steps_per_second: 65.6610\n", + "Results for length 256\n", + "eval_loss: 0.8697, eval_precision: 0.8738, eval_recall: 0.8803, eval_fscore: 0.8669, eval_runtime: 19.7455, eval_samples_per_second: 293.7380, eval_steps_per_second: 36.7170\n" + ] + } + ], + "source": [ + "all_results = {}\n", + "for eval_strlen in [16, 32, 64, 128, 256]:\n", + " eval_texts, eval_labels = sample_texts_from_dir(Path(wiki_root) / \"test\", eval_strlen, num_test_per_lang)\n", + "\n", + " eval_dataset = Dataset.from_dict({\"text\": eval_texts, \"label\": [label2id[l] for l in eval_labels]})\n", + " tokenized_eval_data = eval_dataset.map(preprocess_function, batched=True)\n", + "\n", + " evaluator = Trainer(\n", + " model=lc_model,\n", + " eval_dataset=tokenized_eval_data,\n", + " data_collator=data_collator,\n", + " compute_metrics=compute_metrics,\n", + " )\n", + "\n", + " eval_results = evaluator.evaluate(tokenized_eval_data)\n", + " all_results[eval_strlen] = eval_results\n", + "\n", + "\n", + "for eval_strlen, eval_results in all_results.items():\n", + " print(f\"Results for length {eval_strlen}\")\n", + " print(\", \".join([f\"{k}: {v:0.4f}\" for k, v in eval_results.items()]))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv_llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiments/eval_tool.py b/experiments/eval_tool.py new file mode 100644 index 0000000..33e9bf4 --- /dev/null +++ b/experiments/eval_tool.py @@ -0,0 +1,100 @@ +import json +from pathlib import Path +import os +import time + +import pandas as pd +import langid + +from lplangid import language_classifier as lc +from experiments import fasttext_client, huggingface_client +from experiments.classification_report import nullsafe_classification_report + +wiki_root = Path.home() / "Data" / "WikipediaLindemann" +bibles_root = Path.home() / "Data" / "bibles" / "BibleTexts/" + +def langid_classify(text: str): + return langid.classify(text)[0] + + +def load_twituser_test_data(): + with open("twituser_data/twituser") as twituser_data: + records = [json.loads(line) for line in twituser_data] + texts = [record["text"] for record in records] + labels = [record["lang"] for record in records] + return texts, labels + + +def sample_from_big_string(big_str: str, text_len: int, num_samples: int): + texts = [] + for i in range(num_samples): + region_start = (i * len(big_str)) // num_samples + start = big_str.find(" ", region_start) + if start == -1: + continue + end = big_str.find(" ", start + text_len) + if end == -1: + end = 0 + texts.append(big_str[start:end]) + return texts + + +def sample_texts_from_dir(text_dir, min_length, samples_per_file): + texts, labels = [], [] + for lang in os.listdir(text_dir): + with open(Path(text_dir) / lang, encoding='utf-8') as lang_fh: + lang_contents = lang_fh.read() + lang_samples = sample_from_big_string(lang_contents, min_length, samples_per_file) + texts.extend(lang_samples) + if lang.endswith(".txt"): + lang = lang[:-4] + labels.extend([lang] * len(lang_samples)) + return texts, labels + + +def run_tests(): + rrc_bibles= lc.RRCLanguageClassifier.many_language_bible_instance() + rrc_smallwiki = lc.RRCLanguageClassifier(*lc.prepare_scoring_tables(data_dir=lc.FREQ_DATA_DIR + "_smallwiki")) + ft_classifier = fasttext_client.FastTextLangID() + hg_classifier = huggingface_client.HuggingfaceLangID() + hg_classifier_xlm = huggingface_client.HuggingfaceLangID(huggingface_client.HUGGINGFACE_XLM_MODEL_PATH) + + fn_tags = [ + [lambda texts: [rrc_bibles.get_winner(text) for text in texts], "RRC bibles"], + [lambda texts: [rrc_smallwiki.get_winner(text) for text in texts], "RRC smallwiki"], + [lambda texts: [ft_classifier.predict_lang(text) for text in texts], "FastText"], + [lambda texts: [langid_classify(text) for text in texts], "LangID"], + [hg_classifier.predict_lang_batch, "DistilMBert Lang ID"], + [hg_classifier_xlm.predict_lang_batch, "XLM Roberta Lang ID"], + ] + + strlens = [16, 64, 256] + test_texts = [] + for strlen in strlens: + texts, y_labels = sample_texts_from_dir(Path(wiki_root) / "test", strlen, 20) + test_texts.append([strlen, texts, y_labels]) + + tagged_reports = [] + for fn, tag in fn_tags: + for strlen, texts, y_labels in test_texts: + y_pred = fn(texts) + df_report = nullsafe_classification_report(y_labels, y_pred) + df_report = df_report.loc[["macro avg", "weighted avg"]][["precision", "recall", "f1-score"]] + + # Create tuples and MultiIndex that include both the strlen/tag and the 'macro avg'/'weighted avg' + index_tuples = [(tag, strlen, 'macro avg'), (tag, strlen, 'weighted avg')] + multi_index = pd.MultiIndex.from_tuples(index_tuples, names=['StrLen', 'Classifier', 'Metric']) + df_report = pd.DataFrame(df_report.values, index=multi_index, columns=df_report.columns) + tagged_reports.append(df_report) + + final_df = pd.concat(tagged_reports) + print(final_df) + final_df.to_csv("results/wiki_results_df.csv") + + +def main(): + run_tests() + + +if __name__ == "__main__": + main() diff --git a/experiments/huggingface_client.py b/experiments/huggingface_client.py index 2ba02a4..72c1676 100644 --- a/experiments/huggingface_client.py +++ b/experiments/huggingface_client.py @@ -8,27 +8,37 @@ from scipy.special import softmax import torch -from accelerate import Accelerator, DataLoaderConfiguration from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer warnings.filterwarnings("ignore", category=FutureWarning, module="accelerate.*") -HUGGINGFACE_MODEL_ROOT = Path(os.path.dirname(__file__)) / "distilbert_lc_model_80" +HUGGINGFACE_DEFAULT_MODEL_ROOT = Path(os.path.dirname(__file__)) / "distilmbert_lc_model_800" +HUGGINGFACE_XLM_MODEL_PATH = "papluca/xlm-roberta-base-language-detection" -def get_latest_model_from_dir(directory): + +def get_latest_model_from_dir(model_path): + # If it's not a directory, hopefully it's a huggingface tag ... + if not os.path.exists(model_path): + return model_path + + # If it's a directory with a config.json, it's probably what we're looking for. + if os.path.isfile(Path(model_path) / "config.json"): + return model_path + + # Otherwise look for the last checkpoint in the directory. pattern = re.compile(r"checkpoint-\d+") - dir_items = os.listdir(directory) + dir_items = os.listdir(model_path) checkpoints = sorted(filter(pattern.match, dir_items), key=lambda x: int(x.split('-')[-1])) if not checkpoints: raise ValueError("No checkpoint found in the directory.") latest_checkpoint = checkpoints[-1] - return os.path.join(directory, latest_checkpoint) + return os.path.join(model_path, latest_checkpoint) class HuggingfaceLangID: - def __init__(self, model_root=HUGGINGFACE_MODEL_ROOT): + def __init__(self, model_root=HUGGINGFACE_DEFAULT_MODEL_ROOT): model_path = get_latest_model_from_dir(model_root) self.lc_model = AutoModelForSequenceClassification.from_pretrained(model_path) self.tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -65,8 +75,10 @@ def predict_lang(self, text: str, verbose=False): if __name__ == "__main__": - LANGUAGE = HuggingfaceLangID() - lang = LANGUAGE.predict_lang_batch(["Hello in English", "Bonjour en Francais"]) - print(f"Prediction: {lang}") - + hg_classifier = HuggingfaceLangID() + lang = hg_classifier.predict_lang_batch(["Hello in English", "Bonjour en Francais"]) + print(f"Default prediction: {lang}") + hg_xlm_classifier = HuggingfaceLangID(HUGGINGFACE_XLM_MODEL_PATH) + lang = hg_xlm_classifier.predict_lang_batch(["Hello in English", "Bonjour en Francais"]) + print(f"XLM prediction: {lang}") diff --git a/experiments/twituser_eval.py b/experiments/twituser_eval.py deleted file mode 100644 index 11579b5..0000000 --- a/experiments/twituser_eval.py +++ /dev/null @@ -1,54 +0,0 @@ -import json - -import langid - -from lplangid import language_classifier as lc -from experiments import fasttext_client, huggingface_client -from experiments.classification_report import nullsafe_classification_report - - -def langid_classify(text: str): - return langid.classify(text)[0] - - -def run_twituser_tests(): - rrc_classifier = lc.RRCLanguageClassifier.default_instance() - rrc_bibles= lc.RRCLanguageClassifier.many_language_bible_instance() - rrc_smallwiki = lc.RRCLanguageClassifier(*lc.prepare_scoring_tables(data_dir=lc.FREQ_DATA_DIR + "_smallwiki")) - ft_classifier = fasttext_client.FastTextLangID() - hg_classifier = huggingface_client.HuggingfaceLangID() - - fn_labels = [ - [lambda texts: [rrc_classifier.get_winner(text) for text in texts], "RRC default"], - [lambda texts: [rrc_bibles.get_winner(text) for text in texts], "RRC bibles"], - [lambda texts: [rrc_smallwiki.get_winner(text) for text in texts], "RRC smallwiki"], - [lambda texts: [ft_classifier.predict_lang(text) for text in texts], "FastText"], - [lambda texts: [langid_classify(text) for text in texts], "LangID"], - [hg_classifier.predict_lang_batch, "HuggingFace"], - ] - - # fn_labels = [[hg_classifier.predict_lang, "HuggingFace"]] - - for fn, label in fn_labels: - print(f"Classifying with {label}") - y_labels = [] - with open("twituser_data/twituser") as twituser_data: - input_texts = [] - for line in twituser_data: - record = json.loads(line) - # if record["lang"] not in rrc_classifier.term_ranks: - # continue - input_texts.append(record["text"]) - y_labels.append(record["lang"]) - - y_pred = fn(input_texts) - - print(nullsafe_classification_report(y_labels, y_pred)) - - -def main(): - run_twituser_tests() - - -if __name__ == "__main__": - main()