diff --git a/diambra/arena/sheeprl/make_sheeprl_env.py b/diambra/arena/sheeprl/make_sheeprl_env.py index e7da196..d567f9e 100644 --- a/diambra/arena/sheeprl/make_sheeprl_env.py +++ b/diambra/arena/sheeprl/make_sheeprl_env.py @@ -61,18 +61,36 @@ def thunk() -> gym.Env: instantiate_kwargs["rank"] = rank + vector_env_idx env = hydra.utils.instantiate(cfg.env.wrapper, **instantiate_kwargs) + if not ( + isinstance(cfg.algo.mlp_keys.encoder, list) + and isinstance(cfg.algo.cnn_keys.encoder, list) + and len(cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder) > 0 + ): + raise ValueError( + "`algo.cnn_keys.encoder` and `algo.mlp_keys.encoder` must be lists of strings, got: " + f"cnn encoder keys `{cfg.algo.cnn_keys.encoder}` of type `{type(cfg.algo.cnn_keys.encoder)}` " + f"and mlp encoder keys `{cfg.algo.mlp_keys.encoder}` of type `{type(cfg.algo.mlp_keys.encoder)}`. " + "Both must be non-empty lists." + ) + + if ( + len( + set(k for k in env.observation_space.keys()).intersection( + set(cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder) + ) + ) + == 0 + ): + raise ValueError( + f"The user specified keys `{cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder}` " + "are not a subset of the " + f"environment `{env.observation_space.keys()}` observation keys. Please check your config file." + ) + env_cnn_keys = set( - [ - k - for k in env.observation_space.spaces.keys() - if len(env.observation_space[k].shape) in {2, 3} - ] + [k for k in env.observation_space.spaces.keys() if len(env.observation_space[k].shape) in {2, 3}] ) - if cfg.cnn_keys.encoder is None: - user_cnn_keys = set() - else: - user_cnn_keys = set(cfg.cnn_keys.encoder) - cnn_keys = env_cnn_keys.intersection(user_cnn_keys) + cnn_keys = env_cnn_keys.intersection(set(cfg.algo.cnn_keys.encoder)) def transform_obs(obs: Dict[str, Any]): for k in cnn_keys: @@ -93,9 +111,7 @@ def transform_obs(obs: Dict[str, Any]): # resize if current_obs.shape[:-1] != (cfg.env.screen_size, cfg.env.screen_size): current_obs = cv2.resize( - current_obs, - (cfg.env.screen_size, cfg.env.screen_size), - interpolation=cv2.INTER_AREA, + current_obs, (cfg.env.screen_size, cfg.env.screen_size), interpolation=cv2.INTER_AREA ) # to grayscale @@ -116,14 +132,7 @@ def transform_obs(obs: Dict[str, Any]): env = gym.wrappers.TransformObservation(env, transform_obs) for k in cnn_keys: env.observation_space[k] = gym.spaces.Box( - 0, - 255, - ( - 1 if cfg.env.grayscale else 3, - cfg.env.screen_size, - cfg.env.screen_size, - ), - np.uint8, + 0, 255, (1 if cfg.env.grayscale else 3, cfg.env.screen_size, cfg.env.screen_size), np.uint8 ) if cnn_keys is not None and len(cnn_keys) > 0 and cfg.env.frame_stack > 1: @@ -131,9 +140,7 @@ def transform_obs(obs: Dict[str, Any]): raise ValueError( f"The frame stack dilation argument must be greater than zero, got: {cfg.env.frame_stack_dilation}" ) - env = FrameStack( - env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation - ) + env = FrameStack(env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation) if cfg.env.reward_as_observation: env = RewardAsObservationWrapper(env) @@ -141,24 +148,15 @@ def transform_obs(obs: Dict[str, Any]): env.action_space.seed(seed) env.observation_space.seed(seed) if cfg.env.max_episode_steps and cfg.env.max_episode_steps > 0: - env = gym.wrappers.TimeLimit( - env, max_episode_steps=cfg.env.max_episode_steps - ) + env = gym.wrappers.TimeLimit(env, max_episode_steps=cfg.env.max_episode_steps) env = gym.wrappers.RecordEpisodeStatistics(env) - if ( - cfg.env.capture_video - and rank == 0 - and vector_env_idx == 0 - and run_name is not None - ): + if cfg.env.capture_video and rank == 0 and vector_env_idx == 0 and run_name is not None: if cfg.env.grayscale: env = GrayscaleRenderWrapper(env) env = gym.experimental.wrappers.RecordVideoV0( - env, - os.path.join(run_name, prefix + "_videos" if prefix else "videos"), - disable_logger=True, + env, os.path.join(run_name, prefix + "_videos" if prefix else "videos"), disable_logger=True ) env.metadata["render_fps"] = env.frames_per_sec return env - return thunk + return thunk \ No newline at end of file diff --git a/setup.py b/setup.py index ed50d8e..8dc8107 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ "stable-baselines3": ["stable-baselines3[extra]~=2.1.0", "pyyaml"], "ray-rllib": ["ray[rllib]~=2.7.0", "tensorflow", "torch", "pyyaml"], "sheeprl": [ - "sheeprl==0.4.7", + "sheeprl==0.4.8", "importlib-resources==6.1.0", ], }