diff --git a/metaworld/__init__.py b/metaworld/__init__.py index d95bc5ef3..5e3a499cb 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -401,7 +401,7 @@ def make_mt_envs( max_episode_steps=max_episode_steps, use_one_hot=use_one_hot, env_id=env_id, - num_tasks=num_tasks, + num_tasks=num_tasks or 1, terminate_on_success=terminate_on_success, ) elif name == "MT10" or name == "MT50": @@ -409,6 +409,7 @@ def make_mt_envs( vectorizer: type[gym.vector.VectorEnv] = getattr( gym.vector, f"{vector_strategy.capitalize()}VectorEnv" ) + default_num_tasks = 10 if name == "MT10" else 50 return vectorizer( # type: ignore [ partial( @@ -421,7 +422,7 @@ def make_mt_envs( max_episode_steps=max_episode_steps, use_one_hot=use_one_hot, env_id=env_id, - num_tasks=num_tasks, + num_tasks=num_tasks or default_num_tasks, terminate_on_success=terminate_on_success, task_select=task_select, ) @@ -457,17 +458,16 @@ def _make_ml_envs_inner( tasks_per_env = meta_batch_size // len(all_classes) env_tuples = [] - # TODO figure out how to expose task names for eval - # task_names = [] for env_name, env_cls in all_classes.items(): tasks = [task for task in all_tasks if task.env_name == env_name] if total_tasks_per_cls is not None: tasks = tasks[:total_tasks_per_cls] subenv_tasks = [tasks[i::tasks_per_env] for i in range(0, tasks_per_env)] for tasks_for_subenv in subenv_tasks: - assert len(tasks_for_subenv) == len(tasks) // tasks_per_env + assert ( + len(tasks_for_subenv) == len(tasks) // tasks_per_env + ), f"Invalid division of subtasks, expected {len(tasks) // tasks_per_env} got {len(tasks_for_subenv)}" env_tuples.append((env_cls, tasks_for_subenv)) - # task_names.append(env_name) vectorizer: type[gym.vector.VectorEnv] = getattr( gym.vector, f"{vector_strategy.capitalize()}VectorEnv" diff --git a/metaworld/evaluation.py b/metaworld/evaluation.py index fcb3392d2..21b97270a 100644 --- a/metaworld/evaluation.py +++ b/metaworld/evaluation.py @@ -10,7 +10,7 @@ class Agent(Protocol): - def get_action_eval( + def eval_action( self, obs: npt.NDArray[np.float64] ) -> tuple[npt.NDArray[np.float64], dict[str, npt.NDArray]]: ... @@ -51,7 +51,7 @@ def eval_done(returns): return all(len(r) >= num_episodes for _, r in returns.items()) while not eval_done(episodic_returns): - actions, _ = agent.get_action_eval(obs) + actions, _ = agent.eval_action(obs) obs, _, terminations, truncations, infos = eval_envs.step(actions) for i, env_ended in enumerate(np.logical_or(terminations, truncations)): if env_ended: @@ -108,7 +108,7 @@ def metalearning_evaluation( for _ in range(adaptation_steps): while not eval_buffer.ready: - actions, aux_policy_outs = agent.get_action_eval(obs) + actions, aux_policy_outs = agent.eval_action(obs) next_obs: npt.NDArray[np.float64] rewards: npt.NDArray[np.float64] next_obs, rewards, terminations, truncations, _ = eval_envs.step( @@ -157,14 +157,14 @@ class Rollout(NamedTuple): rewards: npt.NDArray dones: npt.NDArray - # Auxilary polcy outputs + # Auxiliary policy outputs log_probs: npt.NDArray | None = None means: npt.NDArray | None = None stds: npt.NDArray | None = None class _MultiTaskRolloutBuffer: - """A buffer to accumulate rollouts for multple tasks. + """A buffer to accumulate rollouts for multiple tasks. Useful for ML1, ML10, ML45, or on-policy MTRL algorithms. In Metaworld, all episodes are as long as the time limit (typically 500), thus in this buffer we assume diff --git a/metaworld/policies/__init__.py b/metaworld/policies/__init__.py index adf09a36f..37a2a1c6c 100644 --- a/metaworld/policies/__init__.py +++ b/metaworld/policies/__init__.py @@ -75,56 +75,56 @@ ENV_POLICY_MAP = dict( { - "assembly-V3": SawyerAssemblyV3Policy, - "basketball-V3": SawyerBasketballV3Policy, - "bin-picking-V3": SawyerBinPickingV3Policy, - "box-close-V3": SawyerBoxCloseV3Policy, - "button-press-topdown-V3": SawyerButtonPressTopdownV3Policy, - "button-press-topdown-wall-V3": SawyerButtonPressTopdownWallV3Policy, - "button-press-V3": SawyerButtonPressV3Policy, - "button-press-wall-V3": SawyerButtonPressWallV3Policy, - "coffee-button-V3": SawyerCoffeeButtonV3Policy, - "coffee-pull-V3": SawyerCoffeePullV3Policy, - "coffee-push-V3": SawyerCoffeePushV3Policy, - "dial-turn-V3": SawyerDialTurnV3Policy, - "disassemble-V3": SawyerDisassembleV3Policy, - "door-close-V3": SawyerDoorCloseV3Policy, - "door-lock-V3": SawyerDoorLockV3Policy, - "door-open-V3": SawyerDoorOpenV3Policy, - "door-unlock-V3": SawyerDoorUnlockV3Policy, - "drawer-close-V3": SawyerDrawerCloseV3Policy, - "drawer-open-V3": SawyerDrawerOpenV3Policy, - "faucet-close-V3": SawyerFaucetCloseV3Policy, - "faucet-open-V3": SawyerFaucetOpenV3Policy, - "hammer-V3": SawyerHammerV3Policy, - "hand-insert-V3": SawyerHandInsertV3Policy, - "handle-press-side-V3": SawyerHandlePressSideV3Policy, - "handle-press-V3": SawyerHandlePressV3Policy, - "handle-pull-V3": SawyerHandlePullV3Policy, - "handle-pull-side-V3": SawyerHandlePullSideV3Policy, - "peg-insert-side-V3": SawyerPegInsertionSideV3Policy, - "lever-pull-V3": SawyerLeverPullV3Policy, - "peg-unplug-side-V3": SawyerPegUnplugSideV3Policy, - "pick-out-of-hole-V3": SawyerPickOutOfHoleV3Policy, - "pick-place-V3": SawyerPickPlaceV3Policy, - "pick-place-wall-V3": SawyerPickPlaceWallV3Policy, - "plate-slide-back-side-V3": SawyerPlateSlideBackSideV3Policy, - "plate-slide-back-V3": SawyerPlateSlideBackV3Policy, - "plate-slide-side-V3": SawyerPlateSlideSideV3Policy, - "plate-slide-V3": SawyerPlateSlideV3Policy, - "reach-V3": SawyerReachV3Policy, - "reach-wall-V3": SawyerReachWallV3Policy, - "push-back-V3": SawyerPushBackV3Policy, - "push-V3": SawyerPushV3Policy, - "push-wall-V3": SawyerPushWallV3Policy, - "shelf-place-V3": SawyerShelfPlaceV3Policy, - "soccer-V3": SawyerSoccerV3Policy, - "stick-pull-V3": SawyerStickPullV3Policy, - "stick-push-V3": SawyerStickPushV3Policy, - "sweep-into-V3": SawyerSweepIntoV3Policy, - "sweep-V3": SawyerSweepV3Policy, - "window-close-V3": SawyerWindowCloseV3Policy, - "window-open-V3": SawyerWindowOpenV3Policy, + "assembly-v3": SawyerAssemblyV3Policy, + "basketball-v3": SawyerBasketballV3Policy, + "bin-picking-v3": SawyerBinPickingV3Policy, + "box-close-v3": SawyerBoxCloseV3Policy, + "button-press-topdown-v3": SawyerButtonPressTopdownV3Policy, + "button-press-topdown-wall-v3": SawyerButtonPressTopdownWallV3Policy, + "button-press-v3": SawyerButtonPressV3Policy, + "button-press-wall-v3": SawyerButtonPressWallV3Policy, + "coffee-button-v3": SawyerCoffeeButtonV3Policy, + "coffee-pull-v3": SawyerCoffeePullV3Policy, + "coffee-push-v3": SawyerCoffeePushV3Policy, + "dial-turn-v3": SawyerDialTurnV3Policy, + "disassemble-v3": SawyerDisassembleV3Policy, + "door-close-v3": SawyerDoorCloseV3Policy, + "door-lock-v3": SawyerDoorLockV3Policy, + "door-open-v3": SawyerDoorOpenV3Policy, + "door-unlock-v3": SawyerDoorUnlockV3Policy, + "drawer-close-v3": SawyerDrawerCloseV3Policy, + "drawer-open-v3": SawyerDrawerOpenV3Policy, + "faucet-close-v3": SawyerFaucetCloseV3Policy, + "faucet-open-v3": SawyerFaucetOpenV3Policy, + "hammer-v3": SawyerHammerV3Policy, + "hand-insert-v3": SawyerHandInsertV3Policy, + "handle-press-side-v3": SawyerHandlePressSideV3Policy, + "handle-press-v3": SawyerHandlePressV3Policy, + "handle-pull-v3": SawyerHandlePullV3Policy, + "handle-pull-side-v3": SawyerHandlePullSideV3Policy, + "peg-insert-side-v3": SawyerPegInsertionSideV3Policy, + "lever-pull-v3": SawyerLeverPullV3Policy, + "peg-unplug-side-v3": SawyerPegUnplugSideV3Policy, + "pick-out-of-hole-v3": SawyerPickOutOfHoleV3Policy, + "pick-place-v3": SawyerPickPlaceV3Policy, + "pick-place-wall-v3": SawyerPickPlaceWallV3Policy, + "plate-slide-back-side-v3": SawyerPlateSlideBackSideV3Policy, + "plate-slide-back-v3": SawyerPlateSlideBackV3Policy, + "plate-slide-side-v3": SawyerPlateSlideSideV3Policy, + "plate-slide-v3": SawyerPlateSlideV3Policy, + "reach-v3": SawyerReachV3Policy, + "reach-wall-v3": SawyerReachWallV3Policy, + "push-back-v3": SawyerPushBackV3Policy, + "push-v3": SawyerPushV3Policy, + "push-wall-v3": SawyerPushWallV3Policy, + "shelf-place-v3": SawyerShelfPlaceV3Policy, + "soccer-v3": SawyerSoccerV3Policy, + "stick-pull-v3": SawyerStickPullV3Policy, + "stick-push-v3": SawyerStickPushV3Policy, + "sweep-into-v3": SawyerSweepIntoV3Policy, + "sweep-v3": SawyerSweepV3Policy, + "window-close-v3": SawyerWindowCloseV3Policy, + "window-open-v3": SawyerWindowOpenV3Policy, } ) diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index 99cee7545..2128b9914 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any - import gymnasium as gym import numpy as np from gymnasium import Env @@ -35,7 +33,7 @@ class RandomTaskSelectWrapper(gym.Wrapper): """A Gymnasium Wrapper to automatically set / reset the environment to a random task.""" - tasks: List[Task] + tasks: list[Task] sample_tasks_on_reset: bool = True def _set_random_task(self): @@ -45,7 +43,7 @@ def _set_random_task(self): def __init__( self, env: Env, - tasks: List[Task], + tasks: list[Task], sample_tasks_on_reset: bool = True, ): super().__init__(env) @@ -55,16 +53,12 @@ def __init__( def toggle_sample_tasks_on_reset(self, on: bool): self.sample_tasks_on_reset = on - def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ): + def reset(self, *, seed: int | None = None, options: dict | None = None): if self.sample_tasks_on_reset: self._set_random_task() return self.env.reset(seed=seed, options=options) - def sample_tasks( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ): + def sample_tasks(self, *, seed: int | None = None, options: dict | None = None): self._set_random_task() return self.env.reset(seed=seed, options=options) @@ -102,16 +96,12 @@ def __init__( self.tasks = tasks self.current_task_idx = -1 - def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ): + def reset(self, *, seed: int | None = None, options: dict | None = None): if self.sample_tasks_on_reset: self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) - def sample_tasks( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ): + def sample_tasks(self, *, seed: int | None = None, options: dict | None = None): self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) diff --git a/tests/metaworld/test_evaluation.py b/tests/metaworld/test_evaluation.py index 52c15d99f..9ef630f45 100644 --- a/tests/metaworld/test_evaluation.py +++ b/tests/metaworld/test_evaluation.py @@ -12,7 +12,7 @@ from metaworld.policies import ENV_POLICY_MAP -class ScriptedPolicyAgent: +class ScriptedPolicyAgent(evaluation.MetaLearningAgent): def __init__( self, envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv, @@ -25,7 +25,7 @@ def __init__( self.max_episode_steps = max_episode_steps self.adapt_calls = 0 - def get_action_eval( + def eval_action( self, obs: npt.NDArray[np.float64] ) -> tuple[npt.NDArray[np.float64], dict[str, npt.NDArray]]: actions: list[npt.NDArray[np.float32]] = [] diff --git a/tests/metaworld/test_gym_make.py b/tests/metaworld/test_gym_make.py index 6fe0e2c9a..33ee92119 100644 --- a/tests/metaworld/test_gym_make.py +++ b/tests/metaworld/test_gym_make.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random from typing import Literal @@ -159,7 +161,11 @@ def test_ml_benchmarks( vector_strategy: str, ): meta_batch_size = 20 if benchmark != "ML45" else 45 - total_tasks_per_cls = _N_GOALS if benchmark != "ML45" else 45 + total_tasks_per_cls = _N_GOALS + if benchmark == "ML45": + total_tasks_per_cls = 45 + elif benchmark == "ML10" and split == "test": + total_tasks_per_cls = 40 max_episode_steps = 10 envs = gym.make_vec(