diff --git a/neural_speed/convert/convert_llama.py b/neural_speed/convert/convert_llama.py index f3e1d43bf..9dae31bd8 100644 --- a/neural_speed/convert/convert_llama.py +++ b/neural_speed/convert/convert_llama.py @@ -155,6 +155,7 @@ class Params: rope_scale: float bos_token_id: int eos_token_id: int + pad_token_id: int @staticmethod def guessed(model: 'LazyModel') -> 'Params': @@ -188,6 +189,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: Path) -> 'Params': rope_scale = config["rope_scaling"]["factor"] if "factor" in config["rope_scaling"] else 1 bos_token_id = config["bos_token_id"] eos_token_id = config["eos_token_id"] + pad_token_id = config["pad_token_id"] if "pad_token_id" in config else -1 return Params( n_vocab=n_vocab, @@ -202,6 +204,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: Path) -> 'Params': rope_scale=rope_scale, bos_token_id = bos_token_id, eos_token_id = eos_token_id, + pad_token_id = pad_token_id, ) # LLaMA v2 70B params.json @@ -219,6 +222,7 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: Path) -> 'Params': ffn_hidden_size = config["intermediate_size"] bos_token_id = config["bos_token_id"] eos_token_id = config["eos_token_id"] + pad_token_id = config["pad_token_id"] if "pad_token_id" in config else -1 # hack to determine LLaMA v1 vs v2 vs CodeLlama if n_vocab == -1: @@ -234,6 +238,7 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: Path) -> 'Params': ffn_hidden_size=ffn_hidden_size, bos_token_id = bos_token_id, eos_token_id = eos_token_id, + pad_token_id = pad_token_id, ) @staticmethod @@ -1092,7 +1097,7 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None: # but bos_token_id = 1 in llama.cpp self.fout.write(struct.pack("i", params.bos_token_id)) self.fout.write(struct.pack("i", params.eos_token_id)) - self.fout.write(struct.pack("i", -1)) + self.fout.write(struct.pack("i", params.pad_token_id)) self.fout.write(struct.pack("i", -1)) def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None: