Skip to content

Commit

Permalink
Minor indentation refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Mar 10, 2024
1 parent 294d12f commit 7b4e303
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
31 changes: 16 additions & 15 deletions stable_baselines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import yaml
import json
import argparse
from custom_wrappers import RamStatesToChannel, SplitActionsInMoveAndAttack
from diambra.arena import SpaceTypes, load_settings_flat_dict
from diambra.arena.stable_baselines.make_sb_env import make_sb_env, EnvironmentSettings, WrappersSettings
from diambra.arena.stable_baselines.sb_utils import linear_schedule, AutoSave
from custom_policies.custom_cnn_policy import CustCnnPolicy, local_nature_cnn_small
from stable_baselines import PPO2
from custom_wrappers import RamStatesToChannel, SplitActionsInMoveAndAttack
from custom_policies.custom_cnn_policy import CustCnnPolicy, local_nature_cnn_small


# diambra run -s 8 python stable_baselines/training.py --cfgFile $PWD/stable_baselines/cfg_files/sfiii3n/sr6_128x4_das_nc.yaml

def main(cfg_file):
# Read the cfg file
Expand Down Expand Up @@ -35,7 +38,9 @@ def main(cfg_file):
wrappers_settings.wrappers = [[SplitActionsInMoveAndAttack, {}],
[RamStatesToChannel, {"ram_states": params["ram_states"]}]]

# Create environment
env, num_envs = make_sb_env(settings.game_id, settings, wrappers_settings, use_subprocess=True)
print("Activated {} environment(s)".format(num_envs))

# Policy param
policy_kwargs = params["policy_kwargs"]
Expand All @@ -58,21 +63,17 @@ def main(cfg_file):
if model_checkpoint == "0":
# Initialize the agent
agent = PPO2(CustCnnPolicy, env, verbose=1,
gamma=gamma, nminibatches=nminibatches,
noptepochs=noptepochs, n_steps=n_steps,
learning_rate=learning_rate, cliprange=cliprange,
cliprange_vf=cliprange_vf, policy_kwargs=policy_kwargs,
tensorboard_log=tensor_board_folder)
gamma=gamma, nminibatches=nminibatches,
noptepochs=noptepochs, n_steps=n_steps,
learning_rate=learning_rate, cliprange=cliprange,
cliprange_vf=cliprange_vf, policy_kwargs=policy_kwargs,
tensorboard_log=tensor_board_folder)
else:

# Load the trained agent
agent = PPO2.load(os.path.join(model_folder, model_checkpoint), env=env,
policy_kwargs=policy_kwargs, gamma=gamma,
learning_rate=learning_rate,
cliprange=cliprange, cliprange_vf=cliprange_vf,
tensorboard_log=tensor_board_folder)

print("Model discount factor = ", agent.gamma)
gamma=gamma, learning_rate=learning_rate, cliprange=cliprange,
cliprange_vf=cliprange_vf, policy_kwargs=policy_kwargs,
tensorboard_log=tensor_board_folder)

# Create the callback: autosave every USER DEF steps
autosave_freq = ppo_settings["autosave_freq"]
Expand All @@ -96,7 +97,7 @@ def main(cfg_file):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--cfgFile', type=str, required=True, help='Configuration file')
parser.add_argument("--cfgFile", type=str, required=True, help="Configuration file")
opt = parser.parse_args()
print(opt)

Expand Down
3 changes: 1 addition & 2 deletions stable_baselines3/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@ def main(cfg_file):
return 0

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--cfgFile", type=str, required=True, help="Configuration file")
opt = parser.parse_args()
print(opt)

main(opt.cfgFile)
main(opt.cfgFile)

0 comments on commit 7b4e303

Please sign in to comment.