Skip to content

Commit

Permalink
Merge pull request caikit#206 from alex-jw-brooks/causal_lm_tok_toggle
Browse files Browse the repository at this point in the history
Causal LM tokenization: Chunking and seq2seq Forward
  • Loading branch information
alex-jw-brooks authored Oct 2, 2023
2 parents f10f415 + 52f9910 commit 24de8fb
Show file tree
Hide file tree
Showing 4 changed files with 531 additions and 167 deletions.
22 changes: 5 additions & 17 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
default_data_collator,
)
from transformers import AutoModelForCausalLM, default_data_collator
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.optimization import get_linear_schedule_with_warmup
import numpy as np
Expand Down Expand Up @@ -901,12 +897,8 @@ def _get_collate_fn(tokenizer: AutoTokenizer, task_type: str) -> Callable:
Callable
collate_fn to be used for processing batches from our datasets.
"""
if task_type == "CAUSAL_LM":
return DataCollatorForLanguageModeling(
tokenizer=tokenizer,
return_tensors="pt",
mlm=False,
)
# HACK: Do NOT use the causal LM collator (for now) because
# want to set labels ourselves. TODO: centralize collator management.
return default_data_collator

@staticmethod
Expand Down Expand Up @@ -947,15 +939,11 @@ def _get_data_loaders_from_stream(
torch.utils.data.DataLoader
DataLoader to be used for training / evaluating the stream data.
"""
(
tokenize_function,
requires_unwrapping,
) = base_model.build_task_tokenize_closure(
(tokenize_function, _,) = base_model.build_task_tokenize_closure(
tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0
)
mapped_stream = train_stream.map(tokenize_function)
if requires_unwrapping:
mapped_stream = mapped_stream.flatten()
# TODO: Deprecate and remove stream wrapper & use trainer
wrapped_stream = SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle)
dataloader = DataLoader(
wrapped_stream, collate_fn=collate_fn, batch_size=batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def _preprocess_function(
mapped_dataset = dataset.map(
base_model.tokenize_function,
fn_kwargs=fn_kwargs,
batched=base_model.REQUIRES_TOKEN_UNWRAPPING,
batched=False,
# Drop the input / output columns; we need to do this for dimensions to play
# happily when operating on batched inputs for causal language modeling.
remove_columns=["input", "output"],
Expand Down
Loading

0 comments on commit 24de8fb

Please sign in to comment.