Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Oct 20, 2024
1 parent 776d4ca commit b68ca9b
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions runner/eval_thresholds.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,52 @@
#!docker run --rm --user root -v /opt/runner:/o:ro synerbi/sirf:ci python
import logging
from collections import defaultdict
from pathlib import Path
from pathlib import Path, PurePath

import numpy as np
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator, SCALARS
from tensorboard.backend.event_processing.event_accumulator import SCALARS, EventAccumulator

from petric import QualityMetrics

log = logging.getLogger(Path(__file__).stem)
LOGDIR = Path("/o/logs")
TAGS = {"RMSE_whole_object", "RMSE_background", "AEM_VOI"}
assert set(QualityMetrics.THRESHOLD.keys()) == TAGS


def scalars(ea: EventAccumulator, tag: str):
steps = [s.step for s in ea.Scalars(tag)]
assert steps == sorted(steps)
return [(scalar.value, scalar.wall_time) for scalar in ea.Scalars(tag)]

def valid(tensorboard_logfile: str) -> bool:

def valid(tensorboard_logfile: PurePath) -> bool:
ea = EventAccumulator(str(tensorboard_logfile), size_guidance={SCALARS: 0})
ea.Reload()
return len({"RMSE_whole_object", "RMSE_background"}.intersection(ea.Tags()['scalars'])) == 2

def pass_time(tensorboard_logfile: str) -> float:

def pass_time(tensorboard_logfile: PurePath) -> float:
ea = EventAccumulator(str(tensorboard_logfile), size_guidance={SCALARS: 0})
ea.Reload()

try:
start = ea.Scalars("reset")[0]
except KeyError:
log.error("KeyError: reset: not using accurate relative time")
start = None
else:
assert start.value == 0
assert start.step == -1
start = start.wall_time

tags = {tag for tag in ea.Tags()['scalars'] if any(tag.startswith(i) for i in TAGS)}
tags = {tag: scalars(ea, tag) for tag in tags}
tag_names = {tag for tag in ea.Tags()['scalars'] if any(tag.startswith(i) for i in TAGS)}
tags = {tag: scalars(ea, tag) for tag in tag_names}

metrics = [tags.pop("RMSE_whole_object"), tags.pop("RMSE_background")]
thresholds = [
QualityMetrics.THRESHOLD["RMSE_whole_object"],
QualityMetrics.THRESHOLD["RMSE_background"],
] + [QualityMetrics.THRESHOLD["AEM_VOI"]] * len(tags)
QualityMetrics.THRESHOLD["RMSE_background"],] + [QualityMetrics.THRESHOLD["AEM_VOI"]] * len(tags)
metrics.extend(tags.values())
metrics = np.array(metrics).T # [(value, time), step, (RMSE, RMSE, VOI, ...)]

Expand All @@ -51,40 +56,48 @@ def pass_time(tensorboard_logfile: str) -> float:
return np.inf
return metrics[1][i] - (start or metrics[1][0])


if __name__ == '__main__':
timings = defaultdict(list) # {"dataset.name": [(time: float, "algo"), ...], ...}
logging.basicConfig(level=logging.INFO)
# {"dataset.name": [(time, "algo"), ...], ...}
timings: dict[str, list[tuple[float, str]]] = defaultdict(list)

for team in LOGDIR.glob("*/"):
if team.name == '0_THRESHOLDS':
continue
for algo in team.glob("*/"):
for dataset in algo.glob("*/"):
t = np.median([pass_time(logfile) for logfile in dataset.glob("events.out.tfevents.*") if valid(logfile)])
for logfile in dataset.glob("events.out.tfevents.*"):
if not valid(logfile):
log.warning("rm %s", logfile)
# logfile.unlink()
t = np.median([
pass_time(logfile) for logfile in dataset.glob("events.out.tfevents.*") if valid(logfile)])
timings[dataset.name].append((t, f"{team.name}/{algo.name}"))

algos = {algo for time_algos in timings.values() for _, algo in time_algos}
for time_algos in timings.values():
missing = algos - {algo for _, algo in time_algos}
time_algos.extend((np.inf, algo) for algo in missing)

ranks = defaultdict(int)
for dataset, time_algos in timings.items():
ranks: dict[str, int] = defaultdict(int)
for dataset_name, time_algos in timings.items():
time_algos.sort()
print(dataset)
print("=" * len(dataset))
print(dataset_name)
print("=" * len(dataset_name))
N = len(time_algos)
rank = 1
for t, algo in time_algos:
for t, algo_name in time_algos:
if np.isposinf(t):
_rank = N
else:
_rank = rank
rank += 1
print(f"{_rank}: {algo}")
ranks[algo] += _rank
print(f"{_rank}: {algo_name}")
ranks[algo_name] += _rank
print("\n")

print("Leaderboard")
print("===========")
for algo, _ in sorted(ranks.items(), key=lambda algo_rank: algo_rank[1]):
print(algo)
for algo_name, _ in sorted(ranks.items(), key=lambda algo_rank: algo_rank[1]):
print(algo_name)

0 comments on commit b68ca9b

Please sign in to comment.