From 30380dbffb09857012e8ca39edbcef51fddc11fd Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Wed, 19 Jun 2024 06:13:15 -0700 Subject: [PATCH] Fix formatting and tests --- .../callbacks/eval_output_logging_callback.py | 13 +++--- .../test_eval_output_logging_callback.py | 45 +++++++++++++++++-- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/llmfoundry/callbacks/eval_output_logging_callback.py b/llmfoundry/callbacks/eval_output_logging_callback.py index 563e71731c..c333746c81 100644 --- a/llmfoundry/callbacks/eval_output_logging_callback.py +++ b/llmfoundry/callbacks/eval_output_logging_callback.py @@ -47,7 +47,6 @@ def eval_batch_end(self, state: State, logger: Logger) -> None: state.metric_outputs, ) - if state.batch.get('mode') == 'generate': # Outputs are already detokenized logging_dict['outputs'] = state.outputs @@ -60,7 +59,11 @@ def eval_batch_end(self, state: State, logger: Logger) -> None: assert state.dataloader is not None dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues] tokenizer = dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues] - pad_token_id = getattr(dataset, 'pad_tok_id', dataset.tokenizer.pad_token_id) + pad_token_id = getattr( + dataset, + 'pad_tok_id', + dataset.tokenizer.pad_token_id, + ) # Depad and decode input_ids for input_list in input_ids.tolist(): @@ -91,9 +94,9 @@ def eval_batch_end(self, state: State, logger: Logger) -> None: logging_dict[key] = [tokenizer.decode(t) for t in value] elif isinstance(value[0], list): if isinstance(value[0][0], torch.Tensor): - logging_dict[key] = [ - [tokenizer.decode(choice) for choice in t] for t in value - ] + logging_dict[key] = [[ + tokenizer.decode(choice) for choice in t + ] for t in value] # Convert logging_dict from kv pairs of column name and column values to a list of rows # Example: diff --git a/tests/callbacks/test_eval_output_logging_callback.py b/tests/callbacks/test_eval_output_logging_callback.py index 7778e39fe3..3ad89c54cc 100644 --- a/tests/callbacks/test_eval_output_logging_callback.py +++ b/tests/callbacks/test_eval_output_logging_callback.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import re import torch import transformers @@ -50,6 +51,17 @@ def update_curr_eval(self, dataloader: DataLoader, dataloader_label: str): self._dataloader_label = dataloader_label +class RegexMatcher: + + def __init__(self, pattern: str): + self.pattern = re.compile(pattern) + + def __eq__(self, other: str): + if not isinstance(other, str): + return False + return bool(self.pattern.match(other)) + + def mock_lm_computation( metric: Metric, tokenizer: transformers.AutoTokenizer, @@ -193,23 +205,27 @@ def test_eval_output_logging_lm( assert f'lm_acc_step_0' in in_memory_logger.tables # Only want one table - we log once to a single step value during eval_end() assert len(in_memory_logger.tables) == 1 - assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['columns'] == [ + logged_data = json.loads(in_memory_logger.tables[f'lm_acc_step_0']) + assert logged_data['columns'] == [ 'context', 'label', 'output', 'result', 'metric_name', + 'outputs', 'input', 'run_name', ] + # We use the same data in each batch - assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['data'] == [ + assert logged_data['data'] == [ [ 'The dog is', ' furry', ' furry', 1, 'InContextLearningLMAccuracy', + RegexMatcher(r' dog is furry(\[PAD\])+I'), 'The dog is furry', 'mock_name', ], @@ -219,6 +235,7 @@ def test_eval_output_logging_lm( '[PAD]', 0, 'InContextLearningLMAccuracy', + RegexMatcher(r' love to eat(\[PAD\])+I'), 'I love to eat pie', 'mock_name', ], @@ -228,6 +245,7 @@ def test_eval_output_logging_lm( ' long lines', 1, 'InContextLearningLMAccuracy', + RegexMatcher(r' hate long lines(\[PAD\])+The'), 'I hate long lines', 'mock_name', ], @@ -237,6 +255,7 @@ def test_eval_output_logging_lm( ' snowy', 1, 'InContextLearningLMAccuracy', + RegexMatcher(r' weather is snowy(\[PAD\])+The'), 'The weather is snowy', 'mock_name', ], @@ -246,6 +265,7 @@ def test_eval_output_logging_lm( ' furry', 1, 'InContextLearningLMAccuracy', + RegexMatcher(r' dog is furry(\[PAD\])+I'), 'The dog is furry', 'mock_name', ], @@ -255,6 +275,7 @@ def test_eval_output_logging_lm( '[PAD]', 0, 'InContextLearningLMAccuracy', + RegexMatcher(r' love to eat(\[PAD\])+I'), 'I love to eat pie', 'mock_name', ], @@ -264,6 +285,7 @@ def test_eval_output_logging_lm( ' long lines', 1, 'InContextLearningLMAccuracy', + RegexMatcher(r' hate long lines(\[PAD\])+The'), 'I hate long lines', 'mock_name', ], @@ -273,6 +295,7 @@ def test_eval_output_logging_lm( ' snowy', 1, 'InContextLearningLMAccuracy', + RegexMatcher(r' weather is snowy(\[PAD\])+The'), 'The weather is snowy', 'mock_name', ], @@ -314,7 +337,8 @@ def test_eval_output_logging_mc( assert f'mc_acc_step_0' in in_memory_logger.tables # Only want one table - we log once to a single step value during eval_end() assert len(in_memory_logger.tables) == 1 - assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['columns'] == [ + logged_data = json.loads(in_memory_logger.tables[f'mc_acc_step_0']) + assert logged_data['columns'] == [ 'context', 'correct_choice', 'correct_choice_idx', @@ -323,11 +347,12 @@ def test_eval_output_logging_mc( 'all_choices', 'result', 'metric_name', + 'outputs', 'input', 'run_name', ] # We use the same data for each batch - assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['data'] == [ + assert logged_data['data'] == [ [ 'Q: How do you cook a cake?', ' A: turn on the oven', @@ -340,6 +365,9 @@ def test_eval_output_logging_mc( ], 1, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: turn on the oven', 'mock_name', ], @@ -355,6 +383,9 @@ def test_eval_output_logging_mc( ], 0, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: do a backflip', 'mock_name', ], @@ -370,6 +401,9 @@ def test_eval_output_logging_mc( ], 1, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: turn on the oven', 'mock_name', ], @@ -385,6 +419,9 @@ def test_eval_output_logging_mc( ], 0, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: do a backflip', 'mock_name', ],