Skip to content

Commit

Permalink
Add outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
sjawhar authored Jun 14, 2024
1 parent d7afd8a commit a1b3f78
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions llmfoundry/callbacks/eval_output_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,25 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
state.metric_outputs,
)

# If batch mode is not generate, outputs will be logits

if state.batch.get('mode') == 'generate':
# Outputs are already detokenized
logging_dict['outputs'] = state.outputs
elif isinstance(state.outputs, torch.Tensor):
# If batch mode is not generate, outputs will be logits
logging_dict['outputs'] = state.outputs.argmax(dim=-1)

input_ids = state.batch['input_ids']
logged_input = []
assert state.dataloader is not None
dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues]
tokenizer = dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues]
pad_token_id = getattr(dataset, 'pad_tok_id', dataset.tokenizer.pad_token_id)

# Depad and decode input_ids
for input_list in input_ids.tolist():
dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues]
depadded_input = [
tok for tok in input_list if tok != dataset.tokenizer.pad_token_id
]
logged_input.append(dataset.tokenizer.decode(depadded_input))
depadded_input = [tok for tok in input_list if tok != pad_token_id]
logged_input.append(tokenizer.decode(depadded_input))
logging_dict['input'] = logged_input

# Log token indices if toggled
Expand All @@ -85,18 +88,12 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
for key, value in logging_dict.items():
# All types in list are the same
if isinstance(value[0], torch.Tensor):
logging_dict[key] = [
state.dataloader.dataset. # pyright: ignore[reportGeneralTypeIssues]
tokenizer.decode( # pyright: ignore[reportGeneralTypeIssues]
t,
) for t in value
]
logging_dict[key] = [tokenizer.decode(t) for t in value]
elif isinstance(value[0], list):
if isinstance(value[0][0], torch.Tensor):
tokenizer = state.dataloader.dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues]
logging_dict[key] = [[
tokenizer.decode(choice) for choice in t
] for t in value]
logging_dict[key] = [
[tokenizer.decode(choice) for choice in t] for t in value
]

# Convert logging_dict from kv pairs of column name and column values to a list of rows
# Example:
Expand Down

0 comments on commit a1b3f78

Please sign in to comment.