Skip to content

Commit

Permalink
MAJOR: Added render() to RLToyEnv; increase compatibility with Gymnas…
Browse files Browse the repository at this point in the history
…ium v1.0.0
  • Loading branch information
RaghuSpaceRajan committed Dec 20, 2024
1 parent 071dffe commit 445da74
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 63 deletions.
4 changes: 2 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
from mdp_playground.envs import RLToyEnv
import numpy as np

display_images = True

def display_image(obs, mode="RGB"):
# Display the image observation associated with the next state
from PIL import Image
Expand Down Expand Up @@ -411,6 +409,8 @@ def atari_wrapper_example():

from mdp_playground.envs import GymEnvWrapper
import gymnasium as gym
import ale_py
gym.register_envs(ale_py) # optional, helpful for IDEs or pre-commit

ae = gym.make("QbertNoFrameskip-v4")
env = GymEnvWrapper(ae, **config)
Expand Down
5 changes: 5 additions & 0 deletions mdp_playground/envs/gym_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from PIL.Image import FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM
import logging

# Needed from Gymnasium v1.0.0 onwards
import ale_py
gym.register_envs(ale_py) # optional, helpful for IDEs or pre-commit


# def get_gym_wrapper(base_class):


Expand Down
196 changes: 145 additions & 51 deletions mdp_playground/envs/rl_toy_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


class RLToyEnv(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

"""
The base toy environment in MDP Playground. It is parameterised by a config dict and can be instantiated to be an MDP with any of the possible dimensions from the accompanying research paper. The class extends OpenAI Gym's environment gym.Env.
Expand Down Expand Up @@ -428,61 +430,65 @@ def __init__(self, **config):
self.image_representations = False
else:
self.image_representations = config["image_representations"]
if "image_transforms" in config:
assert config["state_space_type"] == "discrete", (
"Image " "transforms are only applicable to discrete envs."
)
self.image_transforms = config["image_transforms"]
else:
self.image_transforms = "none"

if "image_width" in config:
self.image_width = config["image_width"]
else:
self.image_width = 100
# Moved these out of the image_representations block when adding render()
# because they are needed for the render() method even if image_representations
# is False.
if "image_transforms" in config:
assert config["state_space_type"] == "discrete", (
"Image " "transforms are only applicable to discrete envs."
)
self.image_transforms = config["image_transforms"]
else:
self.image_transforms = "none"

if "image_height" in config:
self.image_height = config["image_height"]
else:
self.image_height = 100
if "image_width" in config:
self.image_width = config["image_width"]
else:
self.image_width = 100

# The following transforms are only applicable in discrete envs:
if config["state_space_type"] == "discrete":
if "image_sh_quant" not in config:
if "shift" in self.image_transforms:
warnings.warn(
"Setting image shift quantisation to the \
default of 1, since no config value was provided for it."
)
self.image_sh_quant = 1
else:
self.image_sh_quant = None
if "image_height" in config:
self.image_height = config["image_height"]
else:
self.image_height = 100

# The following transforms are only applicable in discrete envs:
if config["state_space_type"] == "discrete":
if "image_sh_quant" not in config:
if "shift" in self.image_transforms:
warnings.warn(
"Setting image shift quantisation to the \
default of 1, since no config value was provided for it."
)
self.image_sh_quant = 1
else:
self.image_sh_quant = config["image_sh_quant"]
self.image_sh_quant = None
else:
self.image_sh_quant = config["image_sh_quant"]

if "image_ro_quant" not in config:
if "rotate" in self.image_transforms:
warnings.warn(
"Setting image rotate quantisation to the \
default of 1, since no config value was provided for it."
)
self.image_ro_quant = 1
else:
self.image_ro_quant = None
if "image_ro_quant" not in config:
if "rotate" in self.image_transforms:
warnings.warn(
"Setting image rotate quantisation to the \
default of 1, since no config value was provided for it."
)
self.image_ro_quant = 1
else:
self.image_ro_quant = config["image_ro_quant"]
self.image_ro_quant = None
else:
self.image_ro_quant = config["image_ro_quant"]

if "image_scale_range" not in config:
if "scale" in self.image_transforms:
warnings.warn(
"Setting image scale range to the default \
of (0.5, 1.5), since no config value was provided for it."
)
self.image_scale_range = (0.5, 1.5)
else:
self.image_scale_range = None
if "image_scale_range" not in config:
if "scale" in self.image_transforms:
warnings.warn(
"Setting image scale range to the default \
of (0.5, 1.5), since no config value was provided for it."
)
self.image_scale_range = (0.5, 1.5)
else:
self.image_scale_range = config["image_scale_range"]
self.image_scale_range = None
else:
self.image_scale_range = config["image_scale_range"]

# Defaults for the individual environment types:
if config["state_space_type"] == "discrete":
Expand Down Expand Up @@ -827,6 +833,15 @@ def __init__(self, **config):
+ ", "
+ str(len(self.augmented_state))
)

# Needed for rendering with pygame for use with Gymnasium.Env's render() method:
render_mode = config.get("render_mode", None)
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode

self.window = None
self.clock = None

self.logger.debug(
"MDP Playground toy env instantiated with config: " + str(self.config)
)
Expand Down Expand Up @@ -1639,7 +1654,8 @@ def transition_function(self, state, action):
/ factorial_array[j]
)
# print('self.state_derivatives:', self.state_derivatives)
next_state = self.state_derivatives[0]
# copy to avoid modifying the original state which may be used by external code, e.g. to print the state
next_state = self.state_derivatives[0].copy()

else: # if action is from outside allowed action_space
next_state = state
Expand Down Expand Up @@ -1684,7 +1700,8 @@ def transition_function(self, state, action):
self.state_derivatives = [
zero_state.copy() for i in range(self.dynamics_order + 1)
]
self.state_derivatives[0] = next_state
# copy to avoid modifying the original state which may be used by external code, e.g. to print the state
self.state_derivatives[0] = next_state.copy()

if self.config["reward_function"] == "move_to_a_point":
next_state_rel = np.array(next_state, dtype=self.dtype_s)[
Expand Down Expand Up @@ -2126,7 +2143,7 @@ def get_augmented_state(self):

return augmented_state_dict

def reset(self, seed=None):
def reset(self, seed=None, options=None):
"""Resets the environment for the beginning of an episode and samples a start state from rho_0. For discrete environments uses the defined rho_0 directly. For continuous environments, samples a state and resamples until a non-terminal state is sampled.
Returns
Expand Down Expand Up @@ -2225,7 +2242,8 @@ def reset(self, seed=None):
zero_state.copy() for i in range(self.dynamics_order + 1)
] # #####IMP to have copy()
# otherwise it's the same array (in memory) at every position in the list
self.state_derivatives[0] = self.curr_state
# copy to avoid modifying the original state which may be used by external code, e.g. to print the state
self.state_derivatives[0] = self.curr_state.copy()

self.augmented_state = [
[np.nan] * self.state_space_dim
Expand Down Expand Up @@ -2316,6 +2334,82 @@ def seed(self, seed=None):
)
return self.seed_

def render(self,):
'''
Renders the environment using pygame if render_mode is "human" and returns the rendered
image if render_mode is "rgb_array".
Based on https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/
'''

import pygame

# Init stuff on first call. For non-image_representations based envs, it makes sense
# to only instantiate the render_space here and not in __init__ because it's only needed
# if render() is called.
if self.window is None:
if self.image_representations:
self.render_space = self.observation_space
else:
if self.config["state_space_type"] == "discrete":
self.render_space = ImageMultiDiscrete(
self.state_space_size,
width=self.image_width,
height=self.image_height,
transforms=self.image_transforms,
sh_quant=self.image_sh_quant,
scale_range=self.image_scale_range,
ro_quant=self.image_ro_quant,
circle_radius=20,
seed=self.seed_dict["image_representations"],
) # #seed
elif self.config["state_space_type"] == "continuous":
self.render_space = ImageContinuous(
self.feature_space,
width=self.image_width,
height=self.image_height,
term_spaces=self.term_spaces,
target_point=self.target_point,
circle_radius=5,
seed=self.seed_dict["image_representations"],
) # #seed
elif self.config["state_space_type"] == "grid":
target_pt = list_to_float_np_array(self.target_point)
self.render_space = ImageContinuous(
self.feature_space,
width=self.image_width,
height=self.image_height,
term_spaces=self.term_spaces,
target_point=target_pt,
circle_radius=5,
grid_shape=self.grid_shape,
seed=self.seed_dict["image_representations"],
) # #seed


if self.window is None and self.render_mode == "human":
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode(
(self.image_width, self.image_height)
)
if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()

# ##TODO There are repeated calculations here in calling get_concatenated_image
# that can be taken from storing variables in step() or reset().
if self.render_mode == "human":
rgb_array = self.render_space.get_concatenated_image(self.curr_state)
pygame_surface = pygame.surfarray.make_surface(rgb_array)
self.window.blit(pygame_surface, pygame_surface.get_rect())
pygame.event.pump()
pygame.display.update()

# We need to ensure that human-rendering occurs at the predefined framerate.
# The following line will automatically add a delay to keep the framerate stable.
self.clock.tick(self.metadata["render_fps"])
elif self.render_mode == "rgb_array":
return self.render_space.get_concatenated_image(self.curr_state)

def dist_of_pt_from_line(pt, ptA, ptB):
"""Returns shortest distance of a point from a line defined by 2 points - ptA and ptB.
Expand Down
4 changes: 3 additions & 1 deletion mdp_playground/spaces/image_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def get_concatenated_image(self, obs):
# image to have >=3 dims

def convert_to_pixel(self, position):
""" """
"""
Convert a continuous position to a pixel position in the image
"""
# It's implicit that both relevant and irrelevant sub-spaces have the
# same max and min here:
max = self.feature_space.high[self.relevant_indices]
Expand Down
18 changes: 9 additions & 9 deletions tests/test_gym_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_r_delay(self):

ae = gym.make("BeamRiderNoFrameskip-v4")
aew = GymEnvWrapper(ae, **config)
ob = aew.reset()
ob, _ = aew.reset()
print("observation_space.shape:", ob.shape)
# print(ob)
total_reward = 0.0
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_r_shift(self):

ae = gym.make("BeamRiderNoFrameskip-v4")
aew = GymEnvWrapper(ae, **config)
ob = aew.reset()
ob, _ = aew.reset()
print("observation_space.shape:", ob.shape)
# print(ob)
total_reward = 0.0
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_r_scale(self):

ae = gym.make("BeamRiderNoFrameskip-v4")
aew = GymEnvWrapper(ae, **config)
ob = aew.reset()
ob, _ = aew.reset()
print("observation_space.shape:", ob.shape)
# print(ob)
total_reward = 0.0
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_r_scale(self):

# ae = gym.make("BeamRiderNoFrameskip-v4")
# aew = GymEnvWrapper(ae, **config)
# ob = aew.reset()
# ob, _ = aew.reset()
# print("observation_space.shape:", ob.shape)
# # print(ob)
# total_reward = 0.0
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_r_scale(self):
# game = "".join([g.capitalize() for g in game.split("_")])
# ae = gym.make("{}NoFrameskip-v4".format(game))
# aew = GymEnvWrapper(ae, **config)
# ob = aew.reset()
# ob, _ = aew.reset()
# print("observation_space.shape:", ob.shape)
# # print(ob)
# total_reward = 0.0
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_r_delay_p_noise_r_noise(self):

ae = gym.make("BeamRiderNoFrameskip-v4")
aew = GymEnvWrapper(ae, **config)
ob = aew.reset()
ob, _ = aew.reset()
print("observation_space.shape:", ob.shape)
# print(ob)
total_reward = 0.0
Expand Down Expand Up @@ -316,7 +316,7 @@ def test_discrete_irr_features(self):

ae = gym.make("BeamRiderNoFrameskip-v4")
aew = GymEnvWrapper(ae, **config)
ob = aew.reset()
ob, _ = aew.reset()
print("type(observation_space):", type(ob))
# print(ob)
total_reward = 0.0
Expand Down Expand Up @@ -364,7 +364,7 @@ def test_image_transforms(self):

ae = gym.make("BeamRiderNoFrameskip-v4")
aew = GymEnvWrapper(ae, **config)
ob = aew.reset()
ob, _ = aew.reset()
print("observation_space.shape:", ob.shape)
assert ob.shape == (100, 100, 3), "Observation shape of the env was unexpected."
# print(ob)
Expand Down Expand Up @@ -420,7 +420,7 @@ def test_cont_irr_features(self):
# register_env("HalfCheetahWrapper-v3", lambda config: HalfCheetahWrapperV3(**config))

hc3w = GymEnvWrapper(hc3, **config)
ob = hc3w.reset()
ob, _ = hc3w.reset()
print("obs shape, type(observation_space):", ob.shape, type(ob))
print("initial obs: ", ob)
assert (
Expand Down

0 comments on commit 445da74

Please sign in to comment.