From f99eb1cab26633d73356852e0b448cdcf61629ca Mon Sep 17 00:00:00 2001 From: Amit Parekh <7276308+amitkparekh@users.noreply.github.com> Date: Tue, 5 Dec 2023 10:16:25 +0800 Subject: [PATCH] fix types --- src/simbot_offline_inference/metrics/wandb.py | 26 +++---------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/src/simbot_offline_inference/metrics/wandb.py b/src/simbot_offline_inference/metrics/wandb.py index 37b0f34..58a1db4 100644 --- a/src/simbot_offline_inference/metrics/wandb.py +++ b/src/simbot_offline_inference/metrics/wandb.py @@ -5,13 +5,11 @@ import torch import wandb -import yaml from loguru import logger from arena_missions.structures import CDF, MissionTrajectory from emma_experience_hub._version import __version__ as experience_hub_version # noqa: WPS436 from emma_experience_hub.constants import constants_absolute_path -from emma_experience_hub.datamodels.registry import ServiceRegistry from simbot_offline_inference._version import ( # noqa: WPS436 __version__ as offline_inference_version, ) @@ -78,20 +76,6 @@ def finish_trajectory( """Finish running a trajectory.""" raise NotImplementedError - def extract_service_versions_from_registry(self) -> dict[str, str]: - """Get service and model versions from the service registry.""" - service_registry = ServiceRegistry.parse_obj( - yaml.safe_load(SERVICE_REGISTRY_PATH.read_bytes()) - ) - - output_dict = {} - - for service in service_registry.services: - output_dict[f"version/{service.name}"] = service.image_version - output_dict[f"model/{service.name}"] = service.model_url - - return output_dict - class WandBTrajectoryGenerationCallback(WandBCallback): """Track each trajectory as a new run in WandB.""" @@ -129,7 +113,6 @@ def start_trajectory(self, trajectory: MissionTrajectory, preparation_session_id config={ "version/experience_hub": experience_hub_version, "version/offline_inference": offline_inference_version, - **self.extract_service_versions_from_registry(), "session_id": trajectory.session_id, "preparation_session_id": preparation_session_id, # CDF @@ -163,12 +146,12 @@ def start_trajectory(self, trajectory: MissionTrajectory, preparation_session_id # Upload the trajectory results on run completion # According to wandb docs, this command is correct - wandb.save( # type: ignore[call-arg] + wandb.save( str(self.mission_trajectory_outputs_dir.joinpath(f"{trajectory.session_id}.json")), policy="end", ) # Also upload the unity logs - wandb.save(str(self._unity_logs), policy="end") # type: ignore[call-arg] + wandb.save(str(self._unity_logs), policy="end") def finish_trajectory( self, @@ -216,14 +199,11 @@ def start_evaluation(self, *, resume: bool = False) -> None: config={ "version/experience_hub": experience_hub_version, "version/offline_inference": offline_inference_version, - **self.extract_service_versions_from_registry(), }, ) # Upload the trajectory results on run completion - wandb.save( # type: ignore[call-arg] - str(self.mission_trajectory_outputs_dir), policy="end" - ) + wandb.save(str(self.mission_trajectory_outputs_dir), policy="end") # Also upload the unity logs wandb.save(str(self._unity_logs))