Skip to content

Commit

Permalink
sb3: Download model and config from huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
discordianfish committed Nov 10, 2023
1 parent b7c6aec commit c7f2245
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 99 deletions.
1 change: 1 addition & 0 deletions stable_baselines3/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.venv/
8 changes: 8 additions & 0 deletions stable_baselines3/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ARG TAG=main
FROM ghcr.io/diambra/arena-stable-baselines3-on3.10-bullseye:$TAG

RUN pip install huggingface_hub
WORKDIR /app
COPY . .

ENTRYPOINT [ "python", "./agent.py" ]
22 changes: 15 additions & 7 deletions stable_baselines3/agent.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,31 @@
from diambra.arena import Roles, SpaceTypes, load_settings_flat_dict
from diambra.arena.stable_baselines3.make_sb3_env import make_sb3_env, EnvironmentSettings, WrappersSettings
from stable_baselines3 import PPO
from huggingface_hub import hf_hub_download, login as huggingface_hub_login

"""This is an example agent based on stable baselines 3.
Usage:
diambra run python stable_baselines3/agent.py --cfgFile $PWD/stable_baselines3/cfg_files/doapp/sr6_128x4_das_nc.yaml --trainedModel "model_name"
"""

def main(cfg_file, trained_model, test=False):
def main(repo, cfg_file, trained_model, test=False):

if os.getenv("HF_TOKEN"):
huggingface_hub_login(os.getenv("HF_TOKEN"))

config_path = hf_hub_download(repo_id=repo, filename=cfg_file)
# Read the cfg file
yaml_file = open(cfg_file)
yaml_file = open(config_path)
params = yaml.load(yaml_file, Loader=yaml.FullLoader)
print("Config parameters = ", json.dumps(params, sort_keys=True, indent=4))
yaml_file.close()

base_path = os.path.dirname(os.path.abspath(__file__))
model_folder = os.path.join(base_path, params["folders"]["parent_dir"], params["settings"]["game_id"],
params["folders"]["model_name"], "model")
model_repo_path = os.path.join(params["folders"]["parent_dir"], params["settings"]["game_id"],
params["folders"]["model_name"], "model", trained_model)

model_path = hf_hub_download(repo_id=repo, filename=model_repo_path)


# Settings
params["settings"]["action_space"] = SpaceTypes.DISCRETE if params["settings"]["action_space"] == "discrete" else SpaceTypes.MULTI_DISCRETE
Expand All @@ -37,7 +45,6 @@ def main(cfg_file, trained_model, test=False):
print("Activated {} environment(s)".format(num_envs))

# Load the trained agent
model_path = os.path.join(model_folder, trained_model)
agent = PPO.load(model_path)

# Print policy network architecture
Expand All @@ -64,10 +71,11 @@ def main(cfg_file, trained_model, test=False):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--repo", type=str, required=True, help="Repository name")
parser.add_argument("--cfgFile", type=str, required=True, help="Configuration file")
parser.add_argument("--trainedModel", type=str, default="model", help="Model checkpoint")
parser.add_argument("--test", type=int, default=0, help="Test mode")
opt = parser.parse_args()
print(opt)

main(opt.cfgFile, opt.trainedModel, bool(opt.test))
main(opt.repo, opt.cfgFile, opt.trainedModel, bool(opt.test))
43 changes: 0 additions & 43 deletions stable_baselines3/cfg_files/doapp/sr6_128x4_das_nc.yaml

This file was deleted.

43 changes: 0 additions & 43 deletions stable_baselines3/cfg_files/sfiii3n/sr6_128x4_das_nc.yaml

This file was deleted.

This file was deleted.

This file was deleted.

0 comments on commit c7f2245

Please sign in to comment.