Skip to content

Commit

Permalink
gpt2 : Add gpt2 architecture integration (ggerganov#4555)
Browse files Browse the repository at this point in the history
  • Loading branch information
manikbhandari authored Dec 28, 2023
1 parent f679349 commit ea5497d
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ as the main playground for developing new features for the [ggml](https://github
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
- [x] [GPT-2](https://huggingface.co/gpt2)

**Multimodal models:**

Expand Down
66 changes: 66 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def from_model_architecture(model_architecture):
return QwenModel
if model_architecture == "MixtralForCausalLM":
return MixtralModel
if model_architecture == "GPT2LMHeadModel":
return GPT2Model
if model_architecture == "PhiForCausalLM":
return Phi2Model
if model_architecture == "PlamoForCausalLM":
Expand Down Expand Up @@ -225,6 +227,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.QWEN
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch == "GPT2LMHeadModel":
return gguf.MODEL_ARCH.GPT2
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2
if arch == "PlamoForCausalLM":
Expand Down Expand Up @@ -993,6 +997,68 @@ def write_tensors(self):
self.gguf_writer.add_tensor(new_name, data)


class GPT2Model(Model):
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_context_length(self.hparams["n_ctx"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)

def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)

for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias")):
continue

if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
data_torch = data_torch.transpose(1, 0)

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

data = data_torch.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")

self.gguf_writer.add_tensor(new_name, data)

# note: GPT2 output is tied to (same as) wte in original model
if new_name == "token_embd.weight":
print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor("output.weight", data)


class Phi2Model(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
Expand Down
11 changes: 10 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,16 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GPT2: [
# TODO
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PHI2: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down
10 changes: 9 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TensorNameMap:
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert
"language_model.embedding.word_embeddings", # persimmon
"wte", # gpt2
"transformer.embd.wte", # phi2
),

Expand All @@ -34,6 +35,7 @@ class TensorNameMap:
MODEL_TENSOR.POS_EMBD: (
"transformer.wpe", # gpt2
"embeddings.position_embeddings", # bert
"wpe", # gpt2
),

# Output
Expand All @@ -53,7 +55,7 @@ class TensorNameMap:
"norm", # llama-pth
"embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt
"ln_f", # refact bloom qwen
"ln_f", # refact bloom qwen gpt2
"language_model.encoder.final_layernorm", # persimmon
"lm_head.ln", # phi2
),
Expand All @@ -78,6 +80,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
"h.{bid}.ln_1", # gpt2
"transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo
),
Expand All @@ -95,6 +98,7 @@ class TensorNameMap:
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
"h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.mixer.Wqkv", # phi2
),

Expand Down Expand Up @@ -137,6 +141,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
),
Expand All @@ -159,6 +164,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi
"h.{bid}.ln_2", # gpt2
),

MODEL_TENSOR.FFN_GATE_INP: (
Expand All @@ -179,6 +185,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.fc_in", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"transformer.h.{bid}.mlp.w1", # qwen
"h.{bid}.mlp.c_fc", # gpt2
"transformer.h.{bid}.mlp.fc1", # phi2
"model.layers.layers.{bid}.mlp.up_proj", # plamo
),
Expand Down Expand Up @@ -218,6 +225,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
"h.{bid}.mlp.c_proj", # gpt2
"transformer.h.{bid}.mlp.fc2", # phi2
"model.layers.layers.{bid}.mlp.down_proj", # plamo
),
Expand Down
Loading

0 comments on commit ea5497d

Please sign in to comment.