Skip to content

Commit

Permalink
Extended training pipeline to many simple small files, supporting sma…
Browse files Browse the repository at this point in the history
…ll wiki as well as bibles. Also huggingface transformer eval.
  • Loading branch information
dwiddows committed Apr 1, 2024
1 parent 1c68797 commit 05042fc
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 30 deletions.
2 changes: 2 additions & 0 deletions experiments/fasttext_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pathlib import Path
import warnings
warnings.filterwarnings("ignore", message="`load_model` does not return.*")

import fasttext

Expand Down
72 changes: 72 additions & 0 deletions experiments/huggingface_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
from pathlib import Path
import re
from typing import List
import warnings

import numpy as np
from scipy.special import softmax
import torch

from accelerate import Accelerator, DataLoaderConfiguration
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer


warnings.filterwarnings("ignore", category=FutureWarning, module="accelerate.*")

HUGGINGFACE_MODEL_ROOT = Path(os.path.dirname(__file__)) / "distilbert_lc_model_80"


def get_latest_model_from_dir(directory):
pattern = re.compile(r"checkpoint-\d+")
dir_items = os.listdir(directory)
checkpoints = sorted(filter(pattern.match, dir_items), key=lambda x: int(x.split('-')[-1]))
if not checkpoints:
raise ValueError("No checkpoint found in the directory.")
latest_checkpoint = checkpoints[-1]
return os.path.join(directory, latest_checkpoint)


class HuggingfaceLangID:
def __init__(self, model_root=HUGGINGFACE_MODEL_ROOT):
model_path = get_latest_model_from_dir(model_root)
self.lc_model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.evaluator = Trainer(model=self.lc_model)

def predict_lang_batch(self, texts: List[str], batch_size=100, verbose=False):
batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
all_predicted_labels = []

for batch_texts in batches:
tokenized_texts = self.tokenizer(batch_texts, padding=True, return_tensors="pt", truncation=True, max_length=max([len(t) for t in texts]))
inputs = {k: v.to(self.evaluator.args.device) for k, v in tokenized_texts.items()}
with torch.no_grad():
outputs = self.lc_model(**inputs)
all_logits = outputs.logits.cpu().numpy()

predicted_labels = []
for logits in all_logits:
probs = softmax(logits, axis=-1)
# Print sorted languages by probability
if verbose:
lang_scores = {self.lc_model.config.id2label[i]: prob for i, prob in enumerate(probs)}
for k, v in sorted(lang_scores.items(), key=lambda x: x[1]):
print(f"{k}\t{v:0.4f}")
predicted_index = np.argmax(probs, axis=-1)
predicted_label = self.lc_model.config.id2label[predicted_index]
predicted_labels.append(predicted_label)
all_predicted_labels.extend(predicted_labels)

return all_predicted_labels

def predict_lang(self, text: str, verbose=False):
return self.predict_lang_batch([text])[0]


if __name__ == "__main__":
LANGUAGE = HuggingfaceLangID()
lang = LANGUAGE.predict_lang_batch(["Hello in English", "Bonjour en Francais"])
print(f"Prediction: {lang}")


5 changes: 5 additions & 0 deletions experiments/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
datasets
evaluate
fasttext
langid
numpy
Expand All @@ -6,3 +8,6 @@ pytest
scikit-learn
scipy
setuptools
torch
transformers[torch]

30 changes: 19 additions & 11 deletions experiments/twituser_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import langid

from lplangid import language_classifier as lc
from experiments import fasttext_client
from experiments import fasttext_client, huggingface_client
from experiments.classification_report import nullsafe_classification_report


Expand All @@ -13,28 +13,36 @@ def langid_classify(text: str):

def run_twituser_tests():
rrc_classifier = lc.RRCLanguageClassifier.default_instance()
rrc_bibles= lc.RRCLanguageClassifier.many_language_bible_instance()
rrc_smallwiki = lc.RRCLanguageClassifier(*lc.prepare_scoring_tables(data_dir=lc.FREQ_DATA_DIR + "_smallwiki"))
ft_classifier = fasttext_client.FastTextLangID()
hg_classifier = huggingface_client.HuggingfaceLangID()

fn_labels = [
[rrc_classifier.get_winner, "RRC"],
[ft_classifier.predict_lang, "FastText"],
[langid_classify, "LangID"],
[lambda texts: [rrc_classifier.get_winner(text) for text in texts], "RRC default"],
[lambda texts: [rrc_bibles.get_winner(text) for text in texts], "RRC bibles"],
[lambda texts: [rrc_smallwiki.get_winner(text) for text in texts], "RRC smallwiki"],
[lambda texts: [ft_classifier.predict_lang(text) for text in texts], "FastText"],
[lambda texts: [langid_classify(text) for text in texts], "LangID"],
[hg_classifier.predict_lang_batch, "HuggingFace"],
]

# fn_labels = [[hg_classifier.predict_lang, "HuggingFace"]]

for fn, label in fn_labels:
print(f"Classifying with {label}")
y_labels, y_pred = [], []
y_labels = []
with open("twituser_data/twituser") as twituser_data:
input_texts = []
for line in twituser_data:
record = json.loads(line)

if record["lang"] not in rrc_classifier.term_ranks:
continue

result = fn(record["text"])
y_pred.append(result)
# if record["lang"] not in rrc_classifier.term_ranks:
# continue
input_texts.append(record["text"])
y_labels.append(record["lang"])

y_pred = fn(input_texts)

print(nullsafe_classification_report(y_labels, y_pred))


Expand Down
17 changes: 12 additions & 5 deletions lplangid/language_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,18 @@ def prepare_scoring_tables(data_dir=FREQ_DATA_DIR) -> Tuple[Dict[str, Dict[str,
for lang_code in set([x.split('_')[0] for x in os.listdir(data_dir)
if x.endswith('.csv') and not x.startswith('.')]):
tf_file = os.path.join(data_dir, f"{lang_code}_term_rank.csv")
with open(tf_file) as term_freq_file:
term_freqs = cu.read_rank_file(term_freq_file, MAX_WORDS_PER_LANG)
all_term_ranks[lang_code] = term_freqs
with open(os.path.join(data_dir, f'{lang_code}_char_freq.csv')) as char_freq_file:
all_char_freqs[lang_code] = cu.normalize_score_dict(cu.read_freq_file(char_freq_file))
if not os.path.isfile(tf_file):
all_term_ranks[lang_code] = {}
else:
with open(tf_file) as term_freq_file:
term_freqs = cu.read_rank_file(term_freq_file, MAX_WORDS_PER_LANG)
all_term_ranks[lang_code] = term_freqs
cf_file = os.path.join(data_dir, f'{lang_code}_char_freq.csv')
if not os.path.isfile(cf_file):
all_char_freqs[lang_code] = {}
else:
with open(os.path.join(data_dir, f'{lang_code}_char_freq.csv')) as char_freq_file:
all_char_freqs[lang_code] = cu.normalize_score_dict(cu.read_freq_file(char_freq_file))
all_char_weights = invert_char_tables(all_char_freqs)

logging.debug(f"Prepared term and character ranking tables for languages: {sorted(all_term_ranks.keys())}")
Expand Down
3 changes: 2 additions & 1 deletion training/add_wiki_language.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# This script goes through the steps of adding a new language to the RRC classifier based on wiki data.
# See the README.md file for a more thorough explanation of these steps.
# Wikipedia archives are often large (several gigabytes), so the basic download step can take several minutes.

if [[ $(basename `pwd`) != "training" ]]
then
Expand Down Expand Up @@ -60,6 +61,6 @@ pushd $target_wiki_dir
popd

echo "Processing text files to create character frequency and term rank data ..."
time python process_wiki.py --languages $language
time python process_wiki_archive.py --languages $language

echo "Finished. Please check that files ${freq_data_dir}/${language}_char_freq.csv and ${freq_data_dir}/${language}_term_rank.csv look to be present and correct."
65 changes: 52 additions & 13 deletions training/process_bibles.py → training/process_small_texts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,27 @@
Includes several hard-coded paths from Dominic's machine.
"""
from collections import defaultdict
import csv
import logging
import numpy as np
import os
from pathlib import Path
from typing import TextIO
from typing import Dict, TextIO

import xml.etree.ElementTree as ET

from lplangid import count_utils
from lplangid import language_classifier as lc
from lplangid.tokenizer import tokenize_fast as tokenize
from process_wiki import MIN_WORD_LENGTH, SKIP_WORDS_WITH_DIGITS, WIKI_TEXT_ROOT
from training.process_wiki_archive import MIN_WORD_LENGTH, SKIP_WORDS_WITH_DIGITS, WIKI_TEXT_ROOT

# The directory with the unzipped files from https://github.com/christos-c/bible-corpus
BIBLE_XML_DIR = '/Users/widdows/Code/bible-corpus/bibles'
BIBLE_XML_DIR = str(Path.home() / "Data" / "bibles" / "bible-corpus" / "bibles")

# The directory where these will be extracted to raw text files, in full, train, and test directories.
BIBLE_TXT_ROOT = '/Users/widdows/Data/BibleTexts'
SUBDIRS = ['full', 'train', 'test']
BIBLE_TXT_ROOT = str(Path.home() / "Data" / "bibles" / "BibleTexts")
SMALLWIKI_TXT_ROOT = Path.home() / "Data" / "WikipediaLindemann"
SUBDIRS = ["full", "train", "test"]


def process_bibles_xml_to_text(corpus_dir=BIBLE_XML_DIR,
Expand All @@ -33,7 +35,7 @@ def process_bibles_xml_to_text(corpus_dir=BIBLE_XML_DIR,
plain_files = [fn for fn in all_files if '-tok' not in fn and '-WEB' not in fn]
langs = defaultdict(list)

Path(BIBLE_TXT_ROOT).mkdir(parents=True, exist_ok=True)
Path(out_dir_root).mkdir(parents=True, exist_ok=True)
for subdir in SUBDIRS:
Path(os.path.join(BIBLE_TXT_ROOT, subdir)).mkdir(parents=True, exist_ok=True)

Expand All @@ -47,7 +49,7 @@ def process_bibles_xml_to_text(corpus_dir=BIBLE_XML_DIR,
logging.warning(f"Already seen language '{lang_id}' ({lang_name}) in files {langs[lang_id]}")

out_full, out_train, out_test = [
open(os.path.join(BIBLE_TXT_ROOT, subdir, lang_id + '.txt'), 'w', encoding='utf-8')
open(os.path.join(out_dir_root, subdir, lang_id + '.txt'), 'w', encoding='utf-8')
for subdir in SUBDIRS
]
for i, n in enumerate(root.iter('seg')):
Expand All @@ -62,6 +64,34 @@ def process_bibles_xml_to_text(corpus_dir=BIBLE_XML_DIR,
logging.warning(f"Problem in file {fn} with element {str(n)}")


def process_wiki_lindemann_to_text():
orig_dir = SMALLWIKI_TXT_ROOT / "Original"
meta_lines = csv.reader(open(orig_dir / "wiki_language_codes.csv"))
fn2lang = {row[0]: row[1] for row in meta_lines if len(row) > 1}
fn2lang = {k: v for k, v in fn2lang.items() if len(v) <= 3} # Filter out "simple" for "simple english, and other non-ISO codes"

Path(SMALLWIKI_TXT_ROOT).mkdir(parents=True, exist_ok=True)
for subdir in SUBDIRS:
Path(os.path.join(SMALLWIKI_TXT_ROOT, subdir)).mkdir(parents=True, exist_ok=True)

plain_files = [fn for fn in os.listdir(orig_dir) if '.csv' not in fn and '.zip' not in fn]

for fn in plain_files:
if fn not in fn2lang:
logging.warning(f"No langmatch for filename {fn}")
continue

text = open(orig_dir / fn).read()
out_full, out_train, out_test = [
open(os.path.join(SMALLWIKI_TXT_ROOT, subdir, fn2lang[fn]), 'w', encoding='utf-8')
for subdir in SUBDIRS
]
out_full.write(text)
split_point = text.index(" ", (len(text) * 4) // 5)
out_train.write(text[:split_point])
out_test.write(text[split_point:])


def count_text_in_input(filehandle: TextIO):
term_freq_dict = defaultdict(int)
char_freq_dict = defaultdict(int)
Expand All @@ -79,12 +109,12 @@ def count_text_in_input(filehandle: TextIO):
return term_freq_dict, char_freq_dict


def text_files_to_freq_files(input_dir, output_dir=lc.FREQ_DATA_DIR + '_bible'):
def text_files_to_freq_files(input_dir: str, file_to_lang_map: Dict[str, str], output_dir):
if not os.path.isdir(output_dir):
os.mkdir(output_dir)

for infile in os.listdir(input_dir):
lang = infile.split('.')[0]
lang = file_to_lang_map[infile]
with open(os.path.join(input_dir, infile)) as filehandle:
term_freq_dict, char_freq_dict = count_text_in_input(filehandle)

Expand Down Expand Up @@ -138,19 +168,28 @@ def run_wikipedia_tests(num_trials=10000, restrict_to_wiki_langs=False):
f'Precision: {correct/attempted:0.3f}. Recall: {correct/num_trials}')


def main():
def main_bibles():
# Some of these steps are optional, depending on what you're trying to do.
retrain = False
retrain = True
if retrain:
logging.basicConfig(level=logging.INFO)
process_bibles_xml_to_text()
text_files_to_freq_files(os.path.join(BIBLE_TXT_ROOT, SUBDIRS[1]))
texts_dir = os.path.join(BIBLE_TXT_ROOT, SUBDIRS[1])
file_to_lang_map = {fn: fn.split('.')[0] for fn in os.listdir(texts_dir)}
text_files_to_freq_files(texts_dir, file_to_lang_map, output_dir=lc.FREQ_DATA_DIR + '_bible')

# print("Restricting to Wiki languages:")
# run_wikipedia_tests(restrict_to_wiki_langs=True)
# print("Selecting from all available languages:")
# run_wikipedia_tests()


def main_wiki():
process_wiki_lindemann_to_text()
texts_dir = os.path.join(SMALLWIKI_TXT_ROOT, SUBDIRS[1])
file_to_lang_map = {fn: fn.split('.')[0] for fn in os.listdir(texts_dir)}
text_files_to_freq_files(texts_dir, file_to_lang_map, output_dir=lc.FREQ_DATA_DIR + '_smallwiki')


if __name__ == '__main__':
main()
main_wiki()
File renamed without changes.

0 comments on commit 05042fc

Please sign in to comment.