Skip to content

Commit

Permalink
linting and formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Sep 29, 2023
1 parent 579db79 commit 560d5bc
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
10 changes: 2 additions & 8 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
default_data_collator,
)
from transformers import AutoModelForCausalLM, default_data_collator
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.optimization import get_linear_schedule_with_warmup
import numpy as np
Expand Down Expand Up @@ -931,10 +928,7 @@ def _get_data_loaders_from_stream(
torch.utils.data.DataLoader
DataLoader to be used for training / evaluating the stream data.
"""
(
tokenize_function,
_,
) = base_model.build_task_tokenize_closure(
(tokenize_function, _,) = base_model.build_task_tokenize_closure(
tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0
)
mapped_stream = train_stream.map(tokenize_function)
Expand Down
29 changes: 15 additions & 14 deletions caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from typing import List, Union

# Third Party
import torch
from transformers import (
AutoModelForCausalLM,
BatchEncoding,
DataCollatorForLanguageModeling,
)
from transformers.models.auto import modeling_auto
import torch

# First Party
from caikit.core.data_model import DataStream
Expand All @@ -37,7 +37,6 @@
from ...data_model import GenerationTrainRecord, PromptOutputModelType
from ...toolkit.verbalizer_utils import render_verbalizer
from .base import PretrainedModelBase
from .hf_auto_seq2seq_lm import HFAutoSeq2SeqLM

log = alog.use_channel("HFRCLM")
error = error_handler.get(log)
Expand Down Expand Up @@ -139,7 +138,6 @@ def tokenize_function(
drop_remainder=drop_remainder,
)


def _get_data_collator(self, **kwargs) -> "transformers.DataCollator":
"""Function to return appropriate data collator based on resource.
Expand Down Expand Up @@ -169,11 +167,21 @@ def _get_data_collator(self, **kwargs) -> "transformers.DataCollator":
tokenizer=self._tokenizer, return_tensors="pt", **collator_kwargs
)


### Tokenization strategy implementations
# Chunked causal language modeling
@classmethod
def _causal_lm_as_chunked(cls, tokenizer, source, target, max_source_length, max_target_length, batched_mode, task_ids, chunk_size, drop_remainder):
def _causal_lm_as_chunked(
cls,
tokenizer,
source,
target,
max_source_length,
max_target_length,
batched_mode,
task_ids,
chunk_size,
drop_remainder,
):
source_ids = tokenizer(source, max_length=max_source_length, truncation=True)
target_ids = tokenizer(target, max_length=max_target_length, truncation=True)

Expand All @@ -197,7 +205,6 @@ def generator_func():
# onto using batch encodings the way that they are intended to be
return chunk_stream


@staticmethod
def _force_to_batch_encoding_list_of_chunks(
source_ids: BatchEncoding,
Expand Down Expand Up @@ -260,13 +267,11 @@ def _force_to_batch_encoding_list_of_chunks(
encodings += chunks
return encodings


@staticmethod
def _concatenate_encodings(left, right):
for k in left.keys():
left[k] = left[k] + right[k]


@staticmethod
def _split_encoding_into_chunks(
encoding: dict, chunk_size: int, drop_remainder: bool = False, task_ids=None
Expand Down Expand Up @@ -304,7 +309,6 @@ def _split_encoding_into_chunks(
enc["task_ids"] = task_ids
return chunked_encodings


@staticmethod
def _collapse_stream_into_encoding(
stream: DataStream[BatchEncoding],
Expand Down Expand Up @@ -335,7 +339,6 @@ def _collapse_stream_into_encoding(
new_encoding[k].append(enc[k])
return new_encoding


# Causal language modeling as a sequence to sequence problem
@staticmethod
def _causal_lm_padding_as_seq2seq(
Expand All @@ -350,7 +353,7 @@ def _causal_lm_padding_as_seq2seq(
what seq2seq tokenization is doing, but some care needs be taken to ensure the labels
are the same length as the input sequence because of the shifting mechanism implemented
in most causal language models.
Collator compatability is extremely important here; because we are setting the labels
directly, we should NOT use the causal lm collator, otherwise it will clobber it with a
shifted input sequence.
Expand Down Expand Up @@ -385,9 +388,7 @@ def _causal_lm_padding_as_seq2seq(

label_input_ids = labels["input_ids"]
model_inputs = tokenizer.pad(
model_inputs,
padding="max_length",
max_length=max_concat_length
model_inputs, padding="max_length", max_length=max_concat_length
)

if tokenizer.padding_side.lower() == "left":
Expand Down
16 changes: 12 additions & 4 deletions tests/resources/test_pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

# Third Party
from datasets import IterableDataset as TransformersIterableDataset
from torch.utils.data import DataLoader
import pytest
import torch
from torch.utils.data import DataLoader
import transformers

# First Party
Expand Down Expand Up @@ -229,6 +229,7 @@ def get(train_stream):
for k in indiv_res:
assert indiv_res[k] == batched_res[k]


### 2. Tests for causal LM framed as a seq2seq problem
# NOTE: For these tests, we should be careful to always test left and right padding
@pytest.mark.parametrize(
Expand All @@ -245,7 +246,9 @@ def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding
max_lengths = 20
# First, build the output we expect for left / right respectively...
input_tok = causal_lm.tokenizer.encode(sample.input)
output_tok = causal_lm.tokenizer.encode(sample.output) + [causal_lm.tokenizer.eos_token_id]
output_tok = causal_lm.tokenizer.encode(sample.output) + [
causal_lm.tokenizer.eos_token_id
]
concat_res = input_tok + output_tok
masked_res = ([-100] * len(input_tok)) + output_tok

Expand All @@ -254,11 +257,15 @@ def test_causal_lm_as_a_sequence_problem_no_truncation(models_cache_dir, padding
assert len(output_tok) < (max_lengths + 1)
pads_needed = (1 + 2 * max_lengths) - len(concat_res)
if causal_lm.tokenizer.padding_side.lower() == "left":
expected_input_ids = torch.tensor([causal_lm.tokenizer.pad_token_id] * pads_needed + concat_res)
expected_input_ids = torch.tensor(
[causal_lm.tokenizer.pad_token_id] * pads_needed + concat_res
)
expected_attn_mask = torch.tensor([0] * pads_needed + [1] * len(concat_res))
expected_labels = torch.tensor([-100] * pads_needed + masked_res)
else:
expected_input_ids = torch.tensor(concat_res + [causal_lm.tokenizer.pad_token_id] * pads_needed)
expected_input_ids = torch.tensor(
concat_res + [causal_lm.tokenizer.pad_token_id] * pads_needed
)
expected_attn_mask = torch.tensor([1] * len(concat_res) + [0] * pads_needed)
expected_labels = torch.tensor(masked_res + [-100] * pads_needed)

Expand Down Expand Up @@ -327,6 +334,7 @@ def test_seq2seq_tok_output_correctness(models_cache_dir):
assert hasattr(tok_sample, "task_ids")
assert tok_sample["task_ids"] == 0


### Tests for collator compatability
# These tests should validate that we can use our tokenization function to
# build torch loaders around datasets using different collators.
Expand Down

0 comments on commit 560d5bc

Please sign in to comment.