diff --git a/experiments/fasttext_client.py b/experiments/fasttext_client.py index 0515037..21fec9b 100644 --- a/experiments/fasttext_client.py +++ b/experiments/fasttext_client.py @@ -1,4 +1,6 @@ from pathlib import Path +import warnings +warnings.filterwarnings("ignore", message="`load_model` does not return.*") import fasttext diff --git a/experiments/huggingface_client.py b/experiments/huggingface_client.py new file mode 100644 index 0000000..2ba02a4 --- /dev/null +++ b/experiments/huggingface_client.py @@ -0,0 +1,72 @@ +import os +from pathlib import Path +import re +from typing import List +import warnings + +import numpy as np +from scipy.special import softmax +import torch + +from accelerate import Accelerator, DataLoaderConfiguration +from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer + + +warnings.filterwarnings("ignore", category=FutureWarning, module="accelerate.*") + +HUGGINGFACE_MODEL_ROOT = Path(os.path.dirname(__file__)) / "distilbert_lc_model_80" + + +def get_latest_model_from_dir(directory): + pattern = re.compile(r"checkpoint-\d+") + dir_items = os.listdir(directory) + checkpoints = sorted(filter(pattern.match, dir_items), key=lambda x: int(x.split('-')[-1])) + if not checkpoints: + raise ValueError("No checkpoint found in the directory.") + latest_checkpoint = checkpoints[-1] + return os.path.join(directory, latest_checkpoint) + + +class HuggingfaceLangID: + def __init__(self, model_root=HUGGINGFACE_MODEL_ROOT): + model_path = get_latest_model_from_dir(model_root) + self.lc_model = AutoModelForSequenceClassification.from_pretrained(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.evaluator = Trainer(model=self.lc_model) + + def predict_lang_batch(self, texts: List[str], batch_size=100, verbose=False): + batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)] + all_predicted_labels = [] + + for batch_texts in batches: + tokenized_texts = self.tokenizer(batch_texts, padding=True, return_tensors="pt", truncation=True, max_length=max([len(t) for t in texts])) + inputs = {k: v.to(self.evaluator.args.device) for k, v in tokenized_texts.items()} + with torch.no_grad(): + outputs = self.lc_model(**inputs) + all_logits = outputs.logits.cpu().numpy() + + predicted_labels = [] + for logits in all_logits: + probs = softmax(logits, axis=-1) + # Print sorted languages by probability + if verbose: + lang_scores = {self.lc_model.config.id2label[i]: prob for i, prob in enumerate(probs)} + for k, v in sorted(lang_scores.items(), key=lambda x: x[1]): + print(f"{k}\t{v:0.4f}") + predicted_index = np.argmax(probs, axis=-1) + predicted_label = self.lc_model.config.id2label[predicted_index] + predicted_labels.append(predicted_label) + all_predicted_labels.extend(predicted_labels) + + return all_predicted_labels + + def predict_lang(self, text: str, verbose=False): + return self.predict_lang_batch([text])[0] + + +if __name__ == "__main__": + LANGUAGE = HuggingfaceLangID() + lang = LANGUAGE.predict_lang_batch(["Hello in English", "Bonjour en Francais"]) + print(f"Prediction: {lang}") + + diff --git a/experiments/requirements.txt b/experiments/requirements.txt index fcba15d..0934d67 100644 --- a/experiments/requirements.txt +++ b/experiments/requirements.txt @@ -1,3 +1,5 @@ +datasets +evaluate fasttext langid numpy @@ -6,3 +8,6 @@ pytest scikit-learn scipy setuptools +torch +transformers[torch] + diff --git a/experiments/twituser_eval.py b/experiments/twituser_eval.py index b12ba7f..11579b5 100644 --- a/experiments/twituser_eval.py +++ b/experiments/twituser_eval.py @@ -3,7 +3,7 @@ import langid from lplangid import language_classifier as lc -from experiments import fasttext_client +from experiments import fasttext_client, huggingface_client from experiments.classification_report import nullsafe_classification_report @@ -13,28 +13,36 @@ def langid_classify(text: str): def run_twituser_tests(): rrc_classifier = lc.RRCLanguageClassifier.default_instance() + rrc_bibles= lc.RRCLanguageClassifier.many_language_bible_instance() + rrc_smallwiki = lc.RRCLanguageClassifier(*lc.prepare_scoring_tables(data_dir=lc.FREQ_DATA_DIR + "_smallwiki")) ft_classifier = fasttext_client.FastTextLangID() + hg_classifier = huggingface_client.HuggingfaceLangID() fn_labels = [ - [rrc_classifier.get_winner, "RRC"], - [ft_classifier.predict_lang, "FastText"], - [langid_classify, "LangID"], + [lambda texts: [rrc_classifier.get_winner(text) for text in texts], "RRC default"], + [lambda texts: [rrc_bibles.get_winner(text) for text in texts], "RRC bibles"], + [lambda texts: [rrc_smallwiki.get_winner(text) for text in texts], "RRC smallwiki"], + [lambda texts: [ft_classifier.predict_lang(text) for text in texts], "FastText"], + [lambda texts: [langid_classify(text) for text in texts], "LangID"], + [hg_classifier.predict_lang_batch, "HuggingFace"], ] + # fn_labels = [[hg_classifier.predict_lang, "HuggingFace"]] + for fn, label in fn_labels: print(f"Classifying with {label}") - y_labels, y_pred = [], [] + y_labels = [] with open("twituser_data/twituser") as twituser_data: + input_texts = [] for line in twituser_data: record = json.loads(line) - - if record["lang"] not in rrc_classifier.term_ranks: - continue - - result = fn(record["text"]) - y_pred.append(result) + # if record["lang"] not in rrc_classifier.term_ranks: + # continue + input_texts.append(record["text"]) y_labels.append(record["lang"]) + y_pred = fn(input_texts) + print(nullsafe_classification_report(y_labels, y_pred)) diff --git a/lplangid/language_classifier.py b/lplangid/language_classifier.py index 8dc297e..38706c2 100644 --- a/lplangid/language_classifier.py +++ b/lplangid/language_classifier.py @@ -85,11 +85,18 @@ def prepare_scoring_tables(data_dir=FREQ_DATA_DIR) -> Tuple[Dict[str, Dict[str, for lang_code in set([x.split('_')[0] for x in os.listdir(data_dir) if x.endswith('.csv') and not x.startswith('.')]): tf_file = os.path.join(data_dir, f"{lang_code}_term_rank.csv") - with open(tf_file) as term_freq_file: - term_freqs = cu.read_rank_file(term_freq_file, MAX_WORDS_PER_LANG) - all_term_ranks[lang_code] = term_freqs - with open(os.path.join(data_dir, f'{lang_code}_char_freq.csv')) as char_freq_file: - all_char_freqs[lang_code] = cu.normalize_score_dict(cu.read_freq_file(char_freq_file)) + if not os.path.isfile(tf_file): + all_term_ranks[lang_code] = {} + else: + with open(tf_file) as term_freq_file: + term_freqs = cu.read_rank_file(term_freq_file, MAX_WORDS_PER_LANG) + all_term_ranks[lang_code] = term_freqs + cf_file = os.path.join(data_dir, f'{lang_code}_char_freq.csv') + if not os.path.isfile(cf_file): + all_char_freqs[lang_code] = {} + else: + with open(os.path.join(data_dir, f'{lang_code}_char_freq.csv')) as char_freq_file: + all_char_freqs[lang_code] = cu.normalize_score_dict(cu.read_freq_file(char_freq_file)) all_char_weights = invert_char_tables(all_char_freqs) logging.debug(f"Prepared term and character ranking tables for languages: {sorted(all_term_ranks.keys())}") diff --git a/training/add_wiki_language.sh b/training/add_wiki_language.sh index e308705..9fad1cd 100755 --- a/training/add_wiki_language.sh +++ b/training/add_wiki_language.sh @@ -2,6 +2,7 @@ # This script goes through the steps of adding a new language to the RRC classifier based on wiki data. # See the README.md file for a more thorough explanation of these steps. +# Wikipedia archives are often large (several gigabytes), so the basic download step can take several minutes. if [[ $(basename `pwd`) != "training" ]] then @@ -60,6 +61,6 @@ pushd $target_wiki_dir popd echo "Processing text files to create character frequency and term rank data ..." -time python process_wiki.py --languages $language +time python process_wiki_archive.py --languages $language echo "Finished. Please check that files ${freq_data_dir}/${language}_char_freq.csv and ${freq_data_dir}/${language}_term_rank.csv look to be present and correct." diff --git a/training/process_bibles.py b/training/process_small_texts.py similarity index 70% rename from training/process_bibles.py rename to training/process_small_texts.py index 4f0f410..6133908 100644 --- a/training/process_bibles.py +++ b/training/process_small_texts.py @@ -6,25 +6,27 @@ Includes several hard-coded paths from Dominic's machine. """ from collections import defaultdict +import csv import logging import numpy as np import os from pathlib import Path -from typing import TextIO +from typing import Dict, TextIO import xml.etree.ElementTree as ET from lplangid import count_utils from lplangid import language_classifier as lc from lplangid.tokenizer import tokenize_fast as tokenize -from process_wiki import MIN_WORD_LENGTH, SKIP_WORDS_WITH_DIGITS, WIKI_TEXT_ROOT +from training.process_wiki_archive import MIN_WORD_LENGTH, SKIP_WORDS_WITH_DIGITS, WIKI_TEXT_ROOT # The directory with the unzipped files from https://github.com/christos-c/bible-corpus -BIBLE_XML_DIR = '/Users/widdows/Code/bible-corpus/bibles' +BIBLE_XML_DIR = str(Path.home() / "Data" / "bibles" / "bible-corpus" / "bibles") # The directory where these will be extracted to raw text files, in full, train, and test directories. -BIBLE_TXT_ROOT = '/Users/widdows/Data/BibleTexts' -SUBDIRS = ['full', 'train', 'test'] +BIBLE_TXT_ROOT = str(Path.home() / "Data" / "bibles" / "BibleTexts") +SMALLWIKI_TXT_ROOT = Path.home() / "Data" / "WikipediaLindemann" +SUBDIRS = ["full", "train", "test"] def process_bibles_xml_to_text(corpus_dir=BIBLE_XML_DIR, @@ -33,7 +35,7 @@ def process_bibles_xml_to_text(corpus_dir=BIBLE_XML_DIR, plain_files = [fn for fn in all_files if '-tok' not in fn and '-WEB' not in fn] langs = defaultdict(list) - Path(BIBLE_TXT_ROOT).mkdir(parents=True, exist_ok=True) + Path(out_dir_root).mkdir(parents=True, exist_ok=True) for subdir in SUBDIRS: Path(os.path.join(BIBLE_TXT_ROOT, subdir)).mkdir(parents=True, exist_ok=True) @@ -47,7 +49,7 @@ def process_bibles_xml_to_text(corpus_dir=BIBLE_XML_DIR, logging.warning(f"Already seen language '{lang_id}' ({lang_name}) in files {langs[lang_id]}") out_full, out_train, out_test = [ - open(os.path.join(BIBLE_TXT_ROOT, subdir, lang_id + '.txt'), 'w', encoding='utf-8') + open(os.path.join(out_dir_root, subdir, lang_id + '.txt'), 'w', encoding='utf-8') for subdir in SUBDIRS ] for i, n in enumerate(root.iter('seg')): @@ -62,6 +64,34 @@ def process_bibles_xml_to_text(corpus_dir=BIBLE_XML_DIR, logging.warning(f"Problem in file {fn} with element {str(n)}") +def process_wiki_lindemann_to_text(): + orig_dir = SMALLWIKI_TXT_ROOT / "Original" + meta_lines = csv.reader(open(orig_dir / "wiki_language_codes.csv")) + fn2lang = {row[0]: row[1] for row in meta_lines if len(row) > 1} + fn2lang = {k: v for k, v in fn2lang.items() if len(v) <= 3} # Filter out "simple" for "simple english, and other non-ISO codes" + + Path(SMALLWIKI_TXT_ROOT).mkdir(parents=True, exist_ok=True) + for subdir in SUBDIRS: + Path(os.path.join(SMALLWIKI_TXT_ROOT, subdir)).mkdir(parents=True, exist_ok=True) + + plain_files = [fn for fn in os.listdir(orig_dir) if '.csv' not in fn and '.zip' not in fn] + + for fn in plain_files: + if fn not in fn2lang: + logging.warning(f"No langmatch for filename {fn}") + continue + + text = open(orig_dir / fn).read() + out_full, out_train, out_test = [ + open(os.path.join(SMALLWIKI_TXT_ROOT, subdir, fn2lang[fn]), 'w', encoding='utf-8') + for subdir in SUBDIRS + ] + out_full.write(text) + split_point = text.index(" ", (len(text) * 4) // 5) + out_train.write(text[:split_point]) + out_test.write(text[split_point:]) + + def count_text_in_input(filehandle: TextIO): term_freq_dict = defaultdict(int) char_freq_dict = defaultdict(int) @@ -79,12 +109,12 @@ def count_text_in_input(filehandle: TextIO): return term_freq_dict, char_freq_dict -def text_files_to_freq_files(input_dir, output_dir=lc.FREQ_DATA_DIR + '_bible'): +def text_files_to_freq_files(input_dir: str, file_to_lang_map: Dict[str, str], output_dir): if not os.path.isdir(output_dir): os.mkdir(output_dir) for infile in os.listdir(input_dir): - lang = infile.split('.')[0] + lang = file_to_lang_map[infile] with open(os.path.join(input_dir, infile)) as filehandle: term_freq_dict, char_freq_dict = count_text_in_input(filehandle) @@ -138,13 +168,15 @@ def run_wikipedia_tests(num_trials=10000, restrict_to_wiki_langs=False): f'Precision: {correct/attempted:0.3f}. Recall: {correct/num_trials}') -def main(): +def main_bibles(): # Some of these steps are optional, depending on what you're trying to do. - retrain = False + retrain = True if retrain: logging.basicConfig(level=logging.INFO) process_bibles_xml_to_text() - text_files_to_freq_files(os.path.join(BIBLE_TXT_ROOT, SUBDIRS[1])) + texts_dir = os.path.join(BIBLE_TXT_ROOT, SUBDIRS[1]) + file_to_lang_map = {fn: fn.split('.')[0] for fn in os.listdir(texts_dir)} + text_files_to_freq_files(texts_dir, file_to_lang_map, output_dir=lc.FREQ_DATA_DIR + '_bible') # print("Restricting to Wiki languages:") # run_wikipedia_tests(restrict_to_wiki_langs=True) @@ -152,5 +184,12 @@ def main(): # run_wikipedia_tests() +def main_wiki(): + process_wiki_lindemann_to_text() + texts_dir = os.path.join(SMALLWIKI_TXT_ROOT, SUBDIRS[1]) + file_to_lang_map = {fn: fn.split('.')[0] for fn in os.listdir(texts_dir)} + text_files_to_freq_files(texts_dir, file_to_lang_map, output_dir=lc.FREQ_DATA_DIR + '_smallwiki') + + if __name__ == '__main__': - main() + main_wiki() diff --git a/training/process_wiki.py b/training/process_wiki_archive.py similarity index 100% rename from training/process_wiki.py rename to training/process_wiki_archive.py