-
Notifications
You must be signed in to change notification settings - Fork 223
/
plot_summary_comparison.py
158 lines (123 loc) · 5.83 KB
/
plot_summary_comparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import argparse
import json
import os
import string
import sys
from collections import defaultdict
from glob import glob
from pathlib import Path
from typing import List, Tuple
import matplotlib.pyplot as plt
import pandas as pd
"""
To run:
python plot_summary_comparison.py --paths scripts/{method_name}/results/{model_name}
Multiple paths can be provided. The produced plots are outputted to scripts/images/v_{id}/{dataset}.png.
See https://github.com/huggingface/setfit/pull/268#issuecomment-1434549208 for an example of the plots
produced by this script.
"""
def get_sample_sizes(path: str) -> List[str]:
return sorted(list({int(name.split("-")[-2]) for name in glob(f"{path}/*/train-*-0")}))
def get_formatted_ds_metrics(path: str, dataset: str, sample_sizes: List[str]) -> Tuple[str, List[str]]:
split_metrics = defaultdict(list)
for sample_size in sample_sizes:
result_jsons = sorted(glob(os.path.join(path, dataset, f"train-{sample_size}-*", "results.json")))
for result_json in result_jsons:
with open(result_json) as f:
result_dict = json.load(f)
metric_name = result_dict.get("measure", "N/A")
split_metrics[sample_size].append(result_dict["score"])
return metric_name, split_metrics
def plot_summary_comparison(paths: List[str]) -> None:
"""Given a list of paths to output directories produced by e.g. `scripts/setfit/run_fewshot.py`,
produce and save boxplots that compare the various results.
The plots are saved to scripts/images/v_{id}/{dataset}.png, i.e. one plot per dataset.
Args:
paths (List[str]): List of paths to output directories, generally
`scripts/{method_name}/results/{model_name}`
"""
# Parse the result paths
dataset_to_df = defaultdict(pd.DataFrame)
dataset_to_metric = {}
for path_index, path in enumerate(paths):
ds_to_metric, this_dataset_to_df = get_summary_df(path)
for dataset, df in this_dataset_to_df.items():
df["path_index"] = path_index
dataset_to_df[dataset] = pd.concat((dataset_to_df[dataset], df))
dataset_to_metric = dataset_to_metric | ds_to_metric
# Prepare folder for storing figures
image_dir = Path("scripts") / "images"
image_dir.mkdir(exist_ok=True)
new_version = (
max([int(path.name[2:]) for path in image_dir.glob("v_*/") if path.name[2:].isdigit()], default=0) + 1
)
output_dir = image_dir / f"v_{new_version}"
output_dir.mkdir()
# Save a copy the executed command in output directory
(output_dir / "command.txt").write_text("python " + " ".join(sys.argv))
# Create the plots per each dataset
for dataset, df in dataset_to_df.items():
columns = [column for column in df.columns if not column.startswith("path")]
fig, axes = plt.subplots(ncols=len(columns), sharey=True)
for column_index, column in enumerate(columns):
ax = axes[column_index] if len(columns) > 1 else axes
# Set the y label only for the first column
if column_index == 0:
ax.set_ylabel(dataset_to_metric[dataset])
# Set positions to 0, 0.25, ..., one position per boxplot
# This places the boxplots closer together
n_boxplots = len(df["path_index"].unique())
allotted_box_width = 0.2
positions = [allotted_box_width * i for i in range(n_boxplots)]
ax.set_xlim(-allotted_box_width * 0.75, allotted_box_width * (n_boxplots - 0.25))
df[[column, "path_index"]].groupby("path_index", sort=True).boxplot(
subplots=False, ax=ax, column=column, positions=positions
)
k_shot = column.split("-")[-1]
ax.set_xlabel(f"{k_shot}-shot")
if n_boxplots > 1:
# If there are multiple boxplots, override the labels at the bottom generated by pandas
if n_boxplots <= 26:
ax.set_xticklabels(string.ascii_uppercase[:n_boxplots])
else:
ax.set_xticklabels(range(n_boxplots))
else:
# Otherwise, just remove the xticks
ax.tick_params(labelbottom=False)
if n_boxplots > 1:
fig.suptitle(
f"Comparison between various baselines on the {dataset}\ndataset under various $K$-shot conditions"
)
else:
fig.suptitle(f"Results on the {dataset} dataset under various $K$-shot conditions")
fig.tight_layout()
plt.savefig(str(output_dir / dataset))
def get_summary_df(path: str) -> None:
"""Given per-split results, return a mapping from dataset to metrics (e.g. "accuracy") and
a mapping from dataset to pandas DataFrame that stores the results
Args:
path: path to per-split results: generally `scripts/{method_name}/results/{model_name}`,
"""
sample_sizes = get_sample_sizes(path)
header_row = ["dataset", "measure"]
for sample_size in sample_sizes:
header_row.append(f"{sample_size}_avg")
header_row.append(f"{sample_size}_std")
dataset_to_metric = {}
dataset_to_df = {}
for dataset in next(os.walk(path))[1]:
metric_name, split_metrics = get_formatted_ds_metrics(path, dataset, sample_sizes)
dataset_df = pd.DataFrame(split_metrics.values(), index=[f"{dataset}-{key}" for key in split_metrics]).T
dataset_to_metric[dataset] = metric_name
dataset_to_df[dataset] = dataset_df
return dataset_to_metric, dataset_to_df
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--paths", nargs="+", type=str)
args = parser.parse_args()
if args.paths:
plot_summary_comparison(args.paths)
else:
raise Exception("Please provide at least one path via the `--paths` CLI argument.")
if __name__ == "__main__":
main()