diff --git a/stable_baselines3/.gitignore b/stable_baselines3/.gitignore new file mode 100644 index 0000000..21d0b89 --- /dev/null +++ b/stable_baselines3/.gitignore @@ -0,0 +1 @@ +.venv/ diff --git a/stable_baselines3/Dockerfile b/stable_baselines3/Dockerfile new file mode 100644 index 0000000..d993a48 --- /dev/null +++ b/stable_baselines3/Dockerfile @@ -0,0 +1,9 @@ +ARG TAG=main +FROM ghcr.io/diambra/arena-stable-baselines3-on3.10-bullseye:$TAG + +RUN pip install huggingface_hub +ENV HF_HOME=/tmp +WORKDIR /app +COPY . . + +ENTRYPOINT [ "python", "./agent.py" ] \ No newline at end of file diff --git a/stable_baselines3/agent.py b/stable_baselines3/agent.py old mode 100644 new mode 100755 index 8445a4a..45b7c56 --- a/stable_baselines3/agent.py +++ b/stable_baselines3/agent.py @@ -5,6 +5,7 @@ from diambra.arena import Roles, SpaceTypes, load_settings_flat_dict from diambra.arena.stable_baselines3.make_sb3_env import make_sb3_env, EnvironmentSettings, WrappersSettings from stable_baselines3 import PPO +from huggingface_hub import hf_hub_download, login as huggingface_hub_login """This is an example agent based on stable baselines 3. @@ -12,16 +13,23 @@ diambra run python stable_baselines3/agent.py --cfgFile $PWD/stable_baselines3/cfg_files/doapp/sr6_128x4_das_nc.yaml --trainedModel "model_name" """ -def main(cfg_file, trained_model, test=False): +def main(repo, cfg_file, trained_model, test=False): + + if os.getenv("HF_TOKEN"): + huggingface_hub_login(os.getenv("HF_TOKEN")) + + config_path = hf_hub_download(repo_id=repo, filename=cfg_file) # Read the cfg file - yaml_file = open(cfg_file) + yaml_file = open(config_path) params = yaml.load(yaml_file, Loader=yaml.FullLoader) print("Config parameters = ", json.dumps(params, sort_keys=True, indent=4)) yaml_file.close() - base_path = os.path.dirname(os.path.abspath(__file__)) - model_folder = os.path.join(base_path, params["folders"]["parent_dir"], params["settings"]["game_id"], - params["folders"]["model_name"], "model") + model_repo_path = os.path.join(params["folders"]["parent_dir"], params["settings"]["game_id"], + params["folders"]["model_name"], "model", trained_model) + + model_path = hf_hub_download(repo_id=repo, filename=model_repo_path) + # Settings params["settings"]["action_space"] = SpaceTypes.DISCRETE if params["settings"]["action_space"] == "discrete" else SpaceTypes.MULTI_DISCRETE @@ -37,7 +45,6 @@ def main(cfg_file, trained_model, test=False): print("Activated {} environment(s)".format(num_envs)) # Load the trained agent - model_path = os.path.join(model_folder, trained_model) agent = PPO.load(model_path) # Print policy network architecture @@ -64,10 +71,11 @@ def main(cfg_file, trained_model, test=False): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--repo", type=str, required=True, help="Repository name") parser.add_argument("--cfgFile", type=str, required=True, help="Configuration file") - parser.add_argument("--trainedModel", type=str, default="model", help="Model checkpoint") + parser.add_argument("--trainedModel", type=str, default="model.zip", help="Model checkpoint") parser.add_argument("--test", type=int, default=0, help="Test mode") opt = parser.parse_args() print(opt) - main(opt.cfgFile, opt.trainedModel, bool(opt.test)) + main(opt.repo, opt.cfgFile, opt.trainedModel, bool(opt.test)) diff --git a/stable_baselines3/cfg_files/doapp/sr6_128x4_das_nc.yaml b/stable_baselines3/cfg_files/doapp/sr6_128x4_das_nc.yaml deleted file mode 100644 index ba63068..0000000 --- a/stable_baselines3/cfg_files/doapp/sr6_128x4_das_nc.yaml +++ /dev/null @@ -1,43 +0,0 @@ -folders: - parent_dir: "./results/" - model_name: "sr6_128x4_das_nc" - -settings: - game_id: "doapp" - step_ratio: 6 - frame_shape: !!python/tuple [128, 128, 1] - continue_game: 0.0 - action_space: "multi_discrete" - characters: "Kasumi" - difficulty: 3 - outfits: 2 - -wrappers_settings: - normalize_reward: true - no_attack_buttons_combinations: true - stack_frames: 4 - dilation: 1 - add_last_action: true - stack_actions: 12 - scale: true - exclude_image_scaling: true - role_relative: true - flatten: true - filter_keys: ["action", "own_health", "opp_health", "own_side", "opp_side", "opp_character", "stage", "timer"] - -policy_kwargs: - #net_arch: [{ pi: [64, 64], vf: [32, 32] }] - net_arch: [64, 64] - -ppo_settings: - gamma: 0.94 - model_checkpoint: "0" - learning_rate: [2.5e-4, 2.5e-6] # To start - clip_range: [0.15, 0.025] # To start - #learning_rate: [5.0e-5, 2.5e-6] # Fine Tuning - #clip_range: [0.075, 0.025] # Fine Tuning - batch_size: 256 #8 #nminibatches gave different batch size depending on the number of environments: batch_size = (n_steps * n_envs) // nminibatches - n_epochs: 4 - n_steps: 128 - autosave_freq: 256 - time_steps: 512 diff --git a/stable_baselines3/cfg_files/sfiii3n/sr6_128x4_das_nc.yaml b/stable_baselines3/cfg_files/sfiii3n/sr6_128x4_das_nc.yaml deleted file mode 100644 index 4730350..0000000 --- a/stable_baselines3/cfg_files/sfiii3n/sr6_128x4_das_nc.yaml +++ /dev/null @@ -1,43 +0,0 @@ -folders: - parent_dir: "./results/" - model_name: "sr6_128x4_das_nc" - -settings: - game_id: "sfiii3n" - step_ratio: 6 - frame_shape: !!python/tuple [128, 128, 1] - continue_game: 0.0 - action_space: "discrete" - characters: "Ryu" - difficulty: 6 - outfits: 2 - -wrappers_settings: - normalize_reward: true - no_attack_buttons_combinations: true - stack_frames: 4 - dilation: 1 - add_last_action: true - stack_actions: 12 - scale: true - exclude_image_scaling: true - role_relative: true - flatten: true - filter_keys: ["action", "own_health", "opp_health", "own_side", "opp_side", "opp_character", "stage", "timer"] - -policy_kwargs: - #net_arch: [{ pi: [64, 64], vf: [32, 32] }] - net_arch: [64, 64] - -ppo_settings: - gamma: 0.94 - model_checkpoint: "0" - learning_rate: [2.5e-4, 2.5e-6] # To start - clip_range: [0.15, 0.025] # To start - #learning_rate: [5.0e-5, 2.5e-6] # Fine Tuning - #clip_range: [0.075, 0.025] # Fine Tuning - batch_size: 256 #8 #nminibatches gave different batch size depending on the number of environments: batch_size = (n_steps * n_envs) // nminibatches - n_epochs: 4 - n_steps: 128 - autosave_freq: 256 - time_steps: 512 diff --git a/stable_baselines3/results/doapp/sr6_128x4_das_nc/model/model.zip b/stable_baselines3/results/doapp/sr6_128x4_das_nc/model/model.zip deleted file mode 100644 index f97d581..0000000 --- a/stable_baselines3/results/doapp/sr6_128x4_das_nc/model/model.zip +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0283e655d72fdf4972ddfbdd4de02c7584827c78dd882d2ebcdfa7b163873583 -size 30507273 diff --git a/stable_baselines3/results/sfiii3n/sr6_128x4_das_nc/model/model.zip b/stable_baselines3/results/sfiii3n/sr6_128x4_das_nc/model/model.zip deleted file mode 100644 index eefff7a..0000000 --- a/stable_baselines3/results/sfiii3n/sr6_128x4_das_nc/model/model.zip +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:94c4ebac4cffd58e618c5f41c6375e84382776411e8e470db08c701d0d6c7734 -size 30559434