Skip to content

Commit

Permalink
feature(pu): add resume_training option to allow the envstep and trai…
Browse files Browse the repository at this point in the history
…n_iter resume seamlessly (#835)

* feature(pu): add load pretrained ckpt in serial_entry_onpolicy and serial_entry

* polish(pu): add resume_training option to allow the envstep and train_iter resume seamlessly

* polish(pu): polish comments and resume_training option

* style(pu): bash format.sh

* fix(pu): fix test_serial_entry** by add resume_training=False in related config

* polish(pu): polish resume_training in entry

* style(pu): bash format.sh
  • Loading branch information
puyuan1996 authored Nov 5, 2024
1 parent 3898386 commit 1f198e9
Show file tree
Hide file tree
Showing 38 changed files with 234 additions and 176 deletions.
12 changes: 11 additions & 1 deletion ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,15 @@ def serial_pipeline(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -86,6 +94,8 @@ def serial_pipeline(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
27 changes: 14 additions & 13 deletions ding/entry/serial_entry_mbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ def mbrl_entry_setup(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)

if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -70,18 +78,7 @@ def mbrl_entry_setup(
cfg.policy.other.commander, learner, collector, evaluator, env_buffer, policy.command_mode
)

return (
cfg,
policy,
world_model,
env_buffer,
learner,
collector,
collector_env,
evaluator,
commander,
tb_logger,
)
return (cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger)


def create_img_buffer(
Expand Down Expand Up @@ -131,6 +128,8 @@ def serial_pipeline_dyna(
img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger)

learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
Expand Down Expand Up @@ -202,6 +201,8 @@ def serial_pipeline_dream(
mbrl_entry_setup(input_cfg, seed, env_setting, model)

learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
Expand Down
12 changes: 11 additions & 1 deletion ding/entry/serial_entry_ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,15 @@ def serial_pipeline_ngu(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -89,6 +97,8 @@ def serial_pipeline_ngu(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
13 changes: 12 additions & 1 deletion ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,16 @@ def serial_pipeline_onpolicy(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)

# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -80,6 +89,8 @@ def serial_pipeline_onpolicy(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

while True:
collect_kwargs = commander.step()
Expand Down
12 changes: 11 additions & 1 deletion ding/entry/serial_entry_onpolicy_ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ def serial_pipeline_onpolicy_ppg(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.get('resume_training', False)
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -80,6 +88,8 @@ def serial_pipeline_onpolicy_ppg(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

while True:
collect_kwargs = commander.step()
Expand Down
4 changes: 4 additions & 0 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def default_config(cls: type) -> EasyDict:
traj_len_inf=False,
# neural network model config
model=dict(),
# If resume_training is True, the environment step count (collector.envstep) and training iteration (train_iter)
# will be loaded from the pretrained checkpoint, allowing training to resume seamlessly
# from where the ckpt left off.
learn=dict(resume_training=False),
)

def __init__(
Expand Down
12 changes: 11 additions & 1 deletion ding/worker/collector/battle_episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,20 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
12 changes: 11 additions & 1 deletion ding/worker/collector/battle_sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,20 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
12 changes: 11 additions & 1 deletion ding/worker/collector/episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,20 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
12 changes: 11 additions & 1 deletion ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,20 @@ def envstep(self) -> int:
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
- envstep (:obj:`int`): The total envstep count.
"""
return self._total_envstep_count

@envstep.setter
def envstep(self, value: int) -> None:
"""
Overview:
Set the total envstep count.
Arguments:
- value (:obj:`int`): The total envstep count.
"""
self._total_envstep_count = value

def close(self) -> None:
"""
Overview:
Expand Down
22 changes: 22 additions & 0 deletions ding/worker/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def __init__(
self._hooks = {'before_run': [], 'before_iter': [], 'after_iter': [], 'after_run': []}
# Last iteration. Used to record current iter.
self._last_iter = CountVar(init_val=0)
# Collector envstep. Used to record current envstep.
self._collector_envstep = 0

# Setup time wrapper and hook.
self._setup_wrapper()
Expand Down Expand Up @@ -177,6 +179,26 @@ def register_hook(self, hook: LearnerHook) -> None:
"""
add_learner_hook(self._hooks, hook)

@property
def collector_envstep(self) -> int:
"""
Overview:
Get current collector envstep.
Returns:
- collector_envstep (:obj:`int`): Current collector envstep.
"""
return self._collector_envstep

@collector_envstep.setter
def collector_envstep(self, value: int) -> None:
"""
Overview:
Set current collector envstep.
Arguments:
- value (:obj:`int`): Current collector envstep.
"""
self._collector_envstep = value

def train(self, data: dict, envstep: int = -1, policy_kwargs: Optional[dict] = None) -> None:
"""
Overview:
Expand Down
4 changes: 4 additions & 0 deletions ding/worker/learner/learner_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa
if 'last_iter' in state_dict:
last_iter = state_dict.pop('last_iter')
engine.last_iter.update(last_iter)
if 'last_step' in state_dict:
last_step = state_dict.pop('last_step')
engine._collector_envstep = last_step
engine.policy.load_state_dict(state_dict)
engine.info('{} load ckpt in {}'.format(engine.instance_name, path))

Expand Down Expand Up @@ -166,6 +169,7 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa
path = os.path.join(dirname, ckpt_name)
state_dict = engine.policy.state_dict()
state_dict.update({'last_iter': engine.last_iter.val})
state_dict.update({'last_step': engine.collector_envstep})
save_file(path, state_dict)
engine.info('{} save ckpt in {}'.format(engine.instance_name, path))

Expand Down
10 changes: 9 additions & 1 deletion dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(hook=dict(save_ckpt_after_iter=100)),
# Path to the pretrained checkpoint (ckpt).
# If set to an empty string (''), no pretrained model will be loaded.
# To load a pretrained ckpt, specify the path like this:
# learner=dict(hook=dict(load_ckpt_before_run='/path/to/your/ckpt/iteration_100.pth.tar')),

# If True, the environment step count (collector.envstep) and training iteration (train_iter)
# will be loaded from the pretrained checkpoint, allowing training to resume seamlessly
# from where the ckpt left off.
resume_training=False,
),
collect=dict(
n_sample=256,
Expand Down
4 changes: 2 additions & 2 deletions dizoo/cliffwalking/envs/cliffwalking_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(self, cfg: dict) -> None:
self._replay_path = None
self._observation_space = gym.spaces.Box(low=0, high=1, shape=(48, ), dtype=np.float32)
self._env = gym.make(
"CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps
)
"CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps
)
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
Expand Down
8 changes: 3 additions & 5 deletions dizoo/d4rl/config/antmaze_umaze_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
value_model_cfg=dict(
model='TemporalValue',
model_cfg=dict(
horizon = 256,
horizon=256,
transition_dim=37,
dim=32,
dim_mults=[1, 2, 4, 8],
Expand Down Expand Up @@ -92,10 +92,8 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
create_config = create_config
6 changes: 2 additions & 4 deletions dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
value_model_cfg=dict(
model='TemporalValue',
model_cfg=dict(
horizon = 4,
horizon=4,
transition_dim=23,
dim=32,
dim_mults=[1, 4, 8],
Expand Down Expand Up @@ -92,9 +92,7 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
Expand Down
Loading

0 comments on commit 1f198e9

Please sign in to comment.