-
Notifications
You must be signed in to change notification settings - Fork 315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Proposal] Ensure TransformerLens does not load from hugging face when config is passed in #754
Comments
Could you share the code you are using to load TransformerLens? You should be able to pass in your local version of the model with the param |
I've modified less code, so I've just pasted the relevant code directly here. I've labeled the python file location and line number of the code, as well as the original version of the code which I've represented as a comment, with the new code shown below the old code for your convenience in checking. In transformer_lens.HookedTransformer.py line 1257
cfg = loading.get_pretrained_model_config(
official_model_name,
# hf_cfg=hf_cfg
hf_cfg=hf_model.config,
checkpoint_index=checkpoint_index,
checkpoint_value=checkpoint_value,
fold_ln=fold_ln,
device=device,
n_devices=n_devices,
default_prepend_bos=default_prepend_bos,
dtype=dtype,
first_n_layers=first_n_layers,
**from_pretrained_kwargs,
)
In transformer_lens.loading_from_pretrained.py line 1583
# if hf_cfg is not None:
# cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config"# {}).get("load_in_4bit", False)
if hf_cfg is not None:
cfg_dict["load_in_4bit"] = hf_cfg.to_dict().get("quantization_config"{}).get("load_in_4bit", False)
In transformer_lens.loading_from_pretrained.py line 708
# def convert_hf_model_config(model_name: str, **kwargs):
def convert_hf_model_config(model_name: str, hf_config = None, **kwargs):
"""
Returns the model config for a HuggingFace model, converted to a dictionary
in the HookedTransformerConfig format.
Takes the official_model_name as an input.
"""
if (Path(model_name) / "config.json").exists():
logging.info("Loading model config from local directory")
official_model_name = model_name
else:
official_model_name = get_official_model_name(model_name)
# Load HuggingFace model config
if "llama" in official_model_name.lower():
architecture = "LlamaForCausalLM"
elif "gemma-2" in official_model_name.lower():
architecture = "Gemma2ForCausalLM"
elif "gemma" in official_model_name.lower():
architecture = "GemmaForCausalLM"
else:
# huggingface_token = os.environ.get("HF_TOKEN", None)
# hf_config = AutoConfig.from_pretrained(
# official_model_name,
# token=huggingface_token,
# **kwargs,
# )
if hf_config is None:
huggingface_token = os.environ.get("HF_TOKEN", None)
hf_config = AutoConfig.from_pretrained(
official_model_name,
token=huggingface_token,
**kwargs,
)
architecture = hf_config.architectures[0]
...
In transformer_lens.loading_from_pretrained.py line 1525 and line 1543
if Path(model_name).exists():
# If the model_name is a path, it's a local model
# cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
cfg_dict = convert_hf_model_config(model_name, hf_cfg, **kwargs)
official_model_name = model_name
else:
official_model_name = get_official_model_name(model_name)
if (
official_model_name.startswith("NeelNanda")
or official_model_name.startswith("ArthurConmy")
or official_model_name.startswith("Baidicoot")
):
cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
else:
if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
"trust_remote_code", False
):
logging.warning(
f"Loading model {official_model_name} requires setting trust_remote_code=True"
)
kwargs["trust_remote_code"] = True
# cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
cfg_dict = convert_hf_model_config(official_model_name, hf_cfg, **kwargs)
|
Proposal
Change some code that could load model locally.
Motivation
Today I want to load gpt2 model that download from huggingface website locally like Llama, but it keeps try to conncetting huggingface to download.
Then I check the code and find that
Pitch
For model downloaded from huggingface or not cache, providing a approach to load model locally.
Alternatives
Checklist
The text was updated successfully, but these errors were encountered: