From c71d09e4dadb5fc8576ba1dddbee38c2258250f3 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Thu, 10 Oct 2024 10:45:48 +0000 Subject: [PATCH] add some comments & debugging --- runner/eval_thresholds.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/runner/eval_thresholds.py b/runner/eval_thresholds.py index ee01e92..449b673 100644 --- a/runner/eval_thresholds.py +++ b/runner/eval_thresholds.py @@ -14,33 +14,36 @@ assert set(QualityMetrics.THRESHOLD.keys()) == TAGS -def scalars(ea: EventAccumulator, tag: str): +def scalars(ea: EventAccumulator, tag: str) -> list[tuple[float, float]]: + """[(value, time), ...]""" steps = [s.step for s in ea.Scalars(tag)] assert steps == sorted(steps) return [(scalar.value, scalar.wall_time) for scalar in ea.Scalars(tag)] def valid(tensorboard_logfile: PurePath) -> bool: + """False if invalid/empty logfile""" ea = EventAccumulator(str(tensorboard_logfile), size_guidance={SCALARS: 0}) ea.Reload() return len({"RMSE_whole_object", "RMSE_background"}.intersection(ea.Tags()['scalars'])) == 2 def pass_time(tensorboard_logfile: PurePath) -> float: + """time at which thresholds were met (minus any metrics calculation time offset)""" ea = EventAccumulator(str(tensorboard_logfile), size_guidance={SCALARS: 0}) ea.Reload() try: start = ea.Scalars("reset")[0] except KeyError: - log.error("KeyError: reset: not using accurate relative time") + log.error("KeyError: reset: not using accurate relative time for %s", tensorboard_logfile.relative_to(LOGDIR)) start = None else: assert start.value == 0 assert start.step == -1 start = start.wall_time - tag_names = {tag for tag in ea.Tags()['scalars'] if any(tag.startswith(i) for i in TAGS)} + tag_names: set[str] = {tag for tag in ea.Tags()['scalars'] if any(tag.startswith(i) for i in TAGS)} tags = {tag: scalars(ea, tag) for tag in tag_names} metrics = [tags.pop("RMSE_whole_object"), tags.pop("RMSE_background")] @@ -62,6 +65,7 @@ def pass_time(tensorboard_logfile: PurePath) -> float: # {"dataset.name": [(time, "algo"), ...], ...} timings: dict[str, list[tuple[float, str]]] = defaultdict(list) + # LOGDIR / "team" / "algo" / "dataset" / "events.out.tfevents.*" for team in LOGDIR.glob("*/"): if team.name == '0_THRESHOLDS': continue @@ -75,11 +79,15 @@ def pass_time(tensorboard_logfile: PurePath) -> float: pass_time(logfile) for logfile in dataset.glob("events.out.tfevents.*") if valid(logfile)]) timings[dataset.name].append((t, f"{team.name}/{algo.name}")) - algos = {algo for time_algos in timings.values() for _, algo in time_algos} - for time_algos in timings.values(): - missing = algos - {algo for _, algo in time_algos} - time_algos.extend((np.inf, algo) for algo in missing) + # insert `time=np.inf` for each team's missing algos + algos = {algo_name for time_algos in timings.values() for _, algo_name in time_algos} + for dataset_name, time_algos in timings.items(): + missing = algos - {algo_name for _, algo_name in time_algos} + for algo_name in missing: + log.error("FileNotFoundError: logfile for %s/%s", algo_name, dataset_name) + time_algos.extend((np.inf, algo_name) for algo_name in missing) + # calculate ranks ranks: dict[str, int] = defaultdict(int) for dataset_name, time_algos in timings.items(): time_algos.sort()