diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py index 27eaa769..87cc543c 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py @@ -69,7 +69,7 @@ def tokenize_function( use_seq2seq_approach: bool = True, chunk_size: int = 128, drop_remainder: bool = False, - ) -> DataStream[BatchEncoding]: + ) -> Union[DataStream[BatchEncoding], BatchEncoding]: """Tokenization function to be used for causallm training; this function consumes a GenerationTrainRecord object and applies the verbalizer to it followed by the model tokenizer. Due to the nature of our training data with src/target seqs, @@ -89,15 +89,21 @@ def tokenize_function( Verbalizer to be rendered into each text. task_ids: Union[None, int] Task IDs to be used for multiprompt tuning. + use_seq2seq_approach: bool + Indicates whether or not we should use a sequence style approach + or use chunking parameters. chunk_size: int unsigned int value to be used for chunk size. + Only used if use_seq2seq_approach=True. drop_remainder: bool Whether or not to keep the residual as an extra chunk if the total number of tokens is not divisible by the chunk size. + Only used if use_seq2seq_approach=True. Returns: - DataStream[transformers.tokenization_utils_base.BatchEncoding] - stream of encoded tokenization output corresponding to the input example. + Union[DataStream[BatchEncoding], BatchEncoding] + stream of encoded tokenization output corresponding to the input example + or a single batch encoding object containing 1+ tokenized results. """ ### Things common to all Causal LM tokenization approaches # Extract the source & target from our provided inputs @@ -148,6 +154,10 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": NOTE: If mlm (masked language modeling) is not passed in kwargs, this function will automatically set it to `False`. + FIXME: This should be consolidated with what is in the prompt tuning + module, which currently does its own collator management outside of the + resource classes. + Args: **kwargs: All the keyword arguments passed to this function @@ -155,6 +165,7 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": applicable to implemented data collator. Returns: transformers.DataCollator + Collator to be used for causal language modeling. """ applicable_args = ["mlm", "pad_to_multiple_of"] @@ -172,16 +183,47 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": @classmethod def _causal_lm_as_chunked( cls, - tokenizer, - source, - target, - max_source_length, - max_target_length, - batched_mode, - task_ids, - chunk_size, - drop_remainder, - ): + tokenizer: "AutoTokenizer", + source: str, + target: str, + max_source_length: int, + max_target_length: int, + batched_mode: bool, + task_ids: Union[None, int], + chunk_size: int, + drop_remainder: bool, + ) -> Union[DataStream[BatchEncoding], BatchEncoding]: + """Given a source and target string, build the chunked concatenated sequence and formulate + the batch encoded chunks for the sequence. If running in batch mode, the chunks will be + collapsed into a single batch encoding for the whole sequence. Otherwise, each chunk will + placed in its own BatchEncoding and encapsulated within a datastream. + + Args: + tokenizer: AutoTokenizer + Tokenizer object to be applied to input records. + source: str + Raw source string. + target: str + Raw target string. + max_source_length: int + Maximum length for input sequences. + max_target_length: int + Maximum length for output sequences. + batched_mode: bool + Whether or not we should produce a stream of encodings or a single + encoding representing all of the chunked sequence. + task_ids: Union[None, int] + Task IDs to be used for multiprompt tuning. + chunk_size: int + unsigned int value to be used for chunk size. + drop_remainder: bool + Whether or not to keep the residual as an extra chunk if the + total number of tokens is not divisible by the chunk size. + + Returns: + Union[DataStream[BatchEncoding], BatchEncoding] + Encoded chunked sequence as a stream or batch encoding object. + """ source_ids = tokenizer(source, max_length=max_source_length, truncation=True) target_ids = tokenizer(target, max_length=max_target_length, truncation=True) @@ -228,15 +270,23 @@ def _force_to_batch_encoding_list_of_chunks( Whether or not we are processing a batch. task_ids: Union[None, int] Optional task IDs for MPT to be propagated to produced encodings. + chunk_size: int + unsigned int value to be used for chunk size. + drop_remainder: bool + Whether or not to keep the residual as an extra chunk if the + total number of tokens is not divisible by the chunk size. Returns: List[BatchEncoding] + List of batch encodings, each of which encapsulates the contents + of a single chunk. """ if not batch_mode: HFAutoCausalLM._concatenate_encodings(source_ids, target_ids) chunks = HFAutoCausalLM._split_encoding_into_chunks( encoding=source_ids, chunk_size=chunk_size, + drop_remainder=drop_remainder, task_ids=task_ids, ) return chunks @@ -268,20 +318,43 @@ def _force_to_batch_encoding_list_of_chunks( return encodings @staticmethod - def _concatenate_encodings(left, right): + def _concatenate_encodings(left: BatchEncoding, right: BatchEncoding) -> None: + """Given two batch encodings, combine their entries into a single encoding. + + Args: + left: BatchEncoding + Encoding representing left sequence, which will be updated in place. + Corresponds to source. + right: BatchEncoding + Encoding representing right sequence, which will be stacked onto the left + encoding. Corresponds to target. + """ for k in left.keys(): left[k] = left[k] + right[k] @staticmethod def _split_encoding_into_chunks( - encoding: dict, chunk_size: int, drop_remainder: bool = False, task_ids=None - ): - """Fetch the chunked batch encoding objects from source/target encoding(s). - If no target encoding is provided, it's assumed that the source and target - have already been concatenated. - - If drop remainder is enabled, do not yield uneven chunks. For now, this parameter - is not exposed. + encoding: BatchEncoding, + chunk_size: int, + drop_remainder: bool, + task_ids: Union[None, int], + ) -> List[BatchEncoding]: + """Fetch the chunked batch encoding objects from the concatenated encoding. + + Args: + encoding: BatchEncoding + BatchEncoding holding the concatenated source/target for one example. + chunk_size: int + unsigned int value to be used for chunk size. + drop_remainder: bool + Whether or not to keep the residual as an extra chunk if the + total number of tokens is not divisible by the chunk size. + task_ids: Union[None, int] + Optional task IDs for MPT to be propagated to produced encodings. + + Returns: + List[BatchEncoding] + List of encodings, where each encoding represents one chunk. """ chunked_encodings = [] # all encoding keys have the same length list values; we just use input ids @@ -342,13 +415,13 @@ def _collapse_stream_into_encoding( # Causal language modeling as a sequence to sequence problem @staticmethod def _causal_lm_padding_as_seq2seq( - tokenizer, - source, - target, - max_source_length, - max_target_length, - task_ids, - ): + tokenizer: "AutoTokenizer", + source: str, + target: str, + max_source_length: int, + max_target_length: int, + task_ids: Union[None, int], + ) -> BatchEncoding: """Tokenize the example as a seq2seq type problem; this is conceptually similar to what seq2seq tokenization is doing, but some care needs be taken to ensure the labels are the same length as the input sequence because of the shifting mechanism implemented @@ -358,9 +431,24 @@ def _causal_lm_padding_as_seq2seq( directly, we should NOT use the causal lm collator, otherwise it will clobber it with a shifted input sequence. - For now, this is a logical port of the old tokenization logic. - NOTE: In this tokenization strategy, where we concat the texts, the concatenated sequence - length is max_source_length + max_target_length + 1. + Args: + tokenizer: AutoTokenizer + Tokenizer object to be applied to input records. + source: str + Raw source string. + target: str + Raw target string. + max_source_length: int + Maximum length for input sequences. + max_target_length: int + Maximum length for output sequences. + task_ids: Union[None, int] + Optional task IDs for MPT to be propagated to produced encodings. + Returns: + BatchEncoding + BatchEncoding object corresponding to this example, where the input_ids, + attention_mask, and labels all have the same length, i.e., + [max_source_length + max_target_length + 1]. """ IGNORE_ID = -100 # ID of the token to append after our target string; this should generally be pad / EOS