From 1a2fac0c25be354c3e1531301ed69202af66c085 Mon Sep 17 00:00:00 2001 From: sanjari-orb <137819448+sanjari-orb@users.noreply.github.com> Date: Fri, 14 Jun 2024 10:43:14 -0700 Subject: [PATCH] Add registry for ICL datasets (#1252) --- llmfoundry/eval/datasets/__init__.py | 12 + .../in_context_learning_evaluation.py | 519 ++++++++++-------- llmfoundry/registry.py | 17 + llmfoundry/utils/builders.py | 53 +- .../eval/test_in_context_learning_datasets.py | 363 ++++++------ tests/test_registry.py | 1 + 6 files changed, 537 insertions(+), 428 deletions(-) diff --git a/llmfoundry/eval/datasets/__init__.py b/llmfoundry/eval/datasets/__init__.py index 02a2b88b21..a3a36053da 100644 --- a/llmfoundry/eval/datasets/__init__.py +++ b/llmfoundry/eval/datasets/__init__.py @@ -22,6 +22,18 @@ tokenizer_needs_prefix_space, trim_context, ) +from llmfoundry.registry import icl_datasets + +icl_datasets.register( + 'multiple_choice', + func=InContextLearningMultipleChoiceTaskDataset, +) +icl_datasets.register('schema', func=InContextLearningSchemaTaskDataset) +icl_datasets.register('language_modeling', func=InContextLearningLMTaskDataset) +icl_datasets.register( + 'generation_task_with_answers', + func=InContextLearningGenerationTaskWithAnswersDataset, +) __all__ = [ 'InContextLearningDataset', diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index debb0dbc6f..c87b38b09a 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -19,6 +19,7 @@ from datasets import IterableDataset, load_dataset from torch.utils.data import DataLoader, Dataset +from llmfoundry import registry from llmfoundry.eval.datasets.utils import ( convert_tokens_to_tensors, get_continuation_span, @@ -29,6 +30,7 @@ tokenizer_needs_prefix_space, trim_context, ) +from llmfoundry.utils.registry_utils import construct_from_registry log = logging.getLogger(__name__) @@ -114,11 +116,11 @@ def __init__( max_seq_len: int, pad_tok_id: int, num_fewshot: int, - fewshot_random_seed: int, - prompt_string: str, - example_delimiter: str, - continuation_delimiter: str, destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', prelimiter: str = '', context_key: str = 'context', answer_key: str = 'answer', @@ -189,6 +191,20 @@ def __len__(self) -> int: def get_num_samples_in_batch(self, batch: Dict) -> int: return batch['input_ids'].shape[0] + def get_effective_batch_size(self, batch_size: int) -> int: + r"""Returns effective batch size computed for given ICL task. + + The effective batch size may not be equal to the configured evaluation + batch size because for certain ICL tasks, >1 prompts can get created + for every input query depending on the number of choices/continuations. + This requires the effective batch size to be reduced to prevent larger batches than expected during eval. For example, + check InContextLearningMultipleChoiceTaskDataset. + + Args: + batch_size (int): Original batch size configured for ICL evaluations + """ + return batch_size + def update_generation_kwargs(self, generation_kwargs: Dict) -> None: r"""Updates self.base_batch with the passed in generation_kwargs. @@ -519,46 +535,12 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) return batch - def split_batch(self, batch: Any, - microbatch_size: Union[int, float]) -> Sequence[Any]: - """Handling for certain specialty columns that must be split into. - - batches in different formats. - - Args: - batch (Dict): Batch of data - microbatch_size (int | float): Size of microbatches - - Returns: - List: List of chunked batches - """ - # Don't split kwargs that don't change - # Normally split torch tensors - # List split lists of strings - if isinstance(microbatch_size, float): - raise ValueError( - 'split_batch does not support floating point microbatch_size.', - ) - chunked = {} - for k, v in batch.items(): - if k in self.static_keys: - # Defer broadcasting until we know num_chunks - pass - elif k in self.list_keys: - chunked[k] = _split_list(v, microbatch_size) - elif k in self.tensor_keys: - chunked[k] = _default_split_batch(v, microbatch_size) - else: - raise ValueError(f'Unexpected key {k} in batch splitting') - num_chunks = len(chunked['input_ids']) - for k, v in batch.items(): - if k in self.static_keys: - chunked[k] = [v] * num_chunks - - batched_list = [{k: v[idx] - for k, v in chunked.items()} - for idx in range(num_chunks)] - return batched_list + def split_batch( + self, + batch: Any, + microbatch_size: Union[int, float], + ) -> Sequence[Any]: + return _default_split_batch(batch, microbatch_size) class InContextLearningGenerationTaskWithAnswersDataset( @@ -584,13 +566,31 @@ class InContextLearningGenerationTaskWithAnswersDataset( def __init__( self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + context_key: str = 'context', + answer_key: str = 'answer', + strip_dataset: bool = True, + padding_size: Optional[int] = None, + base_batch: Optional[Dict] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, cot_delimiter: str = '', early_stopping_criteria: Optional[List[str]] = None, do_normalization: bool = True, - *args: Any, - **kwargs: Any, ): - if kwargs['tokenizer'].eos_token_id is None: + if tokenizer.eos_token_id is None: raise ValueError( '`InContextLearningGenerationTaskWithAnswersDataset` tokenizer must have non-null `eos_token_id`', ) @@ -607,13 +607,32 @@ def __init__( tensor_keys = ['input_ids', 'attention_mask'] list_keys = ['labels'] super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + context_key=context_key, + answer_key=answer_key, + strip_dataset=strip_dataset, + padding_size=padding_size, + base_batch=base_batch, + batch_mapping=batch_mapping, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + # specific to ICL dataset padding_side='left', tokenize_labels=False, static_keys=static_keys, list_keys=list_keys, tensor_keys=tensor_keys, - *args, - **kwargs, ) # NOTE: set these after init call because they take class vars self.early_stopping_criteria = early_stopping_criteria @@ -635,8 +654,8 @@ def __init__( 'input_ids': self.context_key, 'labels': 'aliases', } - if 'generation_kwargs' in kwargs: - self.update_generation_kwargs(kwargs['generation_kwargs']) + if generation_kwargs: + self.update_generation_kwargs(generation_kwargs) def read_dataset( self, @@ -765,6 +784,45 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch['generation_kwargs']['stopping_criteria'] = stopping_criteria return batch + def split_batch(self, batch: Any, + microbatch_size: Union[int, float]) -> Sequence[Any]: + """Split batch handling for special columns. + + Args: + batch (Dict): Batch of data + microbatch_size (int | float): Size of microbatches + + Returns: + List: List of chunked batches + """ + # Don't split kwargs that don't change + # Normally split torch tensors + # List split lists of strings + if isinstance(microbatch_size, float): + raise ValueError( + 'split_batch does not support floating point microbatch_size.', + ) + chunked = {} + for k, v in batch.items(): + if k in self.static_keys: + # Defer broadcasting until we know num_chunks + pass + elif k in self.list_keys: + chunked[k] = _split_list(v, microbatch_size) + elif k in self.tensor_keys: + chunked[k] = _default_split_batch(v, microbatch_size) + else: + raise ValueError(f'Unexpected key {k} in batch splitting') + num_chunks = len(chunked['input_ids']) + for k, v in batch.items(): + if k in self.static_keys: + chunked[k] = [v] * num_chunks + + batched_list = [{k: v[idx] + for k, v in chunked.items()} + for idx in range(num_chunks)] + return batched_list + class InContextLearningLMTaskDataset(InContextLearningDataset): """A dataset that constructs batches for in-context learning language. @@ -779,8 +837,50 @@ class InContextLearningLMTaskDataset(InContextLearningDataset): See InContextLearningDataset for more details. """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__( + self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + context_key: str = 'context', + strip_dataset: bool = True, + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + static_keys: Optional[List] = None, + list_keys: Optional[List] = None, + ): super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + context_key=context_key, + strip_dataset=strip_dataset, + tokenize_labels=tokenize_labels, + padding_size=padding_size, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + list_keys=list_keys, + # specific to ICL dataset answer_key='continuation', static_keys=['mode'], tensor_keys=[ @@ -800,8 +900,6 @@ def __init__(self, *args: Any, **kwargs: Any): 'labels': 'context', }, padding_side='right', - *args, - **kwargs, ) @@ -833,13 +931,33 @@ class InContextLearningMultipleChoiceTaskDataset(InContextLearningDataset): def __init__( self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + context_key: str = 'query', + tensor_keys: Optional[List] = None, + answer_key: str = 'answer', + strip_dataset: bool = True, + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + list_keys: Optional[List] = None, choices_key: str = 'choices', static_keys: Optional[List] = None, list_of_tensors_keys: Optional[List] = None, list_of_tuples_keys: Optional[List] = None, list_of_primitives: Optional[List] = None, - *args: Any, - **kwargs: Any, ): self.choices_key = choices_key base_batch = { @@ -850,25 +968,42 @@ def __init__( 'gold_indices': [], 'choice_groupings': [], } - context_key = kwargs.pop('context_key', 'query') - static_keys = kwargs.pop('static_keys', ['mode', 'generation_kwargs']) - tensor_keys = kwargs.pop( - 'tensor_keys', - ['input_ids', 'labels', 'attention_mask'], - ) + if not static_keys: + static_keys = ['mode', 'generation_kwargs'] + if not tensor_keys: + tensor_keys = ['input_ids', 'labels', 'attention_mask'] self.list_of_tensors_keys = list_of_tensors_keys or [ 'continuation_indices', ] self.list_of_tuples_keys = list_of_tuples_keys or ['choice_groupings'] self.list_of_primitives = list_of_primitives or ['gold_indices'] super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + answer_key=answer_key, + strip_dataset=strip_dataset, + tokenize_labels=tokenize_labels, + padding_size=padding_size, + batch_mapping=batch_mapping, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + list_keys=list_keys, + # specific to ICL dataset context_key=context_key, base_batch=base_batch, static_keys=static_keys, tensor_keys=tensor_keys, padding_side='right', - *args, - **kwargs, ) self.num_choices = len(self.dataset[0][self.choices_key]) self.batch_mapping_per_choice = { @@ -877,6 +1012,11 @@ def __init__( } self.batch_map_per_example = {'gold_indices': 'gold'} + def get_effective_batch_size(self, batch_size: int) -> int: + batch_size = max(self.num_choices, batch_size) + effective_batchsize = batch_size // self.num_choices + return effective_batchsize + def get_answer_from_example( self, example: Dict, @@ -1095,21 +1235,58 @@ class InContextLearningSchemaTaskDataset( def __init__( self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + answer_key: str = 'answer', + strip_dataset: bool = True, + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + list_keys: Optional[List] = None, choices_key: str = 'context_options', - *args: Any, - **kwargs: Any, ): static_keys = ['mode'] tensor_keys = ['input_ids', 'labels', 'attention_mask'] list_of_tensors_keys = ['continuation_indices'] super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + answer_key=answer_key, + strip_dataset=strip_dataset, + tokenize_labels=tokenize_labels, + padding_size=padding_size, + batch_mapping=batch_mapping, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + list_keys=list_keys, + # specific to ICL dataset choices_key=choices_key, context_key=choices_key, static_keys=static_keys, tensor_keys=tensor_keys, list_of_tensors_keys=list_of_tensors_keys, - *args, - **kwargs, ) self.base_batch = { 'input_ids': [], @@ -1120,6 +1297,11 @@ def __init__( 'choice_groupings': [], } + def get_effective_batch_size(self, batch_size: int) -> int: + batch_size = max(self.num_choices, batch_size) + effective_batchsize = batch_size // self.num_choices + return effective_batchsize + def construct_context( self, example: Dict[str, Any], @@ -1294,23 +1476,10 @@ def build_icl_dataloader( dataset_uri: str, tokenizer: transformers.PreTrainedTokenizerBase, batch_size: int, - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, # e.g. 'translate english to french:' - example_delimiter: str, # e.g. '\n' - continuation_delimiter: str, # e.g. '' hf_loading_vars: Dict, hf_parsing_map: Dict, - destination_path: str, - prelimiter: str, # e.g. 'Question: ' - cot_delimiter: str, # e.g. ' ### ' - fewshot_random_seed: int, - pass_at_k: int, - generations_per_sample: int, - generation_kwargs: Dict, - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True, + destination_path: str = '', + kwargs: Optional[Dict[str, Any]] = None, ) -> DataSpec: """Factory method that builds the specific dataset for the specified. @@ -1323,108 +1492,36 @@ def build_icl_dataloader( this might be different) 3. set the `split_batch` function if necessary """ - if icl_task_type == 'multiple_choice': - dataset = InContextLearningMultipleChoiceTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - batch_size = max(dataset.num_choices, batch_size) - effective_batchsize = batch_size // dataset.num_choices - elif icl_task_type == 'schema': - dataset = InContextLearningSchemaTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - batch_size = max(dataset.num_choices, batch_size) - effective_batchsize = batch_size // dataset.num_choices - elif icl_task_type == 'language_modeling': - dataset = InContextLearningLMTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - effective_batchsize = batch_size - elif icl_task_type == 'generation_task_with_answers': - dataset = InContextLearningGenerationTaskWithAnswersDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - cot_delimiter=cot_delimiter, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, - generation_kwargs=generation_kwargs, - ) - effective_batchsize = batch_size - else: - raise Exception(f'Unrecognized ICL task type: {icl_task_type}') - + # Add named parameters to kwargs + if kwargs is None: + kwargs = {} + kwargs.update({ + 'dataset_uri': dataset_uri, + 'tokenizer': tokenizer, + 'hf_loading_vars': hf_loading_vars, + 'hf_parsing_map': hf_parsing_map, + 'destination_path': destination_path, + }) + dataset = construct_from_registry( + name=icl_task_type, + registry=registry.icl_datasets, + partial_function=False, + pre_validation_function=None, + post_validation_function=None, + kwargs=kwargs, + ) sampler = dist.get_sampler(dataset, drop_last=False, shuffle=False) - split_batch = None - if isinstance( - dataset, - ( - InContextLearningMultipleChoiceTaskDataset, - InContextLearningGenerationTaskWithAnswersDataset, - ), - ): - split_batch = dataset.split_batch - return DataSpec( DataLoader( dataset, - batch_size=effective_batchsize, + batch_size=dataset.get_effective_batch_size(batch_size), sampler=sampler, collate_fn=dataset.collate_fn, ), device_transforms=None, get_num_samples_in_batch=dataset.get_num_samples_in_batch, - split_batch=split_batch, + split_batch=dataset.split_batch, ) @@ -1514,24 +1611,11 @@ def get_icl_task_dataloader( tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], batch_size: int, - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, # e.g. 'translate english to french:' - example_delimiter: str, # e.g. '\n' - continuation_delimiter: str = '', - destination_path: str = '', - question_prelimiter: str = '', # e.g. 'Question: ' - fewshot_random_seed: int = 1234, - pass_at_k: int = 1, - generations_per_sample: int = 1, - cot_delimiter: str = '', has_categories: bool = False, hf_loading_vars: Optional[Dict] = None, hf_parsing_map: Optional[Dict] = None, - generation_kwargs: Optional[Dict] = None, - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True, + destination_path: str = '', + kwargs: Optional[Dict[str, Any]] = None, ) -> Union[DataSpec, Dict[str, DataSpec]]: r"""Constructs a dataloader (or dataloaders if has_categories is True) @@ -1588,28 +1672,12 @@ def get_icl_task_dataloader( The default keys expected are "context" and "answer". tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to map between strings and token ids. batch_size (int): Size of a batch used for eval - max_seq_len (int): The maximum sequence length supported by the model. - pad_tok_id (int): The special token used for padding batches. - num_fewshot (int): The number of complete fewshot examples to prepend before each test example. These are not identical across examples. - prompt_string (str, default = ''): Prompt string to put once before all fewshot examples/test examples (e.g. 'Translate english to french.'). - example_delimiter (str, default = '\\n'): Separator inserted before (context, answer) pairs (e.g. '\\n') for fewshot sampling and prompting. - continuation_delimiter: (str, default = ' '): Separator inserted between context and answer in each example (e.g. '\\nA: '). - destination_path: (str, default = ''): This is the local file where remote datasets will be saved. - question_prelimiter: (str, default = ''): Text to be prepended before each context, including few shot examples (e.g. "Question: "). - fewshot_random_seed (int, default = 1234): Random seed to use for fewshot sampling - pass_at_k (int): k for how many chances the model gets to write passing code. - generations_per_sample (int): How many outputs to generate per prompt. Passed in generation_kwargs under "num_return_sequences" and overwritten by generation_kwargs dict. - cot_delimiter (str): Delimiter to place between chain of thoughts and continuations. has_categories: (bool): If ``True``, we will search the dataset file for a category key, and partition the dataset into a separate dataloader for each category occurring in the data. hf_loading_vars (Dict, default = None): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. hf_parsing_map (Dict, default = None): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. - generation_kwargs (Dict, default = None): A dictionary containing keyword arguments to be passed along to the model's generate function. Overwrites any previously specified generation - keyword args in this function (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig - for more details) - early_stopping (List, default = None): A list of strings that, when found in a model's output, will be treated as a stopping criteria at metric computation time. - Used in generation tasks with CoT - do_normalization (bool, default = True): Whether or not to normalize the outputs and labels in InContextLearningGenerationTaskWithAnswersDataset. Only used in generation tasks. + kwargs (Dict[str, Any], default=None): Dictionary containing a mapping + from ICL dataset constructor's parameter names and their desired values. Returns: DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided. @@ -1618,11 +1686,6 @@ def get_icl_task_dataloader( hf_loading_vars = {} if hf_parsing_map is None: hf_parsing_map = {} - if generation_kwargs is None: - generation_kwargs = {} - if early_stopping_criteria is None: - early_stopping_criteria = [] - if has_categories: result_dls = {} output_files = partition_dataset_by_category( @@ -1639,23 +1702,10 @@ def get_icl_task_dataloader( dataset_uri=partition_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, destination_path=partition_uri + '_tmp', - prelimiter=question_prelimiter, - cot_delimiter=cot_delimiter, - fewshot_random_seed=fewshot_random_seed, - pass_at_k=pass_at_k, - generations_per_sample=generations_per_sample, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, + kwargs=kwargs, ) return result_dls else: @@ -1664,21 +1714,8 @@ def get_icl_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - continuation_delimiter=continuation_delimiter, destination_path=destination_path, - prelimiter=question_prelimiter, - cot_delimiter=cot_delimiter, - fewshot_random_seed=fewshot_random_seed, - pass_at_k=pass_at_k, - generations_per_sample=generations_per_sample, - generation_kwargs=generation_kwargs, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, + kwargs=kwargs, ) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 0c8e64b759..f36f53fffa 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -8,6 +8,7 @@ from composer.optim import ComposerScheduler from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataloader +from torch.utils.data import Dataset from torchmetrics import Metric from transformers import PreTrainedTokenizerBase @@ -206,6 +207,21 @@ description=_metrics_description, ) +_icl_datasets_description = ( + 'The ICL datasets registry is used to register an torch.utils.data.Dataset class which can be used for ICL tasks.' +) +icl_datasets = create_registry( + 'llmfoundry', + 'icl_datasets', + # TODO: Change type from Dataset to + # llmfoundry.eval.InContextLearningDataset. + # Using ICL dataset here introduces a circular import dependency between + # the registry and eval packages right now, thus needs some refactoring. + generic_type=Type[Dataset], + entry_points=True, + description=_icl_datasets_description, +) + __all__ = [ 'loggers', 'callbacks', @@ -228,4 +244,5 @@ 'attention_classes', 'attention_implementations', 'fcs', + 'icl_datasets', ] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 73eb026d98..f9e84aab45 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import copy import functools import logging import os @@ -545,22 +546,10 @@ def _validate_cfg(icl_cfg: Dict[str, Any]): f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg["icl_task_type"]}.', ) - if 'prompt_string' not in icl_cfg: - icl_cfg['prompt_string'] = '' - if 'example_delimiter' not in icl_cfg: - icl_cfg['example_delimiter'] = '\n' - if 'continuation_delimiter' not in icl_cfg: - icl_cfg['continuation_delimiter'] = ' ' if 'max_seq_len' not in icl_cfg: icl_cfg['max_seq_len'] = default_max_seq_len if 'batch_size' not in icl_cfg: icl_cfg['batch_size'] = default_batch_size - if 'pass_at_k' not in icl_cfg: - icl_cfg['pass_at_k'] = 1 - if 'fewshot_random_seed' not in icl_cfg: - icl_cfg['fewshot_random_seed'] = 1234 - if 'generations_per_sample' not in icl_cfg: - icl_cfg['generations_per_sample'] = 1 if 'num_beams' in icl_cfg: raise ValueError( @@ -579,6 +568,7 @@ def _validate_cfg(icl_cfg: Dict[str, Any]): pad_tok_id = tokenizer.eos_token_id else: pad_tok_id = tokenizer.pad_token_id + label = f'{icl_cfg["label"]}/{num_fewshot}-shot' metric_names = list(icl_cfg['metric_names']) # TODO: fix Composer bug when copying local paths and destination exists @@ -589,38 +579,37 @@ def _validate_cfg(icl_cfg: Dict[str, Any]): hf_parsing_map = icl_cfg.get('hf_parsing_map', {}) hf_loading_vars = icl_cfg.get('hf_loading_vars', {}) - early_stopping_criteria = icl_cfg.get( 'early_stopping_criteria', - None, + [], ) + # TODO: fix manual removal of non-constructor fields + icl_constructor_kwargs = copy.deepcopy(icl_cfg) + icl_constructor_kwargs.pop('label', None) + icl_constructor_kwargs.pop('metric_names', None) + icl_constructor_kwargs.pop('icl_task_type', None) + icl_constructor_kwargs.pop('batch_size', None) + icl_constructor_kwargs.pop('has_categories', None) + + # Add custom constructor arguments + icl_constructor_kwargs['pad_tok_id'] = pad_tok_id + icl_constructor_kwargs['num_fewshot'] = num_fewshot + assert early_stopping_criteria is None or isinstance( early_stopping_criteria, list, ) + dataloaders = get_icl_task_dataloader( - icl_cfg['icl_task_type'], - icl_cfg['dataset_uri'], - tokenizer, + icl_task_type=icl_cfg['icl_task_type'], + dataset_uri=icl_cfg['dataset_uri'], + tokenizer=tokenizer, batch_size=icl_cfg['batch_size'], - max_seq_len=icl_cfg['max_seq_len'], - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=icl_cfg['prompt_string'], - example_delimiter=icl_cfg['example_delimiter'], hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - continuation_delimiter=icl_cfg['continuation_delimiter'], - question_prelimiter=icl_cfg.get('question_prelimiter', ''), - destination_path=destination_path, - fewshot_random_seed=icl_cfg['fewshot_random_seed'], - pass_at_k=icl_cfg['pass_at_k'], - generations_per_sample=icl_cfg['generations_per_sample'], has_categories=icl_cfg.get('has_categories', False), - cot_delimiter=icl_cfg.get('cot_delimiter', ''), - generation_kwargs=icl_cfg.get('generation_kwargs', {}), - early_stopping_criteria=early_stopping_criteria, - do_normalization=icl_cfg.get('do_normalization', True), + destination_path=destination_path, + kwargs=icl_constructor_kwargs, ) if 'has_categories' in icl_cfg and icl_cfg[ 'has_categories'] and isinstance(dataloaders, dict): diff --git a/tests/eval/test_in_context_learning_datasets.py b/tests/eval/test_in_context_learning_datasets.py index a3c3e88364..b5eacdeb0f 100644 --- a/tests/eval/test_in_context_learning_datasets.py +++ b/tests/eval/test_in_context_learning_datasets.py @@ -1090,15 +1090,22 @@ def test_mc_task_dataloader_subcategories( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=2, - prompt_string= - 'The following are multiple choice questions (with answers).\n', - example_delimiter='\n', - continuation_delimiter='Answer: ', - destination_path=str(tmp_path / 'icl.jsonl'), has_categories=True, + destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'num_fewshot': + 2, + 'max_seq_len': + seqlen, + 'pad_tok_id': + tokenizer.eos_token_id, + 'prompt_string': + 'The following are multiple choice questions (with answers).\n', + 'example_delimiter': + '\n', + 'continuation_delimiter': + 'Answer: ', + }, ) assert isinstance(dls, dict) @@ -1142,13 +1149,15 @@ def test_lm_task_dataloader_extra_space( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=10, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 10, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1192,13 +1201,15 @@ def test_lm_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 0, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': '', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1241,14 +1252,16 @@ def test_schema_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - question_prelimiter=prelimiter, - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'example_delimiter': '\n', + 'prelimiter': prelimiter, + 'continuation_delimiter': '', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) @@ -1300,13 +1313,15 @@ def test_schema_task_dataloader_sentpiece_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) @@ -1358,13 +1373,15 @@ def test_lm_task_dataloader_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': '', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1410,13 +1427,15 @@ def test_mc_task_dataloader_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1473,13 +1492,15 @@ def test_mc_split_batch( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1550,13 +1571,15 @@ def test_qa_split_batch( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=8, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 0, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dl, DataSpec) # pyright @@ -1612,14 +1635,16 @@ def test_qa_task_dataloader_w_null_eos( dataset_uri, tokenizer, batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter='\nA:', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': '\n', + 'prelimiter': 'Q: ', + 'continuation_delimiter': '\nA:', + }, ) @@ -1647,14 +1672,16 @@ def test_qa_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter='\nA:', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': '\n', + 'prelimiter': 'Q: ', + 'continuation_delimiter': '\nA:', + }, ) assert isinstance(dl, DataSpec) @@ -1714,15 +1741,17 @@ def test_qa_task_with_cot_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter="\nA: Let's think step by step. ", - cot_delimiter=' #### ', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'prelimiter': 'Q: ', + 'continuation_delimiter': "\nA: Let's think step by step. ", + 'cot_delimiter': ' #### ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1779,14 +1808,16 @@ def test_mc_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - question_prelimiter=prelimiter, - example_delimiter=example_delimiter, - continuation_delimiter='\nA: ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'prelimiter': prelimiter, + 'example_delimiter': example_delimiter, + 'continuation_delimiter': '\nA: ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1851,13 +1882,15 @@ def test_lm_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': '', + }, ) evaluator = Evaluator( @@ -1903,13 +1936,15 @@ def test_schema_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -1968,14 +2003,16 @@ def test_mc_task_evaluation_subcategories( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), has_categories=True, + kwargs={ + 'max_seq_len': max_seq_len, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dls, dict) @@ -2039,13 +2076,15 @@ def test_mc_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=64, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 64, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -2107,13 +2146,15 @@ def test_qa_task_evaluation_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -2168,14 +2209,16 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter="A: Let's think step by step. ", - cot_delimiter=' #### ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': "A: Let's think step by step. ", + 'cot_delimiter': ' #### ', + }, ) evaluator = Evaluator( @@ -2228,13 +2271,15 @@ def test_qa_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -2288,14 +2333,16 @@ def test_qa_task_with_cot_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter="A: Let's think step by step", - cot_delimiter=' #### ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': "A: Let's think step by step", + 'cot_delimiter': ' #### ', + }, ) evaluator = Evaluator( @@ -2339,13 +2386,15 @@ def test_lm_spacing_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' UNIQUE ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' UNIQUE ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -2409,15 +2458,17 @@ def test_hf_dataloading_lm_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 0, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -2490,16 +2541,18 @@ def test_hf_dataloading_custom_parsing( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Orbs: ', - continuation_delimiter='\nSpell:', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': '\n', + 'prelimiter': 'Orbs: ', + 'continuation_delimiter': '\nSpell:', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright diff --git a/tests/test_registry.py b/tests/test_registry.py index 87881450d4..3bdf5a800f 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -42,6 +42,7 @@ def test_expected_registries_exist(): 'attention_classes', 'attention_implementations', 'fcs', + 'icl_datasets', } assert existing_registries == expected_registry_names