From a1b3f786a2224fb9f0b21373e68527a6e00e1eef Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Thu, 13 Jun 2024 20:41:04 -0700 Subject: [PATCH] Add outputs --- .../callbacks/eval_output_logging_callback.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/llmfoundry/callbacks/eval_output_logging_callback.py b/llmfoundry/callbacks/eval_output_logging_callback.py index 0ccc919c2a..563e71731c 100644 --- a/llmfoundry/callbacks/eval_output_logging_callback.py +++ b/llmfoundry/callbacks/eval_output_logging_callback.py @@ -47,22 +47,25 @@ def eval_batch_end(self, state: State, logger: Logger) -> None: state.metric_outputs, ) - # If batch mode is not generate, outputs will be logits + if state.batch.get('mode') == 'generate': # Outputs are already detokenized logging_dict['outputs'] = state.outputs + elif isinstance(state.outputs, torch.Tensor): + # If batch mode is not generate, outputs will be logits + logging_dict['outputs'] = state.outputs.argmax(dim=-1) input_ids = state.batch['input_ids'] logged_input = [] assert state.dataloader is not None + dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues] + tokenizer = dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues] + pad_token_id = getattr(dataset, 'pad_tok_id', dataset.tokenizer.pad_token_id) # Depad and decode input_ids for input_list in input_ids.tolist(): - dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues] - depadded_input = [ - tok for tok in input_list if tok != dataset.tokenizer.pad_token_id - ] - logged_input.append(dataset.tokenizer.decode(depadded_input)) + depadded_input = [tok for tok in input_list if tok != pad_token_id] + logged_input.append(tokenizer.decode(depadded_input)) logging_dict['input'] = logged_input # Log token indices if toggled @@ -85,18 +88,12 @@ def eval_batch_end(self, state: State, logger: Logger) -> None: for key, value in logging_dict.items(): # All types in list are the same if isinstance(value[0], torch.Tensor): - logging_dict[key] = [ - state.dataloader.dataset. # pyright: ignore[reportGeneralTypeIssues] - tokenizer.decode( # pyright: ignore[reportGeneralTypeIssues] - t, - ) for t in value - ] + logging_dict[key] = [tokenizer.decode(t) for t in value] elif isinstance(value[0], list): if isinstance(value[0][0], torch.Tensor): - tokenizer = state.dataloader.dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues] - logging_dict[key] = [[ - tokenizer.decode(choice) for choice in t - ] for t in value] + logging_dict[key] = [ + [tokenizer.decode(choice) for choice in t] for t in value + ] # Convert logging_dict from kv pairs of column name and column values to a list of rows # Example: