Skip to content

Commit

Permalink
add some comments & debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Oct 20, 2024
1 parent 69753fe commit 7af9e12
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions runner/eval_thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,36 @@
assert set(QualityMetrics.THRESHOLD.keys()) == TAGS


def scalars(ea: EventAccumulator, tag: str):
def scalars(ea: EventAccumulator, tag: str) -> list[tuple[float, float]]:
"""[(value, time), ...]"""
steps = [s.step for s in ea.Scalars(tag)]
assert steps == sorted(steps)
return [(scalar.value, scalar.wall_time) for scalar in ea.Scalars(tag)]


def valid(tensorboard_logfile: PurePath) -> bool:
"""False if invalid/empty logfile"""
ea = EventAccumulator(str(tensorboard_logfile), size_guidance={SCALARS: 0})
ea.Reload()
return len({"RMSE_whole_object", "RMSE_background"}.intersection(ea.Tags()['scalars'])) == 2


def pass_time(tensorboard_logfile: PurePath) -> float:
"""time at which thresholds were met (minus any metrics calculation time offset)"""
ea = EventAccumulator(str(tensorboard_logfile), size_guidance={SCALARS: 0})
ea.Reload()

try:
start = ea.Scalars("reset")[0]
except KeyError:
log.error("KeyError: reset: not using accurate relative time")
log.error("KeyError: reset: not using accurate relative time for %s", tensorboard_logfile.relative_to(LOGDIR))
start = None
else:
assert start.value == 0
assert start.step == -1
start = start.wall_time

tag_names = {tag for tag in ea.Tags()['scalars'] if any(tag.startswith(i) for i in TAGS)}
tag_names: set[str] = {tag for tag in ea.Tags()['scalars'] if any(tag.startswith(i) for i in TAGS)}
tags = {tag: scalars(ea, tag) for tag in tag_names}

metrics = [tags.pop("RMSE_whole_object"), tags.pop("RMSE_background")]
Expand All @@ -62,6 +65,7 @@ def pass_time(tensorboard_logfile: PurePath) -> float:
# {"dataset.name": [(time, "algo"), ...], ...}
timings: dict[str, list[tuple[float, str]]] = defaultdict(list)

# LOGDIR / "team" / "algo" / "dataset" / "events.out.tfevents.*"
for team in LOGDIR.glob("*/"):
if team.name == '0_THRESHOLDS':
continue
Expand All @@ -75,11 +79,15 @@ def pass_time(tensorboard_logfile: PurePath) -> float:
pass_time(logfile) for logfile in dataset.glob("events.out.tfevents.*") if valid(logfile)])
timings[dataset.name].append((t, f"{team.name}/{algo.name}"))

algos = {algo for time_algos in timings.values() for _, algo in time_algos}
for time_algos in timings.values():
missing = algos - {algo for _, algo in time_algos}
time_algos.extend((np.inf, algo) for algo in missing)
# insert `time=np.inf` for each team's missing algos
algos = {algo_name for time_algos in timings.values() for _, algo_name in time_algos}
for dataset_name, time_algos in timings.items():
missing = algos - {algo_name for _, algo_name in time_algos}
for algo_name in missing:
log.error("FileNotFoundError: logfile for %s/%s", algo_name, dataset_name)
time_algos.extend((np.inf, algo_name) for algo_name in missing)

# calculate ranks
ranks: dict[str, int] = defaultdict(int)
for dataset_name, time_algos in timings.items():
time_algos.sort()
Expand Down

0 comments on commit 7af9e12

Please sign in to comment.