diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 15121018..96ffde63 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -35,7 +35,6 @@ from tqdm import tqdm from transformers import ( AutoModelForCausalLM, - DataCollatorForLanguageModeling, default_data_collator, ) from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -890,13 +889,8 @@ def _get_collate_fn(tokenizer: AutoTokenizer, task_type: str) -> Callable: Callable collate_fn to be used for processing batches from our datasets. """ - # HACK: Do NOT use the causal LM collator (for now) because we want to set the labels ourselves... - # if task_type == "CAUSAL_LM": - # return DataCollatorForLanguageModeling( - # tokenizer=tokenizer, - # return_tensors="pt", - # mlm=False, - # ) + # HACK: Do NOT use the causal LM collator (for now) because + # want to set labels ourselves. TODO: centralize collator management. return default_data_collator @staticmethod @@ -944,8 +938,7 @@ def _get_data_loaders_from_stream( tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0 ) mapped_stream = train_stream.map(tokenize_function) - # if requires_unwrapping: - # mapped_stream = mapped_stream.flatten() + # TODO: Deprecate and remove stream wrapper & use trainer wrapped_stream = SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) dataloader = DataLoader( wrapped_stream, collate_fn=collate_fn, batch_size=batch_size