Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): add envpool new pipeline #753

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ding/entry/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Callable, List, Any
from typing import Optional, Callable, List, Any, Dict

from ding.policy import PolicyFactory
from ding.worker import IMetric, MetricSerialEvaluator
Expand Down Expand Up @@ -46,7 +46,8 @@ def random_collect(
collector_env: 'BaseEnvManager', # noqa
commander: 'BaseSerialCommander', # noqa
replay_buffer: 'IBuffer', # noqa
postprocess_data_fn: Optional[Callable] = None
postprocess_data_fn: Optional[Callable] = None,
collect_kwargs: Optional[Dict] = None,
) -> None: # noqa
assert policy_cfg.random_collect_size > 0
if policy_cfg.get('transition_with_policy_data', False):
Expand All @@ -55,7 +56,8 @@ def random_collect(
action_space = collector_env.action_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
if collect_kwargs is None:
collect_kwargs = commander.step()
if policy_cfg.collect.collector.type == 'episode':
new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs)
else:
Expand Down
232 changes: 227 additions & 5 deletions ding/envs/env_manager/envpool_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from easydict import EasyDict
from copy import deepcopy
import numpy as np
import torch
import treetensor.torch as ttorch
import treetensor.numpy as tnp
from collections import namedtuple
import enum
from typing import Any, Union, List, Tuple, Dict, Callable, Optional
from ditk import logging
try:
Expand All @@ -13,21 +17,33 @@
envpool = None

from ding.envs import BaseEnvTimestep
from ding.envs.env_manager import BaseEnvManagerV2
from ding.utils import ENV_MANAGER_REGISTRY, deep_merge_dicts
from ding.torch_utils import to_ndarray


@ENV_MANAGER_REGISTRY.register('env_pool')
class EnvState(enum.IntEnum):
VOID = 0
INIT = 1
RUN = 2
RESET = 3
DONE = 4
ERROR = 5
NEED_RESET = 6


@ENV_MANAGER_REGISTRY.register('envpool')
class PoolEnvManager:
'''
"""
Overview:
PoolEnvManager supports old pipeline of DI-engine.
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom.
Here we list some commonly used env_ids as follows.
For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>.

- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5"
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1"
'''
"""

@classmethod
def default_config(cls) -> EasyDict:
Expand All @@ -39,10 +55,17 @@ def default_config(cls) -> EasyDict:
# Async mode: batch_size < env_num
env_num=8,
batch_size=8,
image_observation=True,
episodic_life=False,
reward_clip=False,
gray_scale=True,
stack_num=4,
frame_skip=4,
)

def __init__(self, cfg: EasyDict) -> None:
self._cfg = cfg
self._cfg = self.default_config()
self._cfg.update(cfg)
self._env_num = cfg.env_num
self._batch_size = cfg.batch_size
self._ready_obs = {}
Expand All @@ -55,6 +78,7 @@ def launch(self) -> None:
seed = 0
else:
seed = self._seed

self._envs = envpool.make(
task_id=self._cfg.env_id,
env_type="gym",
Expand All @@ -65,8 +89,10 @@ def launch(self) -> None:
reward_clip=self._cfg.reward_clip,
stack_num=self._cfg.stack_num,
gray_scale=self._cfg.gray_scale,
frame_skip=self._cfg.frame_skip
frame_skip=self._cfg.frame_skip,
)
self._action_space = self._envs.action_space
self._observation_space = self._envs.observation_space
self._closed = False
self.reset()

Expand All @@ -77,6 +103,8 @@ def reset(self) -> None:
obs, _, _, info = self._envs.recv()
env_id = info['env_id']
obs = obs.astype(np.float32)
if self._cfg.image_observation:
obs /= 255.0
self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs)
if len(self._ready_obs) == self._env_num:
break
Expand All @@ -91,6 +119,8 @@ def step(self, action: dict) -> Dict[int, namedtuple]:

obs, rew, done, info = self._envs.recv()
obs = obs.astype(np.float32)
if self._cfg.image_observation:
obs /= 255.0
rew = rew.astype(np.float32)
env_id = info['env_id']
timesteps = {}
Expand All @@ -117,10 +147,202 @@ def seed(self, seed: int, dynamic_seed=False) -> None:
self._seed = seed
logging.warning("envpool doesn't support dynamic_seed in different episode")

@property
def closed(self) -> None:
return self._closed

@property
def env_num(self) -> int:
return self._env_num

@property
def ready_obs(self) -> Dict[int, Any]:
return self._ready_obs

@property
def observation_space(self) -> 'gym.spaces.Space': # noqa
try:
return self._observation_space
except AttributeError:
self.launch()
self.close()
return self._observation_space

@property
def action_space(self) -> 'gym.spaces.Space': # noqa
try:
return self._action_space
except AttributeError:
self.launch()
self.close()
return self._action_space


@ENV_MANAGER_REGISTRY.register('envpool_v2')
class PoolEnvManagerV2:
"""
Overview:
PoolEnvManagerV2 supports new pipeline of DI-engine.
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom.
Here we list some commonly used env_ids as follows.
For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>.

- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5"
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1"
"""

@classmethod
def default_config(cls) -> EasyDict:
return EasyDict(deepcopy(cls.config))

config = dict(
type='envpool_v2',
# Sync mode: batch_size == env_num
# Async mode: batch_size < env_num
env_num=8,
batch_size=8,
image_observation=True,
episodic_life=False,
reward_clip=False,
gray_scale=True,
stack_num=4,
frame_skip=4,
)

def __init__(self, cfg: EasyDict) -> None:
self._cfg = self.default_config()
self._cfg.update(cfg)
self._env_num = cfg.env_num
self._batch_size = cfg.batch_size
self._ready_obs = {}
self._closed = True
self._seed = None

def launch(self) -> None:
assert self._closed, "Please first close the env manager"
if self._seed is None:
seed = 0
else:
seed = self._seed

self._envs = envpool.make(
task_id=self._cfg.env_id,
env_type="gym",
num_envs=self._env_num,
batch_size=self._batch_size,
seed=seed,
episodic_life=self._cfg.episodic_life,
reward_clip=self._cfg.reward_clip,
stack_num=self._cfg.stack_num,
gray_scale=self._cfg.gray_scale,
frame_skip=self._cfg.frame_skip,
)
self._action_space = self._envs.action_space
self._observation_space = self._envs.observation_space
self._closed = False
self.reset()

def reset(self) -> None:
self._ready_obs = {}
self._envs.async_reset()
while True:
obs, _, _, info = self._envs.recv()
env_id = info['env_id']
obs = obs.astype(np.float32)
if self._cfg.image_observation:
obs /= 255.0
self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs)
if len(self._ready_obs) == self._env_num:
break
self._eval_episode_return = [0. for _ in range(self._env_num)]

def step(self, action: tnp.array) -> Dict[int, namedtuple]:
env_id = np.array(self.ready_obs_id)
action = np.array(action)
if len(action.shape) == 2:
action = action.squeeze(1)
self._envs.send(action, env_id)

obs, rew, done, info = self._envs.recv()
obs = obs.astype(np.float32)
if self._cfg.image_observation:
obs /= 255.0
rew = rew.astype(np.float32)
env_id = info['env_id']
new_data = []

self._ready_obs = {}
for i in range(len(env_id)):
d = bool(done[i])
r = to_ndarray([rew[i]])
self._eval_episode_return[env_id[i]] += r

if d:
new_data.append(
tnp.array(
{
'obs': obs[i],
'reward': r,
'done': d,
'info': {
'env_id': i,
'eval_episode_return': self._eval_episode_return[env_id[i]]
},
'env_id': i
}
)
)
self._eval_episode_return[env_id[i]] = 0.
else:
new_data.append(tnp.array({'obs': obs[i], 'reward': r, 'done': d, 'info': {'env_id': i}, 'env_id': i}))

self._ready_obs[env_id[i]] = obs[i]

return new_data

@property
def ready_obs_id(self) -> List[int]:
# In BaseEnvManager, if env_episode_count equals episode_num, this env is done.
return list(self._ready_obs.keys())

@property
def ready_obs(self) -> tnp.array:
obs = list(self._ready_obs.values())
return tnp.stack(obs)

def close(self) -> None:
if self._closed:
return
# Envpool has no `close` API
self._closed = True

def seed(self, seed: int, dynamic_seed=False) -> None:
# The i-th environment seed in Envpool will be set with i+seed, so we don't do extra transformation here
self._seed = seed
logging.warning("envpool doesn't support dynamic_seed in different episode")

@property
def closed(self) -> None:
return self._closed

@property
def env_num(self) -> int:
return self._env_num

@property
def observation_space(self) -> 'gym.spaces.Space': # noqa
try:
return self._observation_space
except AttributeError:
self.launch()
self.close()
return self._observation_space

@property
def action_space(self) -> 'gym.spaces.Space': # noqa
try:
return self._action_space
except AttributeError:
self.launch()
self.close()
return self._action_space
11 changes: 2 additions & 9 deletions ding/envs/env_manager/tests/test_envpool_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from easydict import EasyDict

from ..envpool_env_manager import PoolEnvManager
from ding.envs.env_manager.envpool_env_manager import PoolEnvManager

env_num_args = [[16, 8], [8, 8]]

Expand All @@ -30,17 +30,10 @@ def test_naive(self, env_num, batch_size):
env_manager = PoolEnvManager(env_manager_cfg)
assert env_manager._closed
env_manager.launch()
# Test step
start_time = time.time()
for count in range(20):
for count in range(5):
env_id = env_manager.ready_obs.keys()
action = {i: np.random.randint(4) for i in env_id}
timestep = env_manager.step(action)
assert len(timestep) == env_manager_cfg.batch_size
print('Count {}'.format(count))
print([v.info for v in timestep.values()])
end_time = time.time()
print('total step time: {}'.format(end_time - start_time))
# Test close
env_manager.close()
assert env_manager._closed
Loading
Loading