Skip to content

Commit

Permalink
Update causal lm docstrings and type hints
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Sep 29, 2023
1 parent 560d5bc commit 511e561
Showing 1 changed file with 120 additions and 32 deletions.
152 changes: 120 additions & 32 deletions caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def tokenize_function(
use_seq2seq_approach: bool = True,
chunk_size: int = 128,
drop_remainder: bool = False,
) -> DataStream[BatchEncoding]:
) -> Union[DataStream[BatchEncoding], BatchEncoding]:
"""Tokenization function to be used for causallm training; this function consumes a
GenerationTrainRecord object and applies the verbalizer to it followed by
the model tokenizer. Due to the nature of our training data with src/target seqs,
Expand All @@ -89,15 +89,21 @@ def tokenize_function(
Verbalizer to be rendered into each text.
task_ids: Union[None, int]
Task IDs to be used for multiprompt tuning.
use_seq2seq_approach: bool
Indicates whether or not we should use a sequence style approach
or use chunking parameters.
chunk_size: int
unsigned int value to be used for chunk size.
Only used if use_seq2seq_approach=True.
drop_remainder: bool
Whether or not to keep the residual as an extra chunk if the
total number of tokens is not divisible by the chunk size.
Only used if use_seq2seq_approach=True.
Returns:
DataStream[transformers.tokenization_utils_base.BatchEncoding]
stream of encoded tokenization output corresponding to the input example.
Union[DataStream[BatchEncoding], BatchEncoding]
stream of encoded tokenization output corresponding to the input example
or a single batch encoding object containing 1+ tokenized results.
"""
### Things common to all Causal LM tokenization approaches
# Extract the source & target from our provided inputs
Expand Down Expand Up @@ -148,13 +154,18 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator":
NOTE: If mlm (masked language modeling) is not passed in kwargs,
this function will automatically set it to `False`.
FIXME: This should be consolidated with what is in the prompt tuning
module, which currently does its own collator management outside of the
resource classes.
Args:
**kwargs:
All the keyword arguments passed to this function
will get filtered out to appropriate ones that are
applicable to implemented data collator.
Returns:
transformers.DataCollator
Collator to be used for causal language modeling.
"""

applicable_args = ["mlm", "pad_to_multiple_of"]
Expand All @@ -172,16 +183,47 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator":
@classmethod
def _causal_lm_as_chunked(
cls,
tokenizer,
source,
target,
max_source_length,
max_target_length,
batched_mode,
task_ids,
chunk_size,
drop_remainder,
):
tokenizer: "AutoTokenizer",
source: str,
target: str,
max_source_length: int,
max_target_length: int,
batched_mode: bool,
task_ids: Union[None, int],
chunk_size: int,
drop_remainder: bool,
) -> Union[DataStream[BatchEncoding], BatchEncoding]:
"""Given a source and target string, build the chunked concatenated sequence and formulate
the batch encoded chunks for the sequence. If running in batch mode, the chunks will be
collapsed into a single batch encoding for the whole sequence. Otherwise, each chunk will
placed in its own BatchEncoding and encapsulated within a datastream.
Args:
tokenizer: AutoTokenizer
Tokenizer object to be applied to input records.
source: str
Raw source string.
target: str
Raw target string.
max_source_length: int
Maximum length for input sequences.
max_target_length: int
Maximum length for output sequences.
batched_mode: bool
Whether or not we should produce a stream of encodings or a single
encoding representing all of the chunked sequence.
task_ids: Union[None, int]
Task IDs to be used for multiprompt tuning.
chunk_size: int
unsigned int value to be used for chunk size.
drop_remainder: bool
Whether or not to keep the residual as an extra chunk if the
total number of tokens is not divisible by the chunk size.
Returns:
Union[DataStream[BatchEncoding], BatchEncoding]
Encoded chunked sequence as a stream or batch encoding object.
"""
source_ids = tokenizer(source, max_length=max_source_length, truncation=True)
target_ids = tokenizer(target, max_length=max_target_length, truncation=True)

Expand Down Expand Up @@ -228,15 +270,23 @@ def _force_to_batch_encoding_list_of_chunks(
Whether or not we are processing a batch.
task_ids: Union[None, int]
Optional task IDs for MPT to be propagated to produced encodings.
chunk_size: int
unsigned int value to be used for chunk size.
drop_remainder: bool
Whether or not to keep the residual as an extra chunk if the
total number of tokens is not divisible by the chunk size.
Returns:
List[BatchEncoding]
List of batch encodings, each of which encapsulates the contents
of a single chunk.
"""
if not batch_mode:
HFAutoCausalLM._concatenate_encodings(source_ids, target_ids)
chunks = HFAutoCausalLM._split_encoding_into_chunks(
encoding=source_ids,
chunk_size=chunk_size,
drop_remainder=drop_remainder,
task_ids=task_ids,
)
return chunks
Expand Down Expand Up @@ -268,20 +318,43 @@ def _force_to_batch_encoding_list_of_chunks(
return encodings

@staticmethod
def _concatenate_encodings(left, right):
def _concatenate_encodings(left: BatchEncoding, right: BatchEncoding) -> None:
"""Given two batch encodings, combine their entries into a single encoding.
Args:
left: BatchEncoding
Encoding representing left sequence, which will be updated in place.
Corresponds to source.
right: BatchEncoding
Encoding representing right sequence, which will be stacked onto the left
encoding. Corresponds to target.
"""
for k in left.keys():
left[k] = left[k] + right[k]

@staticmethod
def _split_encoding_into_chunks(
encoding: dict, chunk_size: int, drop_remainder: bool = False, task_ids=None
):
"""Fetch the chunked batch encoding objects from source/target encoding(s).
If no target encoding is provided, it's assumed that the source and target
have already been concatenated.
If drop remainder is enabled, do not yield uneven chunks. For now, this parameter
is not exposed.
encoding: BatchEncoding,
chunk_size: int,
drop_remainder: bool,
task_ids: Union[None, int],
) -> List[BatchEncoding]:
"""Fetch the chunked batch encoding objects from the concatenated encoding.
Args:
encoding: BatchEncoding
BatchEncoding holding the concatenated source/target for one example.
chunk_size: int
unsigned int value to be used for chunk size.
drop_remainder: bool
Whether or not to keep the residual as an extra chunk if the
total number of tokens is not divisible by the chunk size.
task_ids: Union[None, int]
Optional task IDs for MPT to be propagated to produced encodings.
Returns:
List[BatchEncoding]
List of encodings, where each encoding represents one chunk.
"""
chunked_encodings = []
# all encoding keys have the same length list values; we just use input ids
Expand Down Expand Up @@ -342,13 +415,13 @@ def _collapse_stream_into_encoding(
# Causal language modeling as a sequence to sequence problem
@staticmethod
def _causal_lm_padding_as_seq2seq(
tokenizer,
source,
target,
max_source_length,
max_target_length,
task_ids,
):
tokenizer: "AutoTokenizer",
source: str,
target: str,
max_source_length: int,
max_target_length: int,
task_ids: Union[None, int],
) -> BatchEncoding:
"""Tokenize the example as a seq2seq type problem; this is conceptually similar to
what seq2seq tokenization is doing, but some care needs be taken to ensure the labels
are the same length as the input sequence because of the shifting mechanism implemented
Expand All @@ -358,9 +431,24 @@ def _causal_lm_padding_as_seq2seq(
directly, we should NOT use the causal lm collator, otherwise it will clobber it with a
shifted input sequence.
For now, this is a logical port of the old tokenization logic.
NOTE: In this tokenization strategy, where we concat the texts, the concatenated sequence
length is max_source_length + max_target_length + 1.
Args:
tokenizer: AutoTokenizer
Tokenizer object to be applied to input records.
source: str
Raw source string.
target: str
Raw target string.
max_source_length: int
Maximum length for input sequences.
max_target_length: int
Maximum length for output sequences.
task_ids: Union[None, int]
Optional task IDs for MPT to be propagated to produced encodings.
Returns:
BatchEncoding
BatchEncoding object corresponding to this example, where the input_ids,
attention_mask, and labels all have the same length, i.e.,
[max_source_length + max_target_length + 1].
"""
IGNORE_ID = -100
# ID of the token to append after our target string; this should generally be pad / EOS
Expand Down

0 comments on commit 511e561

Please sign in to comment.