From 60101e0f276ec05f420a5dbbf4936d43d09de046 Mon Sep 17 00:00:00 2001 From: Paul Lerner Date: Mon, 12 Feb 2024 18:40:04 +0100 Subject: [PATCH] feat: viz.analyze: specify subset + fix predictions and viz_dists --- neot/viz/analyze.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/neot/viz/analyze.py b/neot/viz/analyze.py index 5a835bd..efe3850 100644 --- a/neot/viz/analyze.py +++ b/neot/viz/analyze.py @@ -17,19 +17,18 @@ from ..morph.labels import MorphLabel - def viz_f1(data, pred, metrics): - for i, item in enumerate(data["train"][:100]): + for i, item in enumerate(data[:100]): f1 = metrics["f1s"][i] if f1 > 0.0 and f1 < 1.0: - print(f1, pred['predictions'][i], item["fr"]["text"]) + print(f1, pred['predictions'][i][0], item["fr"]["text"]) def viz_wrong(data, pred, metrics): - for i, item in enumerate(data["train"][:100]): + for i, item in enumerate(data[:100]): f1 = metrics["f1s"][i] if f1 <= 0.0: - print(f1, pred['predictions'][i], item["fr"]["text"]) + print(f1, pred['predictions'][i][0], item["fr"]["text"]) def dist_f1(metrics, output): @@ -43,7 +42,7 @@ def gather_results(data, metrics, tokenizer): fr_ova = {c.name: {True: [], False: []} for c in MorphLabel} en_ova = {c.name: {True: [], False: []} for c in MorphLabel} per_dom = [] - for i, item in enumerate(data["train"]): + for i, item in enumerate(data): p_fr = item["fr"]["morph_label"] p_en = item["en"]["morph_label"] em = metrics["ems"][i] @@ -101,8 +100,11 @@ def viz_dist(results, x, output): fig.savefig(output / f"{x}_wrt_EM_dist.pdf") -def viz_dists(results, **kwargs): - for x in ["Morph. Diff.", "Edit dist.", "Term fertility", "Word fertility", "# words"]: +def viz_dists(results, tokenizer=None, **kwargs): + distributions = ["Morph. Diff.", "Edit dist.", "# words"] + if tokenizer is not None: + distributions += ["Term fertility", "Word fertility"] + for x in distributions: viz_dist(results, x, **kwargs) @@ -113,14 +115,15 @@ def viz_ova(fr_ova, en_ova): def tag(pred, tagger): poses = [] - stripped_preds = [p.strip() for p in pred["predictions"]] + stripped_preds = [p[0].strip() for p in pred["predictions"]] for doc in tagger.pipe(stripped_preds, batch_size=2048): poses.append([t.pos_ for t in doc]) pred["pos"] = poses -def main(data: Path, pred_path: Path, tokenizer: str = None, output: Path = None, tagger: str = None): +def main(data: Path, pred_path: Path, tokenizer: str = None, output: Path = None, tagger: str = None, + subset: str = "test"): with open(data, "rt") as file: data = json.load(file) @@ -137,15 +140,15 @@ def main(data: Path, pred_path: Path, tokenizer: str = None, output: Path = None tokenizer = AutoTokenizer.from_pretrained(tokenizer, add_prefix_space=True) metrics = pred["metrics"] - viz_f1(data, pred, metrics) - viz_wrong(data, pred, metrics) - results, per_dom, fr_ova, en_ova = gather_results(data, metrics, tokenizer) + viz_f1(data[subset], pred, metrics) + viz_wrong(data[subset], pred, metrics) + results, per_dom, fr_ova, en_ova = gather_results(data[subset], metrics, tokenizer) viz_ova(fr_ova, en_ova) if output is not None: output.mkdir(exist_ok=True) dist_f1(metrics, output) - viz_dists(results, output=output) + viz_dists(results, output=output, tokenizer=tokenizer) if __name__ == "__main__": - CLI(main, description=main.__doc__) \ No newline at end of file + CLI(main, description=main.__doc__)