Skip to content

Commit

Permalink
fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
amitkparekh committed Dec 5, 2023
1 parent 7470524 commit f99eb1c
Showing 1 changed file with 3 additions and 23 deletions.
26 changes: 3 additions & 23 deletions src/simbot_offline_inference/metrics/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

import torch
import wandb
import yaml
from loguru import logger

from arena_missions.structures import CDF, MissionTrajectory
from emma_experience_hub._version import __version__ as experience_hub_version # noqa: WPS436
from emma_experience_hub.constants import constants_absolute_path
from emma_experience_hub.datamodels.registry import ServiceRegistry
from simbot_offline_inference._version import ( # noqa: WPS436
__version__ as offline_inference_version,
)
Expand Down Expand Up @@ -78,20 +76,6 @@ def finish_trajectory(
"""Finish running a trajectory."""
raise NotImplementedError

def extract_service_versions_from_registry(self) -> dict[str, str]:
"""Get service and model versions from the service registry."""
service_registry = ServiceRegistry.parse_obj(
yaml.safe_load(SERVICE_REGISTRY_PATH.read_bytes())
)

output_dict = {}

for service in service_registry.services:
output_dict[f"version/{service.name}"] = service.image_version
output_dict[f"model/{service.name}"] = service.model_url

return output_dict


class WandBTrajectoryGenerationCallback(WandBCallback):
"""Track each trajectory as a new run in WandB."""
Expand Down Expand Up @@ -129,7 +113,6 @@ def start_trajectory(self, trajectory: MissionTrajectory, preparation_session_id
config={
"version/experience_hub": experience_hub_version,
"version/offline_inference": offline_inference_version,
**self.extract_service_versions_from_registry(),
"session_id": trajectory.session_id,
"preparation_session_id": preparation_session_id,
# CDF
Expand Down Expand Up @@ -163,12 +146,12 @@ def start_trajectory(self, trajectory: MissionTrajectory, preparation_session_id

# Upload the trajectory results on run completion
# According to wandb docs, this command is correct
wandb.save( # type: ignore[call-arg]
wandb.save(
str(self.mission_trajectory_outputs_dir.joinpath(f"{trajectory.session_id}.json")),
policy="end",
)
# Also upload the unity logs
wandb.save(str(self._unity_logs), policy="end") # type: ignore[call-arg]
wandb.save(str(self._unity_logs), policy="end")

def finish_trajectory(
self,
Expand Down Expand Up @@ -216,14 +199,11 @@ def start_evaluation(self, *, resume: bool = False) -> None:
config={
"version/experience_hub": experience_hub_version,
"version/offline_inference": offline_inference_version,
**self.extract_service_versions_from_registry(),
},
)

# Upload the trajectory results on run completion
wandb.save( # type: ignore[call-arg]
str(self.mission_trajectory_outputs_dir), policy="end"
)
wandb.save(str(self.mission_trajectory_outputs_dir), policy="end")

# Also upload the unity logs
wandb.save(str(self._unity_logs))
Expand Down

0 comments on commit f99eb1c

Please sign in to comment.