Skip to content

Commit

Permalink
Fix formatting and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sjawhar committed Jun 19, 2024
1 parent 18c3290 commit 30380db
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
13 changes: 8 additions & 5 deletions llmfoundry/callbacks/eval_output_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
state.metric_outputs,
)


if state.batch.get('mode') == 'generate':
# Outputs are already detokenized
logging_dict['outputs'] = state.outputs
Expand All @@ -60,7 +59,11 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
assert state.dataloader is not None
dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues]
tokenizer = dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues]
pad_token_id = getattr(dataset, 'pad_tok_id', dataset.tokenizer.pad_token_id)
pad_token_id = getattr(
dataset,
'pad_tok_id',
dataset.tokenizer.pad_token_id,
)

# Depad and decode input_ids
for input_list in input_ids.tolist():
Expand Down Expand Up @@ -91,9 +94,9 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
logging_dict[key] = [tokenizer.decode(t) for t in value]
elif isinstance(value[0], list):
if isinstance(value[0][0], torch.Tensor):
logging_dict[key] = [
[tokenizer.decode(choice) for choice in t] for t in value
]
logging_dict[key] = [[
tokenizer.decode(choice) for choice in t
] for t in value]

# Convert logging_dict from kv pairs of column name and column values to a list of rows
# Example:
Expand Down
45 changes: 41 additions & 4 deletions tests/callbacks/test_eval_output_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import json
import re

import torch
import transformers
Expand Down Expand Up @@ -50,6 +51,17 @@ def update_curr_eval(self, dataloader: DataLoader, dataloader_label: str):
self._dataloader_label = dataloader_label


class RegexMatcher:

def __init__(self, pattern: str):
self.pattern = re.compile(pattern)

def __eq__(self, other: str):
if not isinstance(other, str):
return False
return bool(self.pattern.match(other))


def mock_lm_computation(
metric: Metric,
tokenizer: transformers.AutoTokenizer,
Expand Down Expand Up @@ -193,23 +205,27 @@ def test_eval_output_logging_lm(
assert f'lm_acc_step_0' in in_memory_logger.tables
# Only want one table - we log once to a single step value during eval_end()
assert len(in_memory_logger.tables) == 1
assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['columns'] == [
logged_data = json.loads(in_memory_logger.tables[f'lm_acc_step_0'])
assert logged_data['columns'] == [
'context',
'label',
'output',
'result',
'metric_name',
'outputs',
'input',
'run_name',
]

# We use the same data in each batch
assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['data'] == [
assert logged_data['data'] == [
[
'The dog is',
' furry',
' furry',
1,
'InContextLearningLMAccuracy',
RegexMatcher(r' dog is furry(\[PAD\])+I'),
'The dog is furry',
'mock_name',
],
Expand All @@ -219,6 +235,7 @@ def test_eval_output_logging_lm(
'[PAD]',
0,
'InContextLearningLMAccuracy',
RegexMatcher(r' love to eat(\[PAD\])+I'),
'I love to eat pie',
'mock_name',
],
Expand All @@ -228,6 +245,7 @@ def test_eval_output_logging_lm(
' long lines',
1,
'InContextLearningLMAccuracy',
RegexMatcher(r' hate long lines(\[PAD\])+The'),
'I hate long lines',
'mock_name',
],
Expand All @@ -237,6 +255,7 @@ def test_eval_output_logging_lm(
' snowy',
1,
'InContextLearningLMAccuracy',
RegexMatcher(r' weather is snowy(\[PAD\])+The'),
'The weather is snowy',
'mock_name',
],
Expand All @@ -246,6 +265,7 @@ def test_eval_output_logging_lm(
' furry',
1,
'InContextLearningLMAccuracy',
RegexMatcher(r' dog is furry(\[PAD\])+I'),
'The dog is furry',
'mock_name',
],
Expand All @@ -255,6 +275,7 @@ def test_eval_output_logging_lm(
'[PAD]',
0,
'InContextLearningLMAccuracy',
RegexMatcher(r' love to eat(\[PAD\])+I'),
'I love to eat pie',
'mock_name',
],
Expand All @@ -264,6 +285,7 @@ def test_eval_output_logging_lm(
' long lines',
1,
'InContextLearningLMAccuracy',
RegexMatcher(r' hate long lines(\[PAD\])+The'),
'I hate long lines',
'mock_name',
],
Expand All @@ -273,6 +295,7 @@ def test_eval_output_logging_lm(
' snowy',
1,
'InContextLearningLMAccuracy',
RegexMatcher(r' weather is snowy(\[PAD\])+The'),
'The weather is snowy',
'mock_name',
],
Expand Down Expand Up @@ -314,7 +337,8 @@ def test_eval_output_logging_mc(
assert f'mc_acc_step_0' in in_memory_logger.tables
# Only want one table - we log once to a single step value during eval_end()
assert len(in_memory_logger.tables) == 1
assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['columns'] == [
logged_data = json.loads(in_memory_logger.tables[f'mc_acc_step_0'])
assert logged_data['columns'] == [
'context',
'correct_choice',
'correct_choice_idx',
Expand All @@ -323,11 +347,12 @@ def test_eval_output_logging_mc(
'all_choices',
'result',
'metric_name',
'outputs',
'input',
'run_name',
]
# We use the same data for each batch
assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['data'] == [
assert logged_data['data'] == [
[
'Q: How do you cook a cake?',
' A: turn on the oven',
Expand All @@ -340,6 +365,9 @@ def test_eval_output_logging_mc(
],
1,
'InContextLearningMultipleChoiceAccuracy',
RegexMatcher(
r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q',
),
'Q: How do you cook a cake? A: turn on the oven',
'mock_name',
],
Expand All @@ -355,6 +383,9 @@ def test_eval_output_logging_mc(
],
0,
'InContextLearningMultipleChoiceAccuracy',
RegexMatcher(
r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q',
),
'Q: How do you cook a cake? A: do a backflip',
'mock_name',
],
Expand All @@ -370,6 +401,9 @@ def test_eval_output_logging_mc(
],
1,
'InContextLearningMultipleChoiceAccuracy',
RegexMatcher(
r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q',
),
'Q: How do you cook a cake? A: turn on the oven',
'mock_name',
],
Expand All @@ -385,6 +419,9 @@ def test_eval_output_logging_mc(
],
0,
'InContextLearningMultipleChoiceAccuracy',
RegexMatcher(
r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q',
),
'Q: How do you cook a cake? A: do a backflip',
'mock_name',
],
Expand Down

0 comments on commit 30380db

Please sign in to comment.