diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 96ffde63..26e365cb 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -33,10 +33,7 @@ from torch.optim import AdamW from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import ( - AutoModelForCausalLM, - default_data_collator, -) +from transformers import AutoModelForCausalLM, default_data_collator from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.optimization import get_linear_schedule_with_warmup import numpy as np @@ -931,10 +928,7 @@ def _get_data_loaders_from_stream( torch.utils.data.DataLoader DataLoader to be used for training / evaluating the stream data. """ - ( - tokenize_function, - _, - ) = base_model.build_task_tokenize_closure( + (tokenize_function, _,) = base_model.build_task_tokenize_closure( tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0 ) mapped_stream = train_stream.map(tokenize_function) diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py index 6c138726..27eaa769 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py @@ -19,13 +19,13 @@ from typing import List, Union # Third Party -import torch from transformers import ( AutoModelForCausalLM, BatchEncoding, DataCollatorForLanguageModeling, ) from transformers.models.auto import modeling_auto +import torch # First Party from caikit.core.data_model import DataStream @@ -37,7 +37,6 @@ from ...data_model import GenerationTrainRecord, PromptOutputModelType from ...toolkit.verbalizer_utils import render_verbalizer from .base import PretrainedModelBase -from .hf_auto_seq2seq_lm import HFAutoSeq2SeqLM log = alog.use_channel("HFRCLM") error = error_handler.get(log) @@ -139,7 +138,6 @@ def tokenize_function( drop_remainder=drop_remainder, ) - def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": """Function to return appropriate data collator based on resource. @@ -169,11 +167,21 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator": tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs ) - ### Tokenization strategy implementations # Chunked causal language modeling @classmethod - def _causal_lm_as_chunked(cls, tokenizer, source, target, max_source_length, max_target_length, batched_mode, task_ids, chunk_size, drop_remainder): + def _causal_lm_as_chunked( + cls, + tokenizer, + source, + target, + max_source_length, + max_target_length, + batched_mode, + task_ids, + chunk_size, + drop_remainder, + ): source_ids = tokenizer(source, max_length=max_source_length, truncation=True) target_ids = tokenizer(target, max_length=max_target_length, truncation=True) @@ -197,7 +205,6 @@ def generator_func(): # onto using batch encodings the way that they are intended to be return chunk_stream - @staticmethod def _force_to_batch_encoding_list_of_chunks( source_ids: BatchEncoding, @@ -260,13 +267,11 @@ def _force_to_batch_encoding_list_of_chunks( encodings += chunks return encodings - @staticmethod def _concatenate_encodings(left, right): for k in left.keys(): left[k] = left[k] + right[k] - @staticmethod def _split_encoding_into_chunks( encoding: dict, chunk_size: int, drop_remainder: bool = False, task_ids=None @@ -304,7 +309,6 @@ def _split_encoding_into_chunks( enc["task_ids"] = task_ids return chunked_encodings - @staticmethod def _collapse_stream_into_encoding( stream: DataStream[BatchEncoding], @@ -335,7 +339,6 @@ def _collapse_stream_into_encoding( new_encoding[k].append(enc[k]) return new_encoding - # Causal language modeling as a sequence to sequence problem @staticmethod def _causal_lm_padding_as_seq2seq( @@ -350,7 +353,7 @@ def _causal_lm_padding_as_seq2seq( what seq2seq tokenization is doing, but some care needs be taken to ensure the labels are the same length as the input sequence because of the shifting mechanism implemented in most causal language models. - + Collator compatability is extremely important here; because we are setting the labels directly, we should NOT use the causal lm collator, otherwise it will clobber it with a shifted input sequence. @@ -385,9 +388,7 @@ def _causal_lm_padding_as_seq2seq( label_input_ids = labels["input_ids"] model_inputs = tokenizer.pad( - model_inputs, - padding="max_length", - max_length=max_concat_length + model_inputs, padding="max_length", max_length=max_concat_length ) if tokenizer.padding_side.lower() == "left": diff --git a/tests/resources/test_pretrained_model.py b/tests/resources/test_pretrained_model.py index c9251068..7ebf7868 100644 --- a/tests/resources/test_pretrained_model.py +++ b/tests/resources/test_pretrained_model.py @@ -10,9 +10,9 @@ # Third Party from datasets import IterableDataset as TransformersIterableDataset +from torch.utils.data import DataLoader import pytest import torch -from torch.utils.data import DataLoader import transformers # First Party @@ -229,6 +229,7 @@ def get(train_stream): for k in indiv_res: assert indiv_res[k] == batched_res[k] + ### 2. Tests for causal LM framed as a seq2seq problem # NOTE: For these tests, we should be careful to always test left and right padding @pytest.mark.parametrize( @@ -245,7 +246,9 @@ def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding max_lengths = 20 # First, build the output we expect for left / right respectively... input_tok = causal_lm.tokenizer.encode(sample.input) - output_tok = causal_lm.tokenizer.encode(sample.output) + [causal_lm.tokenizer.eos_token_id] + output_tok = causal_lm.tokenizer.encode(sample.output) + [ + causal_lm.tokenizer.eos_token_id + ] concat_res = input_tok + output_tok masked_res = ([-100] * len(input_tok)) + output_tok @@ -254,11 +257,15 @@ def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding assert len(output_tok) < (max_lengths + 1) pads_needed = (1 + 2 * max_lengths) - len(concat_res) if causal_lm.tokenizer.padding_side.lower() == "left": - expected_input_ids = torch.tensor([causal_lm.tokenizer.pad_token_id] * pads_needed + concat_res) + expected_input_ids = torch.tensor( + [causal_lm.tokenizer.pad_token_id] * pads_needed + concat_res + ) expected_attn_mask = torch.tensor([0] * pads_needed + [1] * len(concat_res)) expected_labels = torch.tensor([-100] * pads_needed + masked_res) else: - expected_input_ids = torch.tensor(concat_res + [causal_lm.tokenizer.pad_token_id] * pads_needed) + expected_input_ids = torch.tensor( + concat_res + [causal_lm.tokenizer.pad_token_id] * pads_needed + ) expected_attn_mask = torch.tensor([1] * len(concat_res) + [0] * pads_needed) expected_labels = torch.tensor(masked_res + [-100] * pads_needed) @@ -327,6 +334,7 @@ def test_seq2seq_tok_output_correctness(models_cache_dir): assert hasattr(tok_sample, "task_ids") assert tok_sample["task_ids"] == 0 + ### Tests for collator compatability # These tests should validate that we can use our tokenization function to # build torch loaders around datasets using different collators.