diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 6ca10fcd47..89aa917809 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -23,6 +23,12 @@ jobs: - name: "2.3.0_cu121_flash2_aws" base_image: mosaicml/pytorch:2.3.0_cu121-python3.11-ubuntu20.04-aws dep_groups: "[gpu-flash2]" + - name: "2.3.1_cu121" + base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 + dep_groups: "[gpu]" + - name: "2.3.1_cu121_aws" + base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04-aws + dep_groups: "[gpu]" steps: - name: Maximize Build Space on Worker uses: easimon/maximize-build-space@v4 diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 93612b7983..78faea8e44 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -23,6 +23,10 @@ jobs: container: mosaicml/pytorch:2.3.0_cpu-python3.11-ubuntu20.04 markers: "not gpu" pytest_command: "coverage run -m pytest" + - name: "cpu-2.3.1" + container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04 + markers: "not gpu" + pytest_command: "coverage run -m pytest" name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 31af66e51f..07e811e244 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -24,6 +24,11 @@ jobs: markers: "gpu" pytest_command: "coverage run -m pytest" pip_deps: "[all]" + - name: "gpu-2.3.1" + container: mosaicml/llm-foundry:2.3.1_cu121-latest + markers: "gpu" + pytest_command: "coverage run -m pytest" + pip_deps: "[all]" name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: diff --git a/Dockerfile b/Dockerfile index ca684dca2a..9366d7dbcd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,7 @@ ADD https://raw.githubusercontent.com/mosaicml/llm-foundry/$BRANCH_NAME/setup.py RUN rm setup.py # Install TransformerEngine -RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=3 MAX_JOBS=3 pip install git+https://github.com/NVIDIA/TransformerEngine.git@05eb6deb31c1b48e9f4380d18fe95f3c38e84335 +RUN NVTE_FRAMEWORK=pytorch CMAKE_BUILD_PARALLEL_LEVEL=4 MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@b5a7c9f # Install and uninstall foundry to cache foundry requirements RUN git clone -b $BRANCH_NAME https://github.com/mosaicml/llm-foundry.git diff --git a/README.md b/README.md index 70436271dd..c92c252395 100644 --- a/README.md +++ b/README.md @@ -230,7 +230,7 @@ python data_prep/convert_dataset_hf.py \ # Train an MPT-125m model for 10 batches composer train/train.py \ train/yamls/pretrain/mpt-125m.yaml \ - data_local=my-copy-c4 \ + variables.data_local=my-copy-c4 \ train_loader.dataset.split=train_small \ eval_loader.dataset.split=val_small \ max_duration=10ba \ diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 4712de5d5e..496e905e13 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -11,6 +11,7 @@ OptimizerMonitor, RuntimeEstimator, SpeedMonitor, + SystemMetricsMonitor, ) from llmfoundry.callbacks.async_eval_callback import AsyncEval @@ -35,6 +36,7 @@ from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector from llmfoundry.registry import callbacks, callbacks_with_config +callbacks.register('system_metrics_monitor', func=SystemMetricsMonitor) callbacks.register('lr_monitor', func=LRMonitor) callbacks.register('memory_monitor', func=MemoryMonitor) callbacks.register('memory_snapshot', func=MemorySnapshot) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 28b33b43d8..d80060d6f6 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,7 +17,7 @@ import numpy as np import torch import torch.nn as nn -from composer.core import Callback, Event, State, Time, TimeUnit +from composer.core import Callback, Event, Precision, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel @@ -37,6 +37,12 @@ from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility +try: + import transformer_engine.pytorch as te + is_te_imported = True +except ModuleNotFoundError: + is_te_imported = False + log = logging.getLogger(__name__) __all__ = ['HuggingFaceCheckpointer'] @@ -486,9 +492,19 @@ def dtensor_to_tensor_hook( ) log.debug('Saving Hugging Face checkpoint to disk') - new_model_instance.save_pretrained(temp_save_dir) + # This context manager casts the TE extra state in io.BytesIO format to tensor format + # Needed for proper hf ckpt saving. + context_manager = te.onnx_export( + True, + ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + ) + with context_manager: + new_model_instance.save_pretrained(temp_save_dir) if original_tokenizer is not None: - assert isinstance(original_tokenizer, PreTrainedTokenizerBase) + assert isinstance( + original_tokenizer, + PreTrainedTokenizerBase, + ) original_tokenizer.save_pretrained(temp_save_dir) # Only need to edit files for MPT because it has custom code diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 639beba6f0..160e9bfe3b 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -222,7 +222,7 @@ def build_finetuning_dataloader( cache_limit=dataset_cfg.get('cache_limit', None), partition_algo=dataset_cfg.get('partition_algo', 'relaxed'), num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None), - batch_size=dataset_batch_size, + batch_size=dataloader_batch_size, shuffle=dataset_cfg.get('shuffle', False), shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'), shuffle_seed=dataset_cfg.get('shuffle_seed', 9176), @@ -233,6 +233,7 @@ def build_finetuning_dataloader( max_seq_len=dataset_cfg['max_seq_len'], allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False), replication=replication_factor, + packing_ratio=dataloader_batch_size / dataset_batch_size, ) else: @@ -390,6 +391,7 @@ def _validate_config( 'allow_pad_trimming', 'seq_parallel_replication', 'auto_packing_replication', + 'max_leftover_bins_to_keep', } if not set(kwargs.keys()).issubset(allowed_additional_kwargs): raise ValueError( diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 40f178fb6e..9a0f680bd7 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -592,6 +592,7 @@ def __init__( max_seq_len: int = 2048, allow_unsafe_types: bool = False, replication: Optional[int] = None, + packing_ratio: Optional[float] = None, **kwargs: Any, ): @@ -644,6 +645,7 @@ def __init__( self.tokenizer = tokenizer self.max_seq_len = max_seq_len + self.packing_ratio = packing_ratio # How to process a sample def __getitem__(self, idx: int) -> Dict[str, Any]: @@ -675,6 +677,16 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: return {'turns': [sample]} return tokenize_formatted_example(sample, tokenizer=self.tokenizer) + def state_dict(self, num_samples: int, + from_beginning: bool) -> Dict[str, Any]: + if self.packing_ratio is not None: + num_samples = int(self.packing_ratio * num_samples) + + return super().state_dict( + num_samples=num_samples, + from_beginning=from_beginning, + ) + class DatasetConstructor: diff --git a/llmfoundry/eval/datasets/__init__.py b/llmfoundry/eval/datasets/__init__.py index 02a2b88b21..a3a36053da 100644 --- a/llmfoundry/eval/datasets/__init__.py +++ b/llmfoundry/eval/datasets/__init__.py @@ -22,6 +22,18 @@ tokenizer_needs_prefix_space, trim_context, ) +from llmfoundry.registry import icl_datasets + +icl_datasets.register( + 'multiple_choice', + func=InContextLearningMultipleChoiceTaskDataset, +) +icl_datasets.register('schema', func=InContextLearningSchemaTaskDataset) +icl_datasets.register('language_modeling', func=InContextLearningLMTaskDataset) +icl_datasets.register( + 'generation_task_with_answers', + func=InContextLearningGenerationTaskWithAnswersDataset, +) __all__ = [ 'InContextLearningDataset', diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index debb0dbc6f..c87b38b09a 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -19,6 +19,7 @@ from datasets import IterableDataset, load_dataset from torch.utils.data import DataLoader, Dataset +from llmfoundry import registry from llmfoundry.eval.datasets.utils import ( convert_tokens_to_tensors, get_continuation_span, @@ -29,6 +30,7 @@ tokenizer_needs_prefix_space, trim_context, ) +from llmfoundry.utils.registry_utils import construct_from_registry log = logging.getLogger(__name__) @@ -114,11 +116,11 @@ def __init__( max_seq_len: int, pad_tok_id: int, num_fewshot: int, - fewshot_random_seed: int, - prompt_string: str, - example_delimiter: str, - continuation_delimiter: str, destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', prelimiter: str = '', context_key: str = 'context', answer_key: str = 'answer', @@ -189,6 +191,20 @@ def __len__(self) -> int: def get_num_samples_in_batch(self, batch: Dict) -> int: return batch['input_ids'].shape[0] + def get_effective_batch_size(self, batch_size: int) -> int: + r"""Returns effective batch size computed for given ICL task. + + The effective batch size may not be equal to the configured evaluation + batch size because for certain ICL tasks, >1 prompts can get created + for every input query depending on the number of choices/continuations. + This requires the effective batch size to be reduced to prevent larger batches than expected during eval. For example, + check InContextLearningMultipleChoiceTaskDataset. + + Args: + batch_size (int): Original batch size configured for ICL evaluations + """ + return batch_size + def update_generation_kwargs(self, generation_kwargs: Dict) -> None: r"""Updates self.base_batch with the passed in generation_kwargs. @@ -519,46 +535,12 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) return batch - def split_batch(self, batch: Any, - microbatch_size: Union[int, float]) -> Sequence[Any]: - """Handling for certain specialty columns that must be split into. - - batches in different formats. - - Args: - batch (Dict): Batch of data - microbatch_size (int | float): Size of microbatches - - Returns: - List: List of chunked batches - """ - # Don't split kwargs that don't change - # Normally split torch tensors - # List split lists of strings - if isinstance(microbatch_size, float): - raise ValueError( - 'split_batch does not support floating point microbatch_size.', - ) - chunked = {} - for k, v in batch.items(): - if k in self.static_keys: - # Defer broadcasting until we know num_chunks - pass - elif k in self.list_keys: - chunked[k] = _split_list(v, microbatch_size) - elif k in self.tensor_keys: - chunked[k] = _default_split_batch(v, microbatch_size) - else: - raise ValueError(f'Unexpected key {k} in batch splitting') - num_chunks = len(chunked['input_ids']) - for k, v in batch.items(): - if k in self.static_keys: - chunked[k] = [v] * num_chunks - - batched_list = [{k: v[idx] - for k, v in chunked.items()} - for idx in range(num_chunks)] - return batched_list + def split_batch( + self, + batch: Any, + microbatch_size: Union[int, float], + ) -> Sequence[Any]: + return _default_split_batch(batch, microbatch_size) class InContextLearningGenerationTaskWithAnswersDataset( @@ -584,13 +566,31 @@ class InContextLearningGenerationTaskWithAnswersDataset( def __init__( self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + context_key: str = 'context', + answer_key: str = 'answer', + strip_dataset: bool = True, + padding_size: Optional[int] = None, + base_batch: Optional[Dict] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, cot_delimiter: str = '', early_stopping_criteria: Optional[List[str]] = None, do_normalization: bool = True, - *args: Any, - **kwargs: Any, ): - if kwargs['tokenizer'].eos_token_id is None: + if tokenizer.eos_token_id is None: raise ValueError( '`InContextLearningGenerationTaskWithAnswersDataset` tokenizer must have non-null `eos_token_id`', ) @@ -607,13 +607,32 @@ def __init__( tensor_keys = ['input_ids', 'attention_mask'] list_keys = ['labels'] super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + context_key=context_key, + answer_key=answer_key, + strip_dataset=strip_dataset, + padding_size=padding_size, + base_batch=base_batch, + batch_mapping=batch_mapping, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + # specific to ICL dataset padding_side='left', tokenize_labels=False, static_keys=static_keys, list_keys=list_keys, tensor_keys=tensor_keys, - *args, - **kwargs, ) # NOTE: set these after init call because they take class vars self.early_stopping_criteria = early_stopping_criteria @@ -635,8 +654,8 @@ def __init__( 'input_ids': self.context_key, 'labels': 'aliases', } - if 'generation_kwargs' in kwargs: - self.update_generation_kwargs(kwargs['generation_kwargs']) + if generation_kwargs: + self.update_generation_kwargs(generation_kwargs) def read_dataset( self, @@ -765,6 +784,45 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch['generation_kwargs']['stopping_criteria'] = stopping_criteria return batch + def split_batch(self, batch: Any, + microbatch_size: Union[int, float]) -> Sequence[Any]: + """Split batch handling for special columns. + + Args: + batch (Dict): Batch of data + microbatch_size (int | float): Size of microbatches + + Returns: + List: List of chunked batches + """ + # Don't split kwargs that don't change + # Normally split torch tensors + # List split lists of strings + if isinstance(microbatch_size, float): + raise ValueError( + 'split_batch does not support floating point microbatch_size.', + ) + chunked = {} + for k, v in batch.items(): + if k in self.static_keys: + # Defer broadcasting until we know num_chunks + pass + elif k in self.list_keys: + chunked[k] = _split_list(v, microbatch_size) + elif k in self.tensor_keys: + chunked[k] = _default_split_batch(v, microbatch_size) + else: + raise ValueError(f'Unexpected key {k} in batch splitting') + num_chunks = len(chunked['input_ids']) + for k, v in batch.items(): + if k in self.static_keys: + chunked[k] = [v] * num_chunks + + batched_list = [{k: v[idx] + for k, v in chunked.items()} + for idx in range(num_chunks)] + return batched_list + class InContextLearningLMTaskDataset(InContextLearningDataset): """A dataset that constructs batches for in-context learning language. @@ -779,8 +837,50 @@ class InContextLearningLMTaskDataset(InContextLearningDataset): See InContextLearningDataset for more details. """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__( + self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + context_key: str = 'context', + strip_dataset: bool = True, + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + static_keys: Optional[List] = None, + list_keys: Optional[List] = None, + ): super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + context_key=context_key, + strip_dataset=strip_dataset, + tokenize_labels=tokenize_labels, + padding_size=padding_size, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + list_keys=list_keys, + # specific to ICL dataset answer_key='continuation', static_keys=['mode'], tensor_keys=[ @@ -800,8 +900,6 @@ def __init__(self, *args: Any, **kwargs: Any): 'labels': 'context', }, padding_side='right', - *args, - **kwargs, ) @@ -833,13 +931,33 @@ class InContextLearningMultipleChoiceTaskDataset(InContextLearningDataset): def __init__( self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + context_key: str = 'query', + tensor_keys: Optional[List] = None, + answer_key: str = 'answer', + strip_dataset: bool = True, + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + list_keys: Optional[List] = None, choices_key: str = 'choices', static_keys: Optional[List] = None, list_of_tensors_keys: Optional[List] = None, list_of_tuples_keys: Optional[List] = None, list_of_primitives: Optional[List] = None, - *args: Any, - **kwargs: Any, ): self.choices_key = choices_key base_batch = { @@ -850,25 +968,42 @@ def __init__( 'gold_indices': [], 'choice_groupings': [], } - context_key = kwargs.pop('context_key', 'query') - static_keys = kwargs.pop('static_keys', ['mode', 'generation_kwargs']) - tensor_keys = kwargs.pop( - 'tensor_keys', - ['input_ids', 'labels', 'attention_mask'], - ) + if not static_keys: + static_keys = ['mode', 'generation_kwargs'] + if not tensor_keys: + tensor_keys = ['input_ids', 'labels', 'attention_mask'] self.list_of_tensors_keys = list_of_tensors_keys or [ 'continuation_indices', ] self.list_of_tuples_keys = list_of_tuples_keys or ['choice_groupings'] self.list_of_primitives = list_of_primitives or ['gold_indices'] super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + answer_key=answer_key, + strip_dataset=strip_dataset, + tokenize_labels=tokenize_labels, + padding_size=padding_size, + batch_mapping=batch_mapping, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + list_keys=list_keys, + # specific to ICL dataset context_key=context_key, base_batch=base_batch, static_keys=static_keys, tensor_keys=tensor_keys, padding_side='right', - *args, - **kwargs, ) self.num_choices = len(self.dataset[0][self.choices_key]) self.batch_mapping_per_choice = { @@ -877,6 +1012,11 @@ def __init__( } self.batch_map_per_example = {'gold_indices': 'gold'} + def get_effective_batch_size(self, batch_size: int) -> int: + batch_size = max(self.num_choices, batch_size) + effective_batchsize = batch_size // self.num_choices + return effective_batchsize + def get_answer_from_example( self, example: Dict, @@ -1095,21 +1235,58 @@ class InContextLearningSchemaTaskDataset( def __init__( self, + dataset_uri: str, + tokenizer: transformers.PreTrainedTokenizerBase, + max_seq_len: int, + pad_tok_id: int, + num_fewshot: int, + destination_path: str, + fewshot_random_seed: int = 1234, + prompt_string: str = '', + example_delimiter: str = '\n', + continuation_delimiter: str = ' ', + prelimiter: str = '', + answer_key: str = 'answer', + strip_dataset: bool = True, + tokenize_labels: bool = True, + padding_size: Optional[int] = None, + batch_mapping: Optional[Dict] = None, + hf_loading_vars: Optional[Dict] = None, + hf_parsing_map: Optional[Dict] = None, + generation_kwargs: Optional[Dict] = None, + list_keys: Optional[List] = None, choices_key: str = 'context_options', - *args: Any, - **kwargs: Any, ): static_keys = ['mode'] tensor_keys = ['input_ids', 'labels', 'attention_mask'] list_of_tensors_keys = ['continuation_indices'] super().__init__( + dataset_uri=dataset_uri, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + pad_tok_id=pad_tok_id, + num_fewshot=num_fewshot, + fewshot_random_seed=fewshot_random_seed, + prompt_string=prompt_string, + example_delimiter=example_delimiter, + continuation_delimiter=continuation_delimiter, + destination_path=destination_path, + prelimiter=prelimiter, + answer_key=answer_key, + strip_dataset=strip_dataset, + tokenize_labels=tokenize_labels, + padding_size=padding_size, + batch_mapping=batch_mapping, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, + generation_kwargs=generation_kwargs, + list_keys=list_keys, + # specific to ICL dataset choices_key=choices_key, context_key=choices_key, static_keys=static_keys, tensor_keys=tensor_keys, list_of_tensors_keys=list_of_tensors_keys, - *args, - **kwargs, ) self.base_batch = { 'input_ids': [], @@ -1120,6 +1297,11 @@ def __init__( 'choice_groupings': [], } + def get_effective_batch_size(self, batch_size: int) -> int: + batch_size = max(self.num_choices, batch_size) + effective_batchsize = batch_size // self.num_choices + return effective_batchsize + def construct_context( self, example: Dict[str, Any], @@ -1294,23 +1476,10 @@ def build_icl_dataloader( dataset_uri: str, tokenizer: transformers.PreTrainedTokenizerBase, batch_size: int, - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, # e.g. 'translate english to french:' - example_delimiter: str, # e.g. '\n' - continuation_delimiter: str, # e.g. '' hf_loading_vars: Dict, hf_parsing_map: Dict, - destination_path: str, - prelimiter: str, # e.g. 'Question: ' - cot_delimiter: str, # e.g. ' ### ' - fewshot_random_seed: int, - pass_at_k: int, - generations_per_sample: int, - generation_kwargs: Dict, - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True, + destination_path: str = '', + kwargs: Optional[Dict[str, Any]] = None, ) -> DataSpec: """Factory method that builds the specific dataset for the specified. @@ -1323,108 +1492,36 @@ def build_icl_dataloader( this might be different) 3. set the `split_batch` function if necessary """ - if icl_task_type == 'multiple_choice': - dataset = InContextLearningMultipleChoiceTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - batch_size = max(dataset.num_choices, batch_size) - effective_batchsize = batch_size // dataset.num_choices - elif icl_task_type == 'schema': - dataset = InContextLearningSchemaTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - batch_size = max(dataset.num_choices, batch_size) - effective_batchsize = batch_size // dataset.num_choices - elif icl_task_type == 'language_modeling': - dataset = InContextLearningLMTaskDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - ) - effective_batchsize = batch_size - elif icl_task_type == 'generation_task_with_answers': - dataset = InContextLearningGenerationTaskWithAnswersDataset( - dataset_uri=dataset_uri, - tokenizer=tokenizer, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, - destination_path=destination_path, - prelimiter=prelimiter, - fewshot_random_seed=fewshot_random_seed, - hf_loading_vars=hf_loading_vars, - hf_parsing_map=hf_parsing_map, - cot_delimiter=cot_delimiter, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, - generation_kwargs=generation_kwargs, - ) - effective_batchsize = batch_size - else: - raise Exception(f'Unrecognized ICL task type: {icl_task_type}') - + # Add named parameters to kwargs + if kwargs is None: + kwargs = {} + kwargs.update({ + 'dataset_uri': dataset_uri, + 'tokenizer': tokenizer, + 'hf_loading_vars': hf_loading_vars, + 'hf_parsing_map': hf_parsing_map, + 'destination_path': destination_path, + }) + dataset = construct_from_registry( + name=icl_task_type, + registry=registry.icl_datasets, + partial_function=False, + pre_validation_function=None, + post_validation_function=None, + kwargs=kwargs, + ) sampler = dist.get_sampler(dataset, drop_last=False, shuffle=False) - split_batch = None - if isinstance( - dataset, - ( - InContextLearningMultipleChoiceTaskDataset, - InContextLearningGenerationTaskWithAnswersDataset, - ), - ): - split_batch = dataset.split_batch - return DataSpec( DataLoader( dataset, - batch_size=effective_batchsize, + batch_size=dataset.get_effective_batch_size(batch_size), sampler=sampler, collate_fn=dataset.collate_fn, ), device_transforms=None, get_num_samples_in_batch=dataset.get_num_samples_in_batch, - split_batch=split_batch, + split_batch=dataset.split_batch, ) @@ -1514,24 +1611,11 @@ def get_icl_task_dataloader( tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], batch_size: int, - max_seq_len: int, - pad_tok_id: int, - num_fewshot: int, - prompt_string: str, # e.g. 'translate english to french:' - example_delimiter: str, # e.g. '\n' - continuation_delimiter: str = '', - destination_path: str = '', - question_prelimiter: str = '', # e.g. 'Question: ' - fewshot_random_seed: int = 1234, - pass_at_k: int = 1, - generations_per_sample: int = 1, - cot_delimiter: str = '', has_categories: bool = False, hf_loading_vars: Optional[Dict] = None, hf_parsing_map: Optional[Dict] = None, - generation_kwargs: Optional[Dict] = None, - early_stopping_criteria: Optional[List[str]] = None, - do_normalization: bool = True, + destination_path: str = '', + kwargs: Optional[Dict[str, Any]] = None, ) -> Union[DataSpec, Dict[str, DataSpec]]: r"""Constructs a dataloader (or dataloaders if has_categories is True) @@ -1588,28 +1672,12 @@ def get_icl_task_dataloader( The default keys expected are "context" and "answer". tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to map between strings and token ids. batch_size (int): Size of a batch used for eval - max_seq_len (int): The maximum sequence length supported by the model. - pad_tok_id (int): The special token used for padding batches. - num_fewshot (int): The number of complete fewshot examples to prepend before each test example. These are not identical across examples. - prompt_string (str, default = ''): Prompt string to put once before all fewshot examples/test examples (e.g. 'Translate english to french.'). - example_delimiter (str, default = '\\n'): Separator inserted before (context, answer) pairs (e.g. '\\n') for fewshot sampling and prompting. - continuation_delimiter: (str, default = ' '): Separator inserted between context and answer in each example (e.g. '\\nA: '). - destination_path: (str, default = ''): This is the local file where remote datasets will be saved. - question_prelimiter: (str, default = ''): Text to be prepended before each context, including few shot examples (e.g. "Question: "). - fewshot_random_seed (int, default = 1234): Random seed to use for fewshot sampling - pass_at_k (int): k for how many chances the model gets to write passing code. - generations_per_sample (int): How many outputs to generate per prompt. Passed in generation_kwargs under "num_return_sequences" and overwritten by generation_kwargs dict. - cot_delimiter (str): Delimiter to place between chain of thoughts and continuations. has_categories: (bool): If ``True``, we will search the dataset file for a category key, and partition the dataset into a separate dataloader for each category occurring in the data. hf_loading_vars (Dict, default = None): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. hf_parsing_map (Dict, default = None): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. - generation_kwargs (Dict, default = None): A dictionary containing keyword arguments to be passed along to the model's generate function. Overwrites any previously specified generation - keyword args in this function (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig - for more details) - early_stopping (List, default = None): A list of strings that, when found in a model's output, will be treated as a stopping criteria at metric computation time. - Used in generation tasks with CoT - do_normalization (bool, default = True): Whether or not to normalize the outputs and labels in InContextLearningGenerationTaskWithAnswersDataset. Only used in generation tasks. + kwargs (Dict[str, Any], default=None): Dictionary containing a mapping + from ICL dataset constructor's parameter names and their desired values. Returns: DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided. @@ -1618,11 +1686,6 @@ def get_icl_task_dataloader( hf_loading_vars = {} if hf_parsing_map is None: hf_parsing_map = {} - if generation_kwargs is None: - generation_kwargs = {} - if early_stopping_criteria is None: - early_stopping_criteria = [] - if has_categories: result_dls = {} output_files = partition_dataset_by_category( @@ -1639,23 +1702,10 @@ def get_icl_task_dataloader( dataset_uri=partition_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, - continuation_delimiter=continuation_delimiter, destination_path=partition_uri + '_tmp', - prelimiter=question_prelimiter, - cot_delimiter=cot_delimiter, - fewshot_random_seed=fewshot_random_seed, - pass_at_k=pass_at_k, - generations_per_sample=generations_per_sample, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - generation_kwargs=generation_kwargs, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, + kwargs=kwargs, ) return result_dls else: @@ -1664,21 +1714,8 @@ def get_icl_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter=example_delimiter, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - continuation_delimiter=continuation_delimiter, destination_path=destination_path, - prelimiter=question_prelimiter, - cot_delimiter=cot_delimiter, - fewshot_random_seed=fewshot_random_seed, - pass_at_k=pass_at_k, - generations_per_sample=generations_per_sample, - generation_kwargs=generation_kwargs, - early_stopping_criteria=early_stopping_criteria, - do_normalization=do_normalization, + kwargs=kwargs, ) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 0c8e64b759..f36f53fffa 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -8,6 +8,7 @@ from composer.optim import ComposerScheduler from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataloader +from torch.utils.data import Dataset from torchmetrics import Metric from transformers import PreTrainedTokenizerBase @@ -206,6 +207,21 @@ description=_metrics_description, ) +_icl_datasets_description = ( + 'The ICL datasets registry is used to register an torch.utils.data.Dataset class which can be used for ICL tasks.' +) +icl_datasets = create_registry( + 'llmfoundry', + 'icl_datasets', + # TODO: Change type from Dataset to + # llmfoundry.eval.InContextLearningDataset. + # Using ICL dataset here introduces a circular import dependency between + # the registry and eval packages right now, thus needs some refactoring. + generic_type=Type[Dataset], + entry_points=True, + description=_icl_datasets_description, +) + __all__ = [ 'loggers', 'callbacks', @@ -228,4 +244,5 @@ 'attention_classes', 'attention_implementations', 'fcs', + 'icl_datasets', ] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 73eb026d98..ada553c52f 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import copy import functools import logging import os @@ -545,22 +546,10 @@ def _validate_cfg(icl_cfg: Dict[str, Any]): f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg["icl_task_type"]}.', ) - if 'prompt_string' not in icl_cfg: - icl_cfg['prompt_string'] = '' - if 'example_delimiter' not in icl_cfg: - icl_cfg['example_delimiter'] = '\n' - if 'continuation_delimiter' not in icl_cfg: - icl_cfg['continuation_delimiter'] = ' ' if 'max_seq_len' not in icl_cfg: icl_cfg['max_seq_len'] = default_max_seq_len if 'batch_size' not in icl_cfg: icl_cfg['batch_size'] = default_batch_size - if 'pass_at_k' not in icl_cfg: - icl_cfg['pass_at_k'] = 1 - if 'fewshot_random_seed' not in icl_cfg: - icl_cfg['fewshot_random_seed'] = 1234 - if 'generations_per_sample' not in icl_cfg: - icl_cfg['generations_per_sample'] = 1 if 'num_beams' in icl_cfg: raise ValueError( @@ -579,6 +568,7 @@ def _validate_cfg(icl_cfg: Dict[str, Any]): pad_tok_id = tokenizer.eos_token_id else: pad_tok_id = tokenizer.pad_token_id + label = f'{icl_cfg["label"]}/{num_fewshot}-shot' metric_names = list(icl_cfg['metric_names']) # TODO: fix Composer bug when copying local paths and destination exists @@ -589,38 +579,51 @@ def _validate_cfg(icl_cfg: Dict[str, Any]): hf_parsing_map = icl_cfg.get('hf_parsing_map', {}) hf_loading_vars = icl_cfg.get('hf_loading_vars', {}) - early_stopping_criteria = icl_cfg.get( 'early_stopping_criteria', - None, + [], ) + # TODO: fix manual removal of non-constructor fields + icl_constructor_kwargs = copy.deepcopy(icl_cfg) + icl_constructor_kwargs.pop('label', None) + icl_constructor_kwargs.pop('metric_names', None) + icl_constructor_kwargs.pop('icl_task_type', None) + icl_constructor_kwargs.pop('batch_size', None) + icl_constructor_kwargs.pop('has_categories', None) + + # Add custom constructor arguments + icl_constructor_kwargs['pad_tok_id'] = pad_tok_id + icl_constructor_kwargs['num_fewshot'] = num_fewshot + + # Support backwards compatibility for the naming of "prelimiter" as "question_prelimiter" + if 'question_prelimiter' in icl_constructor_kwargs: + if 'prelimiter' in icl_constructor_kwargs: + raise ValueError( + 'Both "question_prelimiter" and "prelimiter" are specified in the ICL task config. ' + + + 'Please only specify one of them, as they map to the same argument.', + ) + else: + icl_constructor_kwargs['prelimiter' + ] = icl_constructor_kwargs.pop( + 'question_prelimiter', + ) + assert early_stopping_criteria is None or isinstance( early_stopping_criteria, list, ) + dataloaders = get_icl_task_dataloader( - icl_cfg['icl_task_type'], - icl_cfg['dataset_uri'], - tokenizer, + icl_task_type=icl_cfg['icl_task_type'], + dataset_uri=icl_cfg['dataset_uri'], + tokenizer=tokenizer, batch_size=icl_cfg['batch_size'], - max_seq_len=icl_cfg['max_seq_len'], - pad_tok_id=pad_tok_id, - num_fewshot=num_fewshot, - prompt_string=icl_cfg['prompt_string'], - example_delimiter=icl_cfg['example_delimiter'], hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - continuation_delimiter=icl_cfg['continuation_delimiter'], - question_prelimiter=icl_cfg.get('question_prelimiter', ''), - destination_path=destination_path, - fewshot_random_seed=icl_cfg['fewshot_random_seed'], - pass_at_k=icl_cfg['pass_at_k'], - generations_per_sample=icl_cfg['generations_per_sample'], has_categories=icl_cfg.get('has_categories', False), - cot_delimiter=icl_cfg.get('cot_delimiter', ''), - generation_kwargs=icl_cfg.get('generation_kwargs', {}), - early_stopping_criteria=early_stopping_criteria, - do_normalization=icl_cfg.get('do_normalization', True), + destination_path=destination_path, + kwargs=icl_constructor_kwargs, ) if 'has_categories' in icl_cfg and icl_cfg[ 'has_categories'] and isinstance(dataloaders, dict): diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 5c1ec9114a..ef894b562e 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -115,7 +115,7 @@ class TrainConfig: # Cuda allocation configuration max_split_size_mb: Optional[int] = None - expandable_segments: bool = False + expandable_segments: bool = True cuda_load_lazy: bool = False # Distributed training parameters @@ -593,8 +593,28 @@ def _process_data_source( # Check for HF path elif 'hf_name' in dataset and dataset['hf_name']: hf_path = dataset['hf_name'] - backend, _, _ = parse_uri(hf_path) - if backend: + backend, _, uc_path = parse_uri(hf_path) + unsupported_file = True + if backend == 'dbfs': + assert cfg_split + from llmfoundry.data.finetuning.tasks import SUPPORTED_EXTENSIONS + possible_files = [ + f'{cfg_split}{ext}' for ext in SUPPORTED_EXTENSIONS + ] + for file in possible_files: + path = os.path.join(uc_path, file) + # Ensure path starts with '/' + if not path.startswith('/'): + path = '/' + path + if _verify_uc_path(path): + data_paths.append(('uc_volume', path, true_split)) + unsupported_file = False + break + if unsupported_file: + log.warning( + f'{hf_path} does not contain a supported file extension.', + ) + elif backend: hf_path = os.path.join(hf_path, cfg_split) if cfg_split else hf_path data_paths.append((backend, hf_path, true_split)) elif os.path.exists(hf_path): @@ -665,3 +685,52 @@ def log_dataset_uri(cfg: Dict[str, Any]) -> None: mlflow.log_input( mlflow.data.meta_dataset.MetaDataset(source, name=split), ) + + +def _verify_uc_path(path: str) -> bool: + """Verify a UC path exists. + + Args: + path (str): UnityCatalog path + Returns: + (bool): If path exists or not + """ + from databricks.sdk.errors.platform import NotFound, PermissionDenied + w = None + try: + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + except ImportError: + log.warning( + 'Cannot verify the path of `UCVolumeDatasetSource` because of missing' + \ + '`databricks-sdk`. Please install `databricks-sdk` via ' + \ + '`pip install -U databricks-sdk`. This does not block creating ' + \ + '`UCVolumeDatasetSource`, but your `UCVolumeDatasetSource` might be invalid.', + ) + return False + except Exception as e: + log.warning( + f'Error occured when attempting to connect with Databricks WorkspaceClient. ' + \ + f'Error details: {str(e)}. This does not block creating `UCVolumeDatasetSource`, ' + \ + f'but your `UCVolumeDatasetSource` might be invalid.', + ) + + if w: + try: + w.files.get_metadata(path) + except (NotFound, PermissionDenied): + try: + # Check if `self.path` points to a valid UC directory. + w.files.get_directory_metadata(path) + return True + except (NotFound, PermissionDenied): + # Neither file nor directory exists, we throw an exception. + return False + except Exception as e: + log.warning( + f'Error occured when verifying path of `UCVolumeDatasetSource`. ' + \ + f'Error details: {str(e)}. This does not block creating `UCVolumeDatasetSource`, ' + \ + f'but your `UCVolumeDatasetSource` might be invalid.', + ) + return False diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index d871761803..f63f1b0027 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -463,6 +463,7 @@ def validate_and_get_cluster_info( if res is None: raise ClusterDoesNotExistError(cluster_id) + assert res.spark_version is not None stripped_runtime = re.sub( r'[a-zA-Z]', '', diff --git a/setup.py b/setup.py index f81b1cd0f1..a836aec27f 100644 --- a/setup.py +++ b/setup.py @@ -54,8 +54,8 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,oci,gcs]>=0.23.2,<0.24', - 'mlflow>=2.12.1,<2.13', + 'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.23.2,<0.24', + 'mlflow>=2.13.2,<2.14', 'accelerate>=0.25,<0.26', # for HF inference `device_map` 'transformers>=4.40,<4.41', 'mosaicml-streaming>=0.7.6,<0.8', diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 0577e13a1f..1b2f791995 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import json import math import os @@ -468,6 +469,7 @@ def _get_model_and_tokenizer( model: str, max_seq_len: int, tie_word_embeddings: bool, + precision: str, ): if model == 'mpt': model_cfg = { @@ -482,6 +484,7 @@ def _get_model_and_tokenizer( 'attn_config': { 'attn_impl': 'torch', }, + 'fc_type': 'te' if precision == 'amp_fp8' else 'torch', 'loss_fn': 'torch_crossentropy', 'tie_word_embeddings': tie_word_embeddings, } @@ -783,8 +786,9 @@ def _assert_checkpoint_equivalence( ) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize( - 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', - [('1ba', '1ba', '1ba', 1, 1)], + 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints,trainer_precision', + [('1ba', '1ba', '1ba', 1, 1, 'amp_bf16'), + ('1ba', '1ba', '1ba', 1, 1, 'amp_fp8')], ) @patch('os.cpu_count', MagicMock(return_value=1)) @patch( @@ -801,10 +805,30 @@ def test_huggingface_conversion_callback( max_duration: str, expected_hf_checkpoints: int, expected_normal_checkpoints: int, + trainer_precision: str, peft_config: Optional[dict], ): if model == 'mptmoe' and fsdp_state_dict_type is None: pytest.skip('mptmoe requires FSDP') + if trainer_precision == 'amp_fp8': + # Check if transformer-engine is installed for FP8. + try: + import transformer_engine.pytorch as te + except ImportError: + pytest.skip( + 'Precision amp_fp8 requires transformer-engine to be installed', + ) + + # Check we are using mpt models only for FP8. + if (model == 'neo' or model == 'llama2'): + pytest.skip( + 'Precision amp_fp8 works only for mpt models, not hf models', + ) + + # Check that we are using H100 or later for FP8. + if not (torch.cuda.get_device_capability() >= (8, 9)): + pytest.skip('Amp FP8 requires a GPU with compute capability >= 8.9') + delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -825,9 +849,10 @@ def test_huggingface_conversion_callback( # Get small version of each model model_cfg, tokenizer_name = _get_model_and_tokenizer( - model, - max_seq_len, - tie_word_embeddings, + model=model, + max_seq_len=max_seq_len, + tie_word_embeddings=tie_word_embeddings, + precision=trainer_precision, ) assert model_cfg is not None assert tokenizer_name is not None @@ -883,7 +908,7 @@ def test_huggingface_conversion_callback( trainer = Trainer( model=original_model, device='gpu', - precision='amp_bf16', + precision=trainer_precision, fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), @@ -900,24 +925,29 @@ def test_huggingface_conversion_callback( # summon full params to check equivalence from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params( - trainer.state.model, - writeback=False, - recurse=True, - ): - _assert_checkpoint_equivalence( - tmp_path=tmp_path, - expected_normal_checkpoints=expected_normal_checkpoints, - expected_hf_checkpoints=expected_hf_checkpoints, - trainer=trainer, - batches_per_epoch=batches_per_epoch, - original_model=original_model, - precision=precision, - model=model, - tokenizer=tokenizer, - fsdp_state_dict_type=fsdp_state_dict_type, - peft_config=peft_config, - ) + + context_manager = te.onnx_export( # type: ignore + True, + ) if trainer_precision == 'amp_fp8' else contextlib.nullcontext() + with context_manager: + with FSDP.summon_full_params( + trainer.state.model, + writeback=False, + recurse=True, + ): + _assert_checkpoint_equivalence( + tmp_path=tmp_path, + expected_normal_checkpoints=expected_normal_checkpoints, + expected_hf_checkpoints=expected_hf_checkpoints, + trainer=trainer, + batches_per_epoch=batches_per_epoch, + original_model=original_model, + precision=precision, + model=model, + tokenizer=tokenizer, + fsdp_state_dict_type=fsdp_state_dict_type, + peft_config=peft_config, + ) dist.barrier() delete_transformers_cache() diff --git a/tests/callbacks/test_system_metrics_monitor.py b/tests/callbacks/test_system_metrics_monitor.py new file mode 100644 index 0000000000..47095604eb --- /dev/null +++ b/tests/callbacks/test_system_metrics_monitor.py @@ -0,0 +1,15 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from composer.callbacks import SystemMetricsMonitor + +from llmfoundry.utils.builders import build_callback + + +def test_system_metrics_monitor_callback_builds(): + callback = build_callback( + 'system_metrics_monitor', + kwargs={}, + train_config={'train_loader': {}}, + ) + assert isinstance(callback, SystemMetricsMonitor) diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index b910b8c5ff..d181dbde0b 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -14,6 +14,7 @@ from torch.utils.data import DataLoader from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.utils.builders import build_tokenizer @@ -206,6 +207,15 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path): if batch_ix >= 3: break + assert isinstance(loader, DataLoader) + assert isinstance(loader.dataset, StreamingFinetuningDataset) + assert loader.dataset.packing_ratio is not None + assert isinstance(loader.batch_size, int) + assert loader.dataset.packing_ratio == int(loader.batch_size / 6) + + state_dict = loader.dataset.state_dict(num_samples=2, from_beginning=False) + assert state_dict['sample_in_epoch'] == 2 * loader.dataset.packing_ratio + @pytest.mark.parametrize('packing_ratio', ['auto', 2.0]) @patch( diff --git a/tests/eval/test_in_context_learning_datasets.py b/tests/eval/test_in_context_learning_datasets.py index a3c3e88364..81769a18e6 100644 --- a/tests/eval/test_in_context_learning_datasets.py +++ b/tests/eval/test_in_context_learning_datasets.py @@ -37,6 +37,7 @@ InContextLearningLMAccuracy, InContextLearningMultipleChoiceAccuracy, ) +from llmfoundry.utils.builders import build_icl_evaluators def test_strip_data(): @@ -1090,15 +1091,22 @@ def test_mc_task_dataloader_subcategories( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=2, - prompt_string= - 'The following are multiple choice questions (with answers).\n', - example_delimiter='\n', - continuation_delimiter='Answer: ', - destination_path=str(tmp_path / 'icl.jsonl'), has_categories=True, + destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'num_fewshot': + 2, + 'max_seq_len': + seqlen, + 'pad_tok_id': + tokenizer.eos_token_id, + 'prompt_string': + 'The following are multiple choice questions (with answers).\n', + 'example_delimiter': + '\n', + 'continuation_delimiter': + 'Answer: ', + }, ) assert isinstance(dls, dict) @@ -1142,13 +1150,15 @@ def test_lm_task_dataloader_extra_space( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=10, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 10, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1192,13 +1202,15 @@ def test_lm_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 0, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': '', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1241,14 +1253,16 @@ def test_schema_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - question_prelimiter=prelimiter, - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'example_delimiter': '\n', + 'prelimiter': prelimiter, + 'continuation_delimiter': '', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) @@ -1300,13 +1314,15 @@ def test_schema_task_dataloader_sentpiece_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) @@ -1358,13 +1374,15 @@ def test_lm_task_dataloader_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': '', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1410,13 +1428,15 @@ def test_mc_task_dataloader_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1473,13 +1493,15 @@ def test_mc_split_batch( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1550,13 +1572,15 @@ def test_qa_split_batch( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=8, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 0, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dl, DataSpec) # pyright @@ -1612,14 +1636,16 @@ def test_qa_task_dataloader_w_null_eos( dataset_uri, tokenizer, batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter='\nA:', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': '\n', + 'prelimiter': 'Q: ', + 'continuation_delimiter': '\nA:', + }, ) @@ -1647,14 +1673,16 @@ def test_qa_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter='\nA:', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': '\n', + 'prelimiter': 'Q: ', + 'continuation_delimiter': '\nA:', + }, ) assert isinstance(dl, DataSpec) @@ -1714,15 +1742,17 @@ def test_qa_task_with_cot_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - question_prelimiter='Q: ', - continuation_delimiter="\nA: Let's think step by step. ", - cot_delimiter=' #### ', destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'prelimiter': 'Q: ', + 'continuation_delimiter': "\nA: Let's think step by step. ", + 'cot_delimiter': ' #### ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1779,14 +1809,16 @@ def test_mc_task_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - question_prelimiter=prelimiter, - example_delimiter=example_delimiter, - continuation_delimiter='\nA: ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'prelimiter': prelimiter, + 'example_delimiter': example_delimiter, + 'continuation_delimiter': '\nA: ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -1851,13 +1883,15 @@ def test_lm_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter='', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': '', + }, ) evaluator = Evaluator( @@ -1903,13 +1937,15 @@ def test_schema_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -1968,14 +2004,16 @@ def test_mc_task_evaluation_subcategories( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=max_seq_len, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), has_categories=True, + kwargs={ + 'max_seq_len': max_seq_len, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) assert isinstance(dls, dict) @@ -2039,13 +2077,15 @@ def test_mc_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=64, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 64, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -2107,13 +2147,15 @@ def test_qa_task_evaluation_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -2168,14 +2210,16 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter="A: Let's think step by step. ", - cot_delimiter=' #### ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': "A: Let's think step by step. ", + 'cot_delimiter': ' #### ', + }, ) evaluator = Evaluator( @@ -2228,13 +2272,15 @@ def test_qa_task_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + }, ) evaluator = Evaluator( @@ -2288,14 +2334,16 @@ def test_qa_task_with_cot_evaluation( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=1024, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string='', - example_delimiter='\n', - continuation_delimiter="A: Let's think step by step", - cot_delimiter=' #### ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), + kwargs={ + 'max_seq_len': 1024, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': "A: Let's think step by step", + 'cot_delimiter': ' #### ', + }, ) evaluator = Evaluator( @@ -2339,13 +2387,15 @@ def test_lm_spacing_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' UNIQUE ', destination_path=str(tmp_path / 'icl.jsonl'), + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 1, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' UNIQUE ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -2409,15 +2459,17 @@ def test_hf_dataloading_lm_dataloader( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=0, - prompt_string='', - example_delimiter='\n', - continuation_delimiter=' ', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': 0, + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ' ', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -2490,16 +2542,18 @@ def test_hf_dataloading_custom_parsing( dataset_uri=dataset_uri, tokenizer=tokenizer, batch_size=batch_size, - max_seq_len=seqlen, - pad_tok_id=tokenizer.eos_token_id, - num_fewshot=num_fewshot, - prompt_string=prompt_string, - example_delimiter='\n', - question_prelimiter='Orbs: ', - continuation_delimiter='\nSpell:', destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'), hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, + kwargs={ + 'max_seq_len': seqlen, + 'pad_tok_id': tokenizer.eos_token_id, + 'num_fewshot': num_fewshot, + 'prompt_string': prompt_string, + 'example_delimiter': '\n', + 'prelimiter': 'Orbs: ', + 'continuation_delimiter': '\nSpell:', + }, ) assert isinstance(dl, DataSpec) assert isinstance(dl.dataloader, DataLoader) # pyright @@ -2535,3 +2589,42 @@ def test_hf_dataloading_custom_parsing( ) assert decoded_batch[0].endswith('Orbs: quas wex exort\nSpell:') assert decoded_batch[1].endswith('Orbs: quas quas quas\nSpell:') + + +@pytest.mark.parametrize( + 'prelimiter_key_name', + ['prelimiter', 'question_prelimiter'], +) +def test_bc_question_prelimiter( + mpt_tokenizer: transformers.PreTrainedTokenizerBase, + prelimiter_key_name: str, +): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + dataset_uri = f'{local_data}/piqa_small.jsonl' + + icl_tasks = [ + { + 'dataset_uri': dataset_uri, + 'label': 'piqa', + 'icl_task_type': 'multiple_choice', + 'max_seq_len': 64, + 'pad_tok_id': mpt_tokenizer.eos_token_id, + 'num_fewshot': [0], + 'prompt_string': '', + 'example_delimiter': '\n', + 'continuation_delimiter': ': ', + prelimiter_key_name: 'This is a question: ', + }, + ] + + evaluators, _ = build_icl_evaluators( + icl_tasks=icl_tasks, + tokenizer=mpt_tokenizer, + default_batch_size=2, + default_max_seq_len=128, + ) + + assert len(evaluators) == 1 + evaluator = evaluators[0] + assert evaluator.dataloader.dataloader.dataset.prelimiter == 'This is a question: ' # type: ignore diff --git a/tests/test_registry.py b/tests/test_registry.py index 87881450d4..3bdf5a800f 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -42,6 +42,7 @@ def test_expected_registries_exist(): 'attention_classes', 'attention_implementations', 'fcs', + 'icl_datasets', } assert existing_registries == expected_registry_names