diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 5f9417d5..a9ea7f69 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -33,11 +33,7 @@ from torch.optim import AdamW from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import ( - AutoModelForCausalLM, - DataCollatorForLanguageModeling, - default_data_collator, -) +from transformers import AutoModelForCausalLM, default_data_collator from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.optimization import get_linear_schedule_with_warmup import numpy as np @@ -901,12 +897,8 @@ def _get_collate_fn(tokenizer: AutoTokenizer, task_type: str) -> Callable: Callable collate_fn to be used for processing batches from our datasets. """ - if task_type == "CAUSAL_LM": - return DataCollatorForLanguageModeling( - tokenizer=tokenizer, - return_tensors="pt", - mlm=False, - ) + # HACK: Do NOT use the causal LM collator (for now) because + # want to set labels ourselves. TODO: centralize collator management. return default_data_collator @staticmethod @@ -947,15 +939,11 @@ def _get_data_loaders_from_stream( torch.utils.data.DataLoader DataLoader to be used for training / evaluating the stream data. """ - ( - tokenize_function, - requires_unwrapping, - ) = base_model.build_task_tokenize_closure( + (tokenize_function, _,) = base_model.build_task_tokenize_closure( tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0 ) mapped_stream = train_stream.map(tokenize_function) - if requires_unwrapping: - mapped_stream = mapped_stream.flatten() + # TODO: Deprecate and remove stream wrapper & use trainer wrapped_stream = SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) dataloader = DataLoader( wrapped_stream, collate_fn=collate_fn, batch_size=batch_size diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index e9bfcdb1..df612683 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -593,7 +593,7 @@ def _preprocess_function( mapped_dataset = dataset.map( base_model.tokenize_function, fn_kwargs=fn_kwargs, - batched=base_model.REQUIRES_TOKEN_UNWRAPPING, + batched=False, # Drop the input / output columns; we need to do this for dimensions to play # happily when operating on batched inputs for causal language modeling. remove_columns=["input", "output"], diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py index 897cdb31..fc09734e 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py @@ -16,8 +16,7 @@ """ # Standard from collections.abc import Mapping -from copy import copy -from typing import Callable, List, Tuple, Union +from typing import List, Union # Third Party from transformers import ( @@ -26,6 +25,7 @@ DataCollatorForLanguageModeling, ) from transformers.models.auto import modeling_auto +import torch # First Party from caikit.core.data_model import DataStream @@ -66,7 +66,10 @@ def tokenize_function( max_target_length: int, verbalizer: Union[None, str] = None, task_ids: Union[None, int] = None, - ) -> DataStream[BatchEncoding]: + use_seq2seq_approach: bool = True, + chunk_size: int = 128, + drop_remainder: bool = False, + ) -> Union[DataStream[BatchEncoding], BatchEncoding]: """Tokenization function to be used for causallm training; this function consumes a GenerationTrainRecord object and applies the verbalizer to it followed by the model tokenizer. Due to the nature of our training data with src/target seqs, @@ -86,11 +89,23 @@ def tokenize_function( Verbalizer to be rendered into each text. task_ids: Union[None, int] Task IDs to be used for multiprompt tuning. + use_seq2seq_approach: bool + Indicates whether or not we should use a sequence style approach + or use chunking parameters. + chunk_size: int + unsigned int value to be used for chunk size. + Only used if use_seq2seq_approach=True. + drop_remainder: bool + Whether or not to keep the residual as an extra chunk if the + total number of tokens is not divisible by the chunk size. + Only used if use_seq2seq_approach=True. Returns: - DataStream[transformers.tokenization_utils_base.BatchEncoding] - stream of encoded tokenization output corresponding to the input example. + Union[DataStream[BatchEncoding], BatchEncoding] + stream of encoded tokenization output corresponding to the input example + or a single batch encoding object containing 1+ tokenized results. """ + ### Things common to all Causal LM tokenization approaches # Extract the source & target from our provided inputs source, target = cls.decompose_example_io(example) # Determine if our mapped inputs are in batched mode or not @@ -104,50 +119,31 @@ def tokenize_function( source = ( source if verbalizer is None else render_verbalizer(verbalizer, example) ) - - source_ids = tokenizer(source, max_length=max_source_length, truncation=True) - target_ids = tokenizer(target, max_length=max_target_length, truncation=True) - - # Force everything to a list of batch encodings; for non-batch mode, this just - # puts it into a list. For batch mode, we get a list of batch encodings, - # allowing us to standardize subsequent processing a bit. - source_ids, num_target_samples = cls._force_to_batch_encoding_list( - source_ids, target_ids, batched_mode, task_ids + # Treat this as a seq2seq type problem. Note that this implementation is different + # from the seq2seq tokenization function even though it is conceptually similar due + # to sequence length / padding requirements assumed internally by causal LMs. + if use_seq2seq_approach: + return cls._causal_lm_padding_as_seq2seq( + tokenizer=tokenizer, + source=source, + target=target, + max_source_length=max_source_length, + max_target_length=max_target_length, + task_ids=task_ids, + ) + # Do causal language model chunking + return cls._causal_lm_as_chunked( + tokenizer=tokenizer, + source=source, + target=target, + max_source_length=max_source_length, + max_target_length=max_target_length, + batched_mode=batched_mode, + task_ids=task_ids, + chunk_size=chunk_size, + drop_remainder=drop_remainder, ) - def build_generator_func( - source_ids: BatchEncoding, num_target_samples: int - ) -> Callable: - """Builds a generator that can be applied to a single batch encoding and its - corresponding original number of target samples. - - source_ids: BatchEncoding - Source ID to generate different samples from. - num_target_samples: int - Number of target IDs; used for attention mask creation. - """ - - def single_generator_func(): - for idx in range(num_target_samples): - ret_source_ids = copy(source_ids) - ret_source_ids["attention_mask"] = cls._get_attention_mask( - source_ids, - idx, - num_target_samples, - ) - yield ret_source_ids - - return single_generator_func - - if not batched_mode: - return DataStream(build_generator_func(source_ids, num_target_samples)) - streams = [ - DataStream(build_generator_func(s_ids, n_target_samples)) - for s_ids, n_target_samples in zip(source_ids, num_target_samples) - ] - encoding_keys = source_ids[0].keys() - return cls._collapse_streams_into_encoding(streams, encoding_keys) - def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": """Function to return appropriate data collator based on resource. @@ -158,6 +154,10 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": NOTE: If mlm (masked language modeling) is not passed in kwargs, this function will automatically set it to `False`. + FIXME: This should be consolidated with what is in the prompt tuning + module, which currently does its own collator management outside of the + resource classes. + Args: **kwargs: All the keyword arguments passed to this function @@ -165,6 +165,7 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": applicable to implemented data collator. Returns: transformers.DataCollator + Collator to be used for causal language modeling. """ applicable_args = ["mlm", "pad_to_multiple_of"] @@ -177,13 +178,92 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs ) + ### Tokenization strategy implementations + # Chunked causal language modeling + @classmethod + def _causal_lm_as_chunked( + cls, + tokenizer: "AutoTokenizer", + source: str, + target: str, + max_source_length: int, + max_target_length: int, + batched_mode: bool, + task_ids: Union[None, int], + chunk_size: int, + drop_remainder: bool, + ) -> Union[DataStream[BatchEncoding], BatchEncoding]: + """Given a source and target string, build the chunked concatenated sequence and formulate + the batch encoded chunks for the sequence. If running in batch mode, the chunks will be + collapsed into a single batch encoding for the whole sequence. Otherwise, each chunk will + placed in its own BatchEncoding and encapsulated within a datastream. + + Args: + tokenizer: AutoTokenizer + Tokenizer object to be applied to input records. + source: str + Raw source string. + target: str + Raw target string. + max_source_length: int + Maximum length for input sequences. + max_target_length: int + Maximum length for output sequences. + batched_mode: bool + Whether or not we should produce a stream of encodings or a single + encoding representing all of the chunked sequence. + task_ids: Union[None, int] + Task IDs to be used for multiprompt tuning. + chunk_size: int + unsigned int value to be used for chunk size. + drop_remainder: bool + Whether or not to keep the residual as an extra chunk if the + total number of tokens is not divisible by the chunk size. + + Returns: + Union[DataStream[BatchEncoding], BatchEncoding] + Encoded chunked sequence as a stream or batch encoding object. + """ + source_ids = tokenizer(source, max_length=max_source_length, truncation=True) + target_ids = tokenizer(target, max_length=max_target_length, truncation=True) + + # Force everything to a list of batch encodings; for non-batch mode, this just + # puts it into a list. For batch mode, we get a list of batch encodings, + # allowing us to standardize subsequent processing a bit. + # + # For example, given chunk size 2, we might have something like: + # [ + # {'input_ids': [31, 48], 'attention_mask': [1, 1]}, + # {'input_ids': [47, 1], 'attention_mask': [1, 1]}, + # ... + # ] + # (where the above objects are batch encodings, which are a subclass of dict) + source_id_chunks = cls._force_to_batch_encoding_list_of_chunks( + source_ids, target_ids, batched_mode, task_ids, chunk_size, drop_remainder + ) + + def generator_func(): + for chunk in source_id_chunks: + yield chunk + + chunk_stream = DataStream(generator_func) + # If it's batch mode, collapse down into one encoding batch object + if batched_mode: + return cls._collapse_stream_into_encoding(chunk_stream) + # Otherwise just produce the stream to be chained + # NOTE: it might be a good idea to deprecate this to force standardization + # onto using batch encodings the way that they are intended to be + return chunk_stream + @staticmethod - def _force_to_batch_encoding_list( + def _force_to_batch_encoding_list_of_chunks( source_ids: BatchEncoding, target_ids: BatchEncoding, batch_mode: bool, task_ids: Union[None, int], - ) -> Tuple[Union[BatchEncoding, List[BatchEncoding]], Union[int, List[int]]]: + chunk_size: int, + drop_remainder: bool, + ) -> List[BatchEncoding]: """Forces our inputs into either a single batch encoding (if we aren't running in batch mode), or a list of Batch Encodings. I.e., a list of dicts instead of a dict of lists. The primary reason that we do this is to allow us to easily map a common generator @@ -198,19 +278,29 @@ def _force_to_batch_encoding_list( Whether or not we are processing a batch. task_ids: Union[None, int] Optional task IDs for MPT to be propagated to produced encodings. + chunk_size: int + unsigned int value to be used for chunk size. + drop_remainder: bool + Whether or not to keep the residual as an extra chunk if the + total number of tokens is not divisible by the chunk size. Returns: - Tuple[Union[BatchEncoding, List[BatchEncoding]], Union[int, List]] + List[BatchEncoding] + List of batch encodings, each of which encapsulates the contents + of a single chunk. """ if not batch_mode: - source_ids["input_ids"] = source_ids.input_ids + target_ids.input_ids - source_ids["task_ids"] = task_ids - num_target_samples = len(target_ids.input_ids) - return source_ids, num_target_samples + HFAutoCausalLM._concatenate_encodings(source_ids, target_ids) + chunks = HFAutoCausalLM._split_encoding_into_chunks( + encoding=source_ids, + chunk_size=chunk_size, + drop_remainder=drop_remainder, + task_ids=task_ids, + ) + return chunks # Otherwise we need to expand the dict along its keys, # mapping all of its encapsulated objects to new items. encodings = [] - num_target_samples = [] id_keys = source_ids.keys() key = None error.value_check( @@ -218,49 +308,93 @@ def _force_to_batch_encoding_list( source_ids.keys(), "Source ID batch encoding must have keys", ) + for batch_idx in range(len(source_ids.input_ids)): new_encoding = BatchEncoding() for key in id_keys: - if key == "input_ids": - new_encoding[key] = ( - source_ids[key][batch_idx] + target_ids[key][batch_idx] - ) - else: - new_encoding[key] = source_ids[key][batch_idx] - num_target_samples.append(len(target_ids[key][batch_idx])) - new_encoding["task_ids"] = task_ids - encodings.append(new_encoding) - return encodings, num_target_samples + new_encoding[key] = ( + source_ids[key][batch_idx] + target_ids[key][batch_idx] + ) + chunks = HFAutoCausalLM._split_encoding_into_chunks( + encoding=new_encoding, + chunk_size=chunk_size, + drop_remainder=drop_remainder, + task_ids=task_ids, + ) + # Chunks are held as a list of lists + encodings += chunks + return encodings @staticmethod - def _get_attention_mask( - source_ids: BatchEncoding, idx: int, num_target_samples: int - ) -> List[int]: - """Get the attention mask for a given target token from some source encoding. + def _concatenate_encodings(left: BatchEncoding, right: BatchEncoding) -> None: + """Given two batch encodings, combine their entries into a single encoding. Args: - source_ids: BatchEncoding - Source encoding that requires an attention mask. - idx: int - Index of the output token we attend up to. - num_target_samples: int - Length of the original target seequence being considered. + left: BatchEncoding + Encoding representing left sequence, which will be updated in place. + Corresponds to source. + right: BatchEncoding + Encoding representing right sequence, which will be stacked onto the left + encoding. Corresponds to target. + """ + for k in left.keys(): + left[k].extend(right[k]) + + @staticmethod + def _split_encoding_into_chunks( + encoding: BatchEncoding, + chunk_size: int, + drop_remainder: bool, + task_ids: Union[None, int], + ) -> List[BatchEncoding]: + """Fetch the chunked batch encoding objects from the concatenated encoding. + + Args: + encoding: BatchEncoding + BatchEncoding holding the concatenated source/target for one example. + chunk_size: int + unsigned int value to be used for chunk size. + drop_remainder: bool + Whether or not to keep the residual as an extra chunk if the + total number of tokens is not divisible by the chunk size. + task_ids: Union[None, int] + Optional task IDs for MPT to be propagated to produced encodings. Returns: - List[int] - Binary attention mask. + List[BatchEncoding] + List of encodings, where each encoding represents one chunk. """ - return ( - source_ids["attention_mask"] - + [1] * (idx + 1) - + [0] * (num_target_samples - idx - 1) - ) + chunked_encodings = [] + # all encoding keys have the same length list values; we just use input ids + tok_len = len(encoding["input_ids"]) + # Build a batch encoding for every chunk; for each data, + # use the slice for all keys inside of the source_encoding. + if tok_len >= chunk_size: + slice_len = (tok_len // chunk_size) * chunk_size + # If we have a remainder and we don't want to drop it, add a new chunk + if not drop_remainder and slice_len != tok_len: + slice_len += chunk_size + # We just have one big chunk + else: + slice_len = tok_len + chunked_encodings = [ + BatchEncoding( + data={ + k: v[chunk_num : chunk_num + chunk_size] + for k, v in encoding.items() + } + ) + for chunk_num in range(0, slice_len, chunk_size) + ] + for enc in chunked_encodings: + enc["task_ids"] = task_ids + return chunked_encodings @staticmethod - def _collapse_streams_into_encoding( - streams: List[DataStream[BatchEncoding]], encoding_keys: "dict_keys" + def _collapse_stream_into_encoding( + stream: DataStream[BatchEncoding], ) -> BatchEncoding: - """Given a list of streams of batch encodings, collapse them back into + """Given a stream batch encodings, collapse them back into one encoding, i.e., the return value of the batch encoding. Args: @@ -271,14 +405,105 @@ def _collapse_streams_into_encoding( Returns: BatchEncoding - Collapsed batch encoding to be returned from tokenizatino func. + Collapsed batch encoding to be returned from tokenization func. """ + encoding_keys = None new_encoding = BatchEncoding() - for k in encoding_keys: - new_encoding[k] = [] # Now build the individual lists lists for each entry - for stream in streams: - for enc in stream: + for enc in stream: + # Initialize the existing keys in the new encoding + if encoding_keys is None: + encoding_keys = enc.keys() for k in encoding_keys: - new_encoding[k].append(enc[k]) + new_encoding[k] = [] + for k in encoding_keys: + new_encoding[k].append(enc[k]) return new_encoding + + # Causal language modeling as a sequence to sequence problem + @staticmethod + def _causal_lm_padding_as_seq2seq( + tokenizer: "AutoTokenizer", + source: str, + target: str, + max_source_length: int, + max_target_length: int, + task_ids: Union[None, int], + ) -> BatchEncoding: + """Tokenize the example as a seq2seq type problem; this is conceptually similar to + what seq2seq tokenization is doing, but some care needs be taken to ensure the labels + are the same length as the input sequence because of the shifting mechanism implemented + in most causal language models. + + Collator compatability is extremely important here; because we are setting the labels + directly, we should NOT use the causal lm collator, otherwise it will clobber it with a + shifted input sequence. + + Args: + tokenizer: AutoTokenizer + Tokenizer object to be applied to input records. + source: str + Raw source string. + target: str + Raw target string. + max_source_length: int + Maximum length for input sequences. + max_target_length: int + Maximum length for output sequences. + task_ids: Union[None, int] + Optional task IDs for MPT to be propagated to produced encodings. + Returns: + BatchEncoding + BatchEncoding object corresponding to this example, where the input_ids, + attention_mask, and labels all have the same length, i.e., + [max_source_length + max_target_length + 1]. + """ + IGNORE_ID = -100 + # ID of the token to append after our target string; this should generally be pad / EOS + FINAL_TOK_ID = tokenizer.eos_token_id + max_concat_length = max_source_length + max_target_length + 1 + + # Truncate based on max source or max target length before considering as a joined sequence + model_inputs = tokenizer(source, truncation=True, max_length=max_source_length) + labels = tokenizer(target, truncation=True, max_length=max_target_length + 1) + + # Combine the source + target strings into the source input IDs + # This makes the source and target the same length, and then masks the source out of the + # target IDs, and updates the length of the attention vector to be evenly spread on the + # whole combined sequence + sample_input_ids = model_inputs["input_ids"] + label_input_ids = labels["input_ids"] + [FINAL_TOK_ID] + model_inputs["input_ids"] = sample_input_ids + label_input_ids + labels["input_ids"] = [IGNORE_ID] * len(sample_input_ids) + label_input_ids + model_inputs["attention_mask"] = [1] * len(model_inputs["input_ids"]) + # Now we have to update everything to be the max length of the tokenizer, then pad & + # ensure all of the padded stuff we have added has attention weights of 0. + sample_input_ids = model_inputs[ + "input_ids" + ] # NOTE - combined source + target + + + label_input_ids = labels["input_ids"] + model_inputs = tokenizer.pad( + model_inputs, padding="max_length", max_length=max_concat_length + ) + + if tokenizer.padding_side.lower() == "left": + labels["input_ids"] = [IGNORE_ID] * ( + max_concat_length - len(sample_input_ids) + ) + label_input_ids + else: + labels["input_ids"] = label_input_ids + [IGNORE_ID] * ( + max_concat_length - len(sample_input_ids) + ) + + model_inputs["input_ids"] = torch.tensor( + model_inputs["input_ids"][:max_concat_length] + ) + model_inputs["attention_mask"] = torch.tensor( + model_inputs["attention_mask"][:max_concat_length] + ) + + labels["input_ids"] = torch.tensor(labels["input_ids"][:max_concat_length]) + model_inputs["labels"] = labels["input_ids"] + model_inputs["task_ids"] = task_ids + return model_inputs diff --git a/tests/resources/test_pretrained_model.py b/tests/resources/test_pretrained_model.py index 8828544c..056a9a8e 100644 --- a/tests/resources/test_pretrained_model.py +++ b/tests/resources/test_pretrained_model.py @@ -10,6 +10,7 @@ # Third Party from datasets import IterableDataset as TransformersIterableDataset +from torch.utils.data import DataLoader import pytest import torch import transformers @@ -102,7 +103,8 @@ def test_boostrap_model_path(models_cache_dir): ] ) - +# Causal LM tokenization strategies +### 1. Tests for Causal LM tokenization chunking def test_causal_lm_tokenize_func_contains_wrapped_stream(models_cache_dir): """Ensure the Causal LM tokenize func produces a wrapped stream that can be flattened.""" causal_lm = HFAutoCausalLM.bootstrap( @@ -113,6 +115,7 @@ def test_causal_lm_tokenize_func_contains_wrapped_stream(models_cache_dir): max_source_length=100, max_target_length=100, verbalizer="{{input}}", + use_seq2seq_approach=False, ) map_stream = SAMPLE_TRAINING_DATA.map(tok_func) # Since tok_func for causal lm creates a datastream, we should get a stream @@ -127,13 +130,21 @@ def test_causal_lm_tokenize_func_contains_wrapped_stream(models_cache_dir): ) -def test_causal_lm_tok_output_correctness(models_cache_dir): - """Validate the correctness of the attention mask for the language modeling objective.""" +# Key cases here are: +# 1 - simplest and minimal case +# 3 - because the concat sequence is length 17, so we have a remainder +# 100 - which is much larger than the concatenated seq and should yield one chunk +@pytest.mark.parametrize( + "chunk_size,drop_remainder", + [(1, True), (1, False), (3, True), (3, False), (100, True), (100, False)], +) +def test_causal_lm_tok_output_correctness(models_cache_dir, chunk_size, drop_remainder): + """Validate the tokenized results for the chunked language modeling objective.""" causal_lm = HFAutoCausalLM.bootstrap( model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL ) sample = GenerationTrainRecord( - input="This len does not matter", output="but this one does!" + input="Hello world", output="How are you doing today?!" ) (tok_func, _) = causal_lm.build_task_tokenize_closure( tokenizer=causal_lm.tokenizer, @@ -141,35 +152,138 @@ def test_causal_lm_tok_output_correctness(models_cache_dir): max_target_length=100, verbalizer="{{input}}", task_ids=0, + use_seq2seq_approach=False, + chunk_size=chunk_size, + drop_remainder=drop_remainder, ) input_tok = causal_lm.tokenizer.encode(sample.input) output_tok = causal_lm.tokenizer.encode(sample.output) + concat_tok = input_tok + output_tok tok_stream = tok_func(sample) # Ensure we get one token per output in our stream assert isinstance(tok_stream, caikit.core.data_model.DataStream) - assert len(tok_stream) == len(output_tok) - for idx, tok_sample in enumerate(tok_stream): - # We expect by default, everything is in order, and each attention mask grows the tokens - # we attend to in the target by one, until we are paying attention to the whole sequence. - expected_target_mask = torch.tensor( - ([1] * (idx + 1)) + [0] * (len(output_tok) - idx - 1) - ) - actual_target_mask = torch.tensor( - tok_sample["attention_mask"][-len(output_tok) :] - ) - assert bool(torch.all(expected_target_mask == actual_target_mask)) - # Check the source mask; we should always attend to the whole source sequence - actual_source_mask = torch.tensor( - tok_sample["attention_mask"][: len(input_tok)] + # Figure out how many chunks we should have, including if we have a remainder + has_remainder = False + if len(concat_tok) > chunk_size: + num_expected_chunks = len(concat_tok) // chunk_size + # Should only care about the remainder if we are not dropping it + if num_expected_chunks * chunk_size != len(concat_tok) and not drop_remainder: + has_remainder = True + else: + num_expected_chunks = 1 + chunk_size = len(concat_tok) + tok_list = list(tok_stream) + assert len(tok_list) == num_expected_chunks + has_remainder + # Check all full chunks. Note that we always attend to everything + for idx in range(num_expected_chunks): + assert len(tok_list[idx]["attention_mask"]) == chunk_size + assert len(tok_list[idx]["input_ids"]) == chunk_size + assert all(atn == 1 for atn in tok_list[idx]["attention_mask"]) + assert tok_list[idx]["task_ids"] == 0 + # Check the remainder; lists should be the same length, but less than the chunk size + if has_remainder: + remainder = tok_list[-1] + assert len(remainder["attention_mask"]) == len(remainder["input_ids"]) + assert len(remainder["input_ids"]) < chunk_size + assert all(atn == 1 for atn in remainder["attention_mask"]) + + +def test_causal_lm_batch_tokenization(models_cache_dir): + """Ensure that we can batch process causal lm inputs correctly.""" + causal_lm = HFAutoCausalLM.bootstrap( + model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL + ) + train_stream = DataStream.from_iterable( + [ + GenerationTrainRecord(input="hello there", output="world"), + GenerationTrainRecord(input="how", output="today"), + ] + ) + fn_kwargs = { + "tokenizer": causal_lm.tokenizer, + "max_source_length": 10, + "max_target_length": 10, + "use_seq2seq_approach": False, + } + # Create an iterable dataset by batching... + def get(train_stream): + for data in train_stream: + yield {"input": data.input, "output": data.output} + + dataset = TransformersIterableDataset.from_generator( + get, gen_kwargs={"train_stream": train_stream} + ) + batched_dataset = dataset.map( + causal_lm.tokenize_function, + fn_kwargs=fn_kwargs, + batched=True, + remove_columns=["input", "output"], + ) + + # Do the same thing with no batching via tokenize closure + unwrapping + tok_func = causal_lm.build_task_tokenize_closure(**fn_kwargs)[0] + mapped_indiv_stream = train_stream.map(tok_func).flatten() + for indiv_res, batched_res in zip(mapped_indiv_stream, batched_dataset): + # All keys should match (input ids, attention mask) + assert indiv_res.keys() == batched_res.keys() + # And all of their values should be the same + for k in indiv_res: + assert indiv_res[k] == batched_res[k] + + +### 2. Tests for causal LM framed as a seq2seq problem +# NOTE: For these tests, we should be careful to always test left and right padding +@pytest.mark.parametrize( + "padding_side", + ["left", "right"], +) +def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding_side): + causal_lm = HFAutoCausalLM.bootstrap( + model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL + ) + sample = GenerationTrainRecord( + input="Hello world", output="How are you doing today?!" + ) + max_lengths = 20 + # First, build the output we expect for left / right respectively... + input_tok = causal_lm.tokenizer.encode(sample.input) + output_tok = causal_lm.tokenizer.encode(sample.output) + [ + causal_lm.tokenizer.eos_token_id + ] + concat_res = input_tok + output_tok + masked_res = ([-100] * len(input_tok)) + output_tok + + # This must true because otherwise no padding was needed, e.g., truncation + assert len(input_tok) < max_lengths + assert len(output_tok) < (max_lengths + 1) + pads_needed = (1 + 2 * max_lengths) - len(concat_res) + if causal_lm.tokenizer.padding_side.lower() == "left": + expected_input_ids = torch.tensor( + [causal_lm.tokenizer.pad_token_id] * pads_needed + concat_res ) - assert bool(torch.all(torch.tensor([1] * len(input_tok)) == actual_source_mask)) - # Also, the number of tokens we attend to should be the sum of toks in input/output - assert (len(actual_target_mask) + len(actual_source_mask)) == len( - tok_sample["attention_mask"] + expected_attn_mask = torch.tensor([0] * pads_needed + [1] * len(concat_res)) + expected_labels = torch.tensor([-100] * pads_needed + masked_res) + else: + expected_input_ids = torch.tensor( + concat_res + [causal_lm.tokenizer.pad_token_id] * pads_needed ) - # Ensure we support MPT - assert hasattr(tok_sample, "task_ids") - assert tok_sample["task_ids"] == 0 + expected_attn_mask = torch.tensor([1] * len(concat_res) + [0] * pads_needed) + expected_labels = torch.tensor(masked_res + [-100] * pads_needed) + + # Now build the analogous tokenizer closure and compare the tensors + (tok_func, _) = causal_lm.build_task_tokenize_closure( + tokenizer=causal_lm.tokenizer, + max_source_length=max_lengths, + max_target_length=max_lengths, + verbalizer="{{input}}", + task_ids=0, + use_seq2seq_approach=True, + ) + tok_res = tok_func(sample) + assert tok_res["task_ids"] == 0 + assert torch.all(tok_res["input_ids"] == expected_input_ids) + assert torch.all(tok_res["attention_mask"] == expected_attn_mask) + assert torch.all(tok_res["labels"] == expected_labels) ### Tests for Seq2Seq tokenization @@ -222,43 +336,80 @@ def test_seq2seq_tok_output_correctness(models_cache_dir): assert tok_sample["task_ids"] == 0 -def test_causal_lm_batch_tokenization(models_cache_dir): - """Ensure that we can batch process causal lm inputs correctly.""" - causal_lm = HFAutoCausalLM.bootstrap( - model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL - ) +### Tests for collator compatability +# These tests should validate that we can use our tokenization function to +# build torch loaders around datasets using different collators. +# TODO: Expand to cover transformer datasets, i.e., what is produced by +# text gen preprocessing functions. For now, they only check the minimal +# case with the default data collator. +@pytest.mark.parametrize( + "collator_fn", + [transformers.default_data_collator], +) +def test_loader_can_batch_list_of_seq2seq_outputs(collator_fn): + # Build the dataset train_stream = DataStream.from_iterable( [ - GenerationTrainRecord(input="hello there", output="world"), - GenerationTrainRecord(input="how", output="today"), + GenerationTrainRecord(input="hello world", output="how are you today?"), + GenerationTrainRecord(input="goodbye", output="world"), + GenerationTrainRecord(input="good morning", output="have a good day"), + GenerationTrainRecord(input="good night", output="have nice dreams"), ] ) - fn_kwargs = { - "tokenizer": causal_lm.tokenizer, - "max_source_length": 10, - "max_target_length": 10, - } - # Create an iterable dataset by batching... - def get(train_stream): - for data in train_stream: - yield {"input": data.input, "output": data.output} - - dataset = TransformersIterableDataset.from_generator( - get, gen_kwargs={"train_stream": train_stream} + seq2seq = HFAutoSeq2SeqLM.bootstrap( + model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL ) - batched_dataset = dataset.map( - causal_lm.tokenize_function, - fn_kwargs=fn_kwargs, - batched=True, - remove_columns=["input", "output"], + (tok_func, _) = seq2seq.build_task_tokenize_closure( + tokenizer=seq2seq.tokenizer, + max_source_length=20, + max_target_length=20, + verbalizer="{{input}}", + task_ids=0, + ) + tok_results = [tok_func(x) for x in list(train_stream)] + dl = DataLoader( + tok_results, + shuffle=False, + batch_size=2, + collate_fn=collator_fn, ) + # Loader should create 2 batches + loader_list = list(dl) + assert len(loader_list) == 2 - # Do the same thing with no batching via tokenize closure + unwrapping - tok_func = causal_lm.build_task_tokenize_closure(**fn_kwargs)[0] - mapped_indiv_stream = train_stream.map(tok_func).flatten() - for indiv_res, batched_res in zip(mapped_indiv_stream, batched_dataset): - # All keys should match (input ids, attention mask) - assert indiv_res.keys() == batched_res.keys() - # And all of their values should be the same - for k in indiv_res: - assert indiv_res[k] == batched_res[k] + +@pytest.mark.parametrize( + "collator_fn", + [transformers.default_data_collator], +) +def test_loader_can_batch_list_of_causal_lm_outputs(collator_fn): + # Build the dataset + train_stream = DataStream.from_iterable( + [ + GenerationTrainRecord(input="hello world", output="how are you today?"), + GenerationTrainRecord(input="goodbye", output="world"), + GenerationTrainRecord(input="good morning", output="have a good day"), + GenerationTrainRecord(input="good night", output="have nice dreams"), + ] + ) + causal_lm = HFAutoCausalLM.bootstrap( + model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL + ) + (tok_func, _) = causal_lm.build_task_tokenize_closure( + tokenizer=causal_lm.tokenizer, + max_source_length=20, + max_target_length=20, + verbalizer="{{input}}", + task_ids=0, + use_seq2seq_approach=True, + ) + tok_results = [tok_func(x) for x in list(train_stream)] + dl = DataLoader( + tok_results, + shuffle=False, + batch_size=2, + collate_fn=collator_fn, + ) + # Loader should create 2 batches + loader_list = list(dl) + assert len(loader_list) == 2