Skip to content

Commit

Permalink
Merge pull request caikit#314 from gkumbhat/add_granite_modeling_llam…
Browse files Browse the repository at this point in the history
…a_main

Add granite modeling llama main
  • Loading branch information
evaline-ju authored Feb 1, 2024
2 parents c20c444 + 90ea356 commit 18c4c55
Show file tree
Hide file tree
Showing 4 changed files with 1,902 additions and 4 deletions.
5 changes: 3 additions & 2 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def train(

# Remove _name_or_path field as a model can be
# saved in different location but still same
del base_model_config["_name_or_path"]
base_model_config.pop("_name_or_path", None)
error.value_check(
"<NLP07232147E>",
"_name_or_path" not in base_model_config,
Expand Down Expand Up @@ -585,7 +585,8 @@ def load(
if peft_config.task_type == "CAUSAL_LM":
# get the transformers Causal LM model
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path
peft_config.base_model_name_or_path,
torch_dtype=torch_dtype,
)
# get the PEFT causal LM model
model = PeftModel.from_pretrained(base_model, model_config)
Expand Down
5 changes: 5 additions & 0 deletions caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@

# Local
from ...data_model import GenerationTrainRecord, PromptOutputModelType

# Note: Below module is imported to allow loading of fm stack sphinx models
from ...toolkit.text_generation import ( # pylint: disable=unused-import
granite_modeling_llama,
)
from ...toolkit.verbalizer_utils import render_verbalizer
from .base import PretrainedModelBase

Expand Down
Loading

0 comments on commit 18c4c55

Please sign in to comment.