Skip to content

Commit

Permalink
feature(wrh): add taxi env latest version and dqn config (#807)
Browse files Browse the repository at this point in the history
* update taxi env
  • Loading branch information
ruiheng123 authored Jun 20, 2024
1 parent 91bc342 commit 73ff16f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) <br> env tutorial <br> 环境指南 |
| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake) <br> env tutorial <br> 环境指南 |
| 39 | [ising_model](https://github.com/mlii/mfrl/tree/master/examples/ising_model) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/ising_env/ising_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/ising_env) <br> env tutorial <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/ising_model_zh.html) |
| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | dizoo link <br> env tutorial <br> 环境指南 |
| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/taxi/envs) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/taxi.html) <br> [环境指南](https://di-engine-docs.readthedocs.io/zh-cn/latest/13_envs/taxi_zh.html) |



Expand Down
46 changes: 26 additions & 20 deletions dizoo/taxi/config/taxi_dqn_config.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,45 @@
from easydict import EasyDict

taxi_dqn_config = dict(
exp_name='taxi_seed0',
exp_name='taxi_dqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=10,
max_episode_steps=300,
env_id="Taxi-v3"
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
max_episode_steps=60,
env_id="Taxi-v3"
),
policy=dict(
cuda=True,
load_path="./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar",
model=dict(
obs_shape=4,
obs_shape=34,
action_shape=6,
encoder_hidden_size_list=[256, 128, 64]
encoder_hidden_size_list=[128, 128]
),
random_collect_size=5000,
nstep=3,
discount_factor=0.98,
discount_factor=0.99,
learn=dict(
update_per_collect=5,
batch_size=128,
learning_rate=0.001,
update_per_collect=10,
batch_size=64,
learning_rate=0.0001,
learner=dict(
hook=dict(
log_show_after_iter=1000,
)
),
),
collect=dict(n_sample=10),
eval=dict(evaluator=dict(eval_freq=5, )),
collect=dict(n_sample=32),
eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(
eps=dict(
type="linear",
start=0.8,
end=0.1,
decay=10000
),
replay_buffer=dict(replay_buffer_size=20000,),
start=1,
end=0.05,
decay=3000000
),
replay_buffer=dict(replay_buffer_size=100000,),
),
)
)
Expand All @@ -55,4 +61,4 @@

if __name__ == "__main__":
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), max_env_step=5000, seed=0)
serial_pipeline((main_config, create_config), max_env_step=3000000, seed=0)
10 changes: 7 additions & 3 deletions dizoo/taxi/envs/taxi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
if not os.path.exists(replay_path):
os.makedirs(replay_path)
if not os.path.exists(replay_path):
os.makedirs(replay_path)
self._replay_path = replay_path
self._save_replay = True
self._save_replay_count = 0
Expand All @@ -118,7 +118,11 @@ def random_action(self) -> np.ndarray:
#todo encode the state into a vector
def _encode_taxi(self, obs: np.ndarray) -> np.ndarray:
taxi_row, taxi_col, passenger_location, destination = self._env.unwrapped.decode(obs)
return to_ndarray([taxi_row, taxi_col, passenger_location, destination])
encoded_obs = np.zeros(34)
encoded_obs[5 * taxi_row + taxi_col] = 1
encoded_obs[25 + passenger_location] = 1
encoded_obs[30 + destination] = 1
return to_ndarray(encoded_obs)

@property
def observation_space(self) -> Space:
Expand Down
4 changes: 2 additions & 2 deletions dizoo/taxi/envs/test_taxi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_naive(self):
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (4, )
assert obs.shape == (34, )
for _ in range(5):
env.reset()
np.random.seed(314)
Expand All @@ -32,7 +32,7 @@ def test_naive(self):
print(f"Your timestep in wrapped mode is: {timestep}")
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (4, )
assert timestep.obs.shape == (34, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high
Expand Down

0 comments on commit 73ff16f

Please sign in to comment.