forked from facebookresearch/demucs
-
Notifications
You must be signed in to change notification settings - Fork 1
/
valid_table.py
63 lines (53 loc) · 1.9 KB
/
valid_table.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
from collections import defaultdict
from pathlib import Path
import numpy as np
import treetable as tt
LOGS = Path("logs")
STD_KEY = "seed"
METRIC = "best"
parser = argparse.ArgumentParser("result_table.py")
parser.add_argument("-p",
"--paper",
action="store_true",
help="show results from the paper experiment")
parser.add_argument("-i", "--individual", action="store_true", help="no aggregation by seed")
args = parser.parse_args()
if args.paper:
LOGS = Path("results/logs")
all_stats = defaultdict(list)
for path in LOGS.iterdir():
if path.suffix == ".json" and (args.paper or path.with_suffix(".done").exists()):
metric = json.load(open(path))[-1][METRIC]
name = path.stem
model = "Demucs"
if "tasnet" in name:
model = "Tasnet"
if name == "default":
parts = []
else:
parts = [p.split("=") for p in name.split(" ") if "tasnet" not in p]
if not args.individual:
parts = [(k, v) for k, v in parts if k != STD_KEY]
name = model + " " + " ".join(f"{k}={v}" for k, v in parts)
all_stats[name].append(metric)
metrics = [tt.leaf("score", ".4f"), tt.leaf("std", ".3f"), tt.leaf("count", ".2f")]
mytable = tt.table([tt.leaf("name"), tt.group("valid", metrics)])
lines = []
for name, stats in all_stats.items():
line = {"name": name}
stats = np.array(stats)
line["valid"] = {
"score": stats.mean(),
"std": stats.std() / stats.shape[0]**0.5,
"count": stats.shape[0]
}
lines.append(line)
lines.sort(key=lambda x: x["valid"]["score"])
print(tt.treetable(lines, mytable, colors=['33', '0']))