From b246edd2182af5cb9c1ca819484755b35ae9c396 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 25 Dec 2024 21:41:06 -0500 Subject: [PATCH] var naming and add todo --- src/axolotl/prompt_strategies/chat_template.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 223bf14b4..ba524eb48 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -494,7 +494,7 @@ def transform_logprobs(self, sample): logprobs = sample.pop(self.logprobs_field) target_seq_len = len(logprobs) input_seq_len = len(sample["input_ids"]) - padding_len = input_seq_len - target_seq_len + input_padding_len = input_seq_len - target_seq_len top_k = len(logprobs[0]) target_logprobs = [] target_token_ids = [] @@ -502,12 +502,13 @@ def transform_logprobs(self, sample): # fill with -inf for padding_len tokens for top_k tokens # extend target_logprobs with a padding_len x top_k 2D list filled with -inf - for _ in range(padding_len): + for _ in range(input_padding_len): target_logprobs.append([-float("inf")] * top_k) target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k) for _ in range(target_seq_len): + # TODO also check against sample["labels"] target_mask.append([1] * top_k) for _, token_pos_logprobs in enumerate(logprobs):