Skip to content

Commit

Permalink
feat: viz.analyze: specify subset + fix predictions and viz_dists
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulLerner committed Feb 12, 2024
1 parent 4cccbd5 commit 60101e0
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions neot/viz/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@
from ..morph.labels import MorphLabel



def viz_f1(data, pred, metrics):
for i, item in enumerate(data["train"][:100]):
for i, item in enumerate(data[:100]):
f1 = metrics["f1s"][i]
if f1 > 0.0 and f1 < 1.0:
print(f1, pred['predictions'][i], item["fr"]["text"])
print(f1, pred['predictions'][i][0], item["fr"]["text"])


def viz_wrong(data, pred, metrics):
for i, item in enumerate(data["train"][:100]):
for i, item in enumerate(data[:100]):
f1 = metrics["f1s"][i]
if f1 <= 0.0:
print(f1, pred['predictions'][i], item["fr"]["text"])
print(f1, pred['predictions'][i][0], item["fr"]["text"])


def dist_f1(metrics, output):
Expand All @@ -43,7 +42,7 @@ def gather_results(data, metrics, tokenizer):
fr_ova = {c.name: {True: [], False: []} for c in MorphLabel}
en_ova = {c.name: {True: [], False: []} for c in MorphLabel}
per_dom = []
for i, item in enumerate(data["train"]):
for i, item in enumerate(data):
p_fr = item["fr"]["morph_label"]
p_en = item["en"]["morph_label"]
em = metrics["ems"][i]
Expand Down Expand Up @@ -101,8 +100,11 @@ def viz_dist(results, x, output):
fig.savefig(output / f"{x}_wrt_EM_dist.pdf")


def viz_dists(results, **kwargs):
for x in ["Morph. Diff.", "Edit dist.", "Term fertility", "Word fertility", "# words"]:
def viz_dists(results, tokenizer=None, **kwargs):
distributions = ["Morph. Diff.", "Edit dist.", "# words"]
if tokenizer is not None:
distributions += ["Term fertility", "Word fertility"]
for x in distributions:
viz_dist(results, x, **kwargs)


Expand All @@ -113,14 +115,15 @@ def viz_ova(fr_ova, en_ova):

def tag(pred, tagger):
poses = []
stripped_preds = [p.strip() for p in pred["predictions"]]
stripped_preds = [p[0].strip() for p in pred["predictions"]]
for doc in tagger.pipe(stripped_preds, batch_size=2048):
poses.append([t.pos_ for t in doc])

pred["pos"] = poses


def main(data: Path, pred_path: Path, tokenizer: str = None, output: Path = None, tagger: str = None):
def main(data: Path, pred_path: Path, tokenizer: str = None, output: Path = None, tagger: str = None,
subset: str = "test"):
with open(data, "rt") as file:
data = json.load(file)

Expand All @@ -137,15 +140,15 @@ def main(data: Path, pred_path: Path, tokenizer: str = None, output: Path = None
tokenizer = AutoTokenizer.from_pretrained(tokenizer, add_prefix_space=True)

metrics = pred["metrics"]
viz_f1(data, pred, metrics)
viz_wrong(data, pred, metrics)
results, per_dom, fr_ova, en_ova = gather_results(data, metrics, tokenizer)
viz_f1(data[subset], pred, metrics)
viz_wrong(data[subset], pred, metrics)
results, per_dom, fr_ova, en_ova = gather_results(data[subset], metrics, tokenizer)
viz_ova(fr_ova, en_ova)
if output is not None:
output.mkdir(exist_ok=True)
dist_f1(metrics, output)
viz_dists(results, output=output)
viz_dists(results, output=output, tokenizer=tokenizer)


if __name__ == "__main__":
CLI(main, description=main.__doc__)
CLI(main, description=main.__doc__)

0 comments on commit 60101e0

Please sign in to comment.