diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e0b8c0fec5b..330481394c3 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -179,7 +179,10 @@ def download_weights( import json - config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") + if is_local_model: + config_filename = os.path.join(model_id, "config.json") + else: + config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") with open(config_filename, "r") as f: config = json.load(f) architecture = config["architectures"][0]