From edc95a0e7dc4c8e13a20fbf2c8e2b99a8c3e4549 Mon Sep 17 00:00:00 2001 From: zhangsibo1129 <134488188+zhangsibo1129@users.noreply.github.com> Date: Tue, 26 Sep 2023 20:57:53 +0800 Subject: [PATCH] support local model config file (#1058) # What does this PR do? Support local config file to avoid unexpected `discard_names`, which causes #1057. In the case of launching local mode without `model.safetensors` file, the original code will result `discard_names = []` when `hf_hub_download` throws an connection error. ```python # server/text_generation_server/cli.py try: import transformers import json config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") with open(config_filename, "r") as f: config = json.load(f) architecture = config["architectures"][0] class_ = getattr(transformers, architecture) # Name for this varible depends on transformers version. discard_names = getattr(class_, "_tied_weights_keys", []) discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) except Exception as e: discard_names = [] ``` The expected `_tied_weights_keys` of OPT-1.3b is `["lm_head.weight"]`, and its tied weight `"model.decoder.embed_tokens.weight"` will be kept in the safetensors conversion. But the above empty `discard_names` will lead to `"lm_head.weight"` being kept and `"model.decoder.embed_tokens.weight"` being discard in the subsequent method `_remove_duplicate_names`, which causes error #1057. So add a local mode branch to get the expected `discard_names` like follows. This modification also applies to other models ```python # server/text_generation_server/cli.py if is_local_model: config_filename = os.path.join(model_id, "config.json") else: config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") ``` In addition, when `_tied_weights_keys` or `_keys_to_ignore_on_load_missing` is `None`, the above code will also throw an error unexpectedly. This is fixed in PR #1052 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [x] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). N/A - [ ] Did you write any new necessary tests? N/A ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --- server/text_generation_server/cli.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e0b8c0fec5b..330481394c3 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -179,7 +179,10 @@ def download_weights( import json - config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") + if is_local_model: + config_filename = os.path.join(model_id, "config.json") + else: + config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") with open(config_filename, "r") as f: config = json.load(f) architecture = config["architectures"][0]