Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added uncertainty estimation scripts for ByT5 architecture models #63

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/greynirseq/utils/infer_file_byt5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
"""
import copy
import itertools
import time
from pathlib import Path

import click
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration

MAX_LENGTH = 512


def batch_by_n(iterable, batch_size):
# collect into batches of n items, yielding after each batch
_iter = iter(iterable)
while True:
# the list call is necessary since islice is a lazily evaluated generator
batch = list(itertools.islice(_iter, batch_size))
if not batch:
break
yield batch


# fmt: off
@click.command()
@click.option("--checkpoint-path", type=click.Path(exists=True, path_type=Path), required=True)
@click.option("--input-path", type=click.Path(exists=True, path_type=Path), required=True)
@click.option("--output-path", type=click.Path(exists=False, path_type=Path), required=True)
@click.option("--use-cpu", is_flag=True, default=False)
@click.option("--num-beams", type=int, default=5)
@click.option("--dropout", type=float, default=0.0)
@click.option("--batch-size", type=int, default=64)
def main(checkpoint_path, input_path, output_path, use_cpu, num_beams, dropout, batch_size):
# fmt: on
use_gpu = not use_cpu
if use_gpu:
assert torch.cuda.is_available()

tokenizer = AutoTokenizer.from_pretrained("google/byt5-base")
# this may require not being in eval mode (i.e. computation graph is cached for gradient)

model = T5ForConditionalGeneration.from_pretrained(str(checkpoint_path))
if dropout > 0.0:
# Inference-time dropout.
# To make sure the correct dropout value is propagated properly,
# we do a "round trip" by reconstructing it with a modified config.
config = copy.deepcopy(model.config)
config.dropout_rate = dropout
model = T5ForConditionalGeneration.from_pretrained(str(checkpoint_path), config=config)
model = model.train()
else:
model = model.eval()
model = model.to("cuda").half() if use_gpu else model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Er half ok? Þjálfum við með FP16? Ég held að við gerum það en ég veit að upphaflegu líkönin eru með Bfloat16


def prepare_model_inputs(data_lines):
model_inputs = tokenizer(
data_lines,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt",
)
if use_gpu:
model_inputs.to("cuda")
return model_inputs

with open(input_path) as f:
# strip is important
test_lines = [line.strip() for line in f]

# parameters for model.generate
default_generate_kwargs = {
"max_length": MAX_LENGTH,
"num_beams": num_beams,
"output_scores": True,
"return_dict_in_generate": False,
}

start_time = time.time()
# make sure output dir exists
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as out_f:
for batch_lines in batch_by_n(test_lines, batch_size=batch_size):
model_inputs = prepare_model_inputs(batch_lines)

model_output = model.generate(**model_inputs, **default_generate_kwargs)

decoded_output = tokenizer.batch_decode(
model_output, skip_special_tokens=True
)
for line in decoded_output:
out_f.write(line + "\n")

end_time = time.time()
print(f"Finished {output_path} in {end_time - start_time:.2f} seconds")


if __name__ == "__main__":
main()
252 changes: 252 additions & 0 deletions src/greynirseq/utils/score_byt5_uncertainty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
"""
Script that performs dropout Monte-Carlo uncertainty estimation with a ByT5 model from checkpoint file.

The output is a JSONL file with the following (line-wise) format:
{
"mean": [float, float, ...],
"std": [float, float, ...],
"source": "string",
"output": "string",
}
"""
import itertools
import json
from pathlib import Path

import click
import torch
import tqdm
from transformers import AutoTokenizer, T5ForConditionalGeneration

MAX_LENGTH = 512


def batch_by_n(iterable, batch_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

þetta fall er skilgreint í hinni skránni líka. Import-a?

# collect into batches of n items, yielding after each batch
_iter = iter(iterable)
while True:
# the list call is necessary since islice is a lazily evaluated generator
batch = list(itertools.islice(_iter, batch_size))
if not batch:
break
yield batch


def set_t5_dropout(model: torch.nn.Module, dropout: float):
"""
Adjust the internal state of dropout modules in a T5 huggingface model.

Storing the computation tree for backpropagation (called "training mode") more
than quadruples the CUDA memory needed for a given batch!
But the torch dropout modules have no "do-inference-at-test-time" hyperparameter,
but we can produce that behavior by changing the internal state ourselves, i.e.
only partially activate the "training mode" of the dropout modules.
This does not mean that the memory usage is the same, since generation: is
- is guaranteed to discard all of the computation tree
- discards the decoder hidden states (only caches k and v values for the attention instead)
- if the batch is unbalanced in terms of padding, the decoding prunes the short sequences
and they don't cause the same memory overhead as the longest sequences like during training.
"""
assert 0 <= dropout <= 1
# flattened recursive traversal
for mod in model.modules():
if not isinstance(mod, torch.nn.Dropout):
continue
mod.p = dropout
if dropout > 0:
# Save some memory
mod.inplace = True
# Calling mod.train() causes the computation tree to be stored, which we don't want.
# This may not be guaranteed by the module "Interface" (in java terminology),
# but at least in current version of torch (tested on CUDA)
# this causes dropout to be performed without storing most or none of the
# computation tree needed for backpropagation.
mod.training = True
assert mod.training
else:
# mod.inplace = False
mod.eval()
assert not mod.training


# fmt: off
@click.command()
@click.option( "--checkpoint-path", type=click.Path(exists=True, path_type=Path), required=True)
@click.option( "--input-path", type=click.Path(exists=True, path_type=Path), required=True)
@click.option( "--output-path", type=click.Path(exists=False, path_type=Path), required=True)
@click.option("--use-cpu", is_flag=True, default=False)
@click.option("--dropout", type=float, default=0.1)
@click.option("--seed", type=int, default=1)
@click.option("--num-iter", type=int, default=5)
@click.option("--batch-size", type=int, default=64)
def main(
checkpoint_path,
input_path,
output_path,
use_cpu,
dropout,
seed,
num_iter,
batch_size,
):
# fmt: on
# (assuming max_seqlen=512 tokens on 40GB A100)
# batch size of 768 works for generation, but 14-16 is max for scoring
# however, using torch.no_grad allows scoring with batch size of
# at least 500 (but less than 768).
# Using such high batch size is extremely suboptimal unless
# the sequence lengths in a batch are homogeneous. We can make
# the batches more uniform by sorting the input by length prior to batching.
# [1000 samples total] @ 64 bsz got 69.3 sec with sorting
# 32 bsz got 83.6 sec with sorting
# 128 bsz got 67.4 sec with sorting
# 256 bsz got 79.4 sec with sorting

print(f"Writing to {output_path} with dropout={dropout} and num_iter={num_iter}")
use_gpu = not use_cpu
if use_gpu:
assert torch.cuda.is_available()

tokenizer = AutoTokenizer.from_pretrained("google/byt5-base")

print("Loading model...")
# We are using the byte-level version, ByT5 (which is implemented using the same class)
model = T5ForConditionalGeneration.from_pretrained(str(checkpoint_path))
model = model.to("cuda").half() if use_gpu else model

def prepare_model_inputs(data_lines):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Þetta er líka skilgreint í hinni skránni og mikið af þessum kóða. Er hægt að samnýta?

model_inputs = tokenizer(
data_lines,
truncation=True,
padding=True,
max_length=MAX_LENGTH,
return_tensors="pt",
)
if use_gpu:
model_inputs = model_inputs.to("cuda")
return model_inputs

with open(input_path) as f:
# strip is important
src_lines = [line.strip() for line in f]

# NOTE: this sorts the whole file, it would be premature optimization to do this with
# in a cleverer manner such as incremental partial sorting where we separately sort the
# first 10k, then sort 10k-20k, etc.
seqlens = torch.tensor([len(line.encode("utf-8")) for line in src_lines]).long()
sorted_indices = seqlens.argsort(descending=True)
src_lines = [src_lines[i] for i in sorted_indices.tolist()]

# parameters for model.generate
default_generate_kwargs = {
"max_length": MAX_LENGTH,
"num_beams": 1,
"output_scores": True,
"return_dict_in_generate": False,
}

print("Generating...")
model = model.eval()
permuted_means = []
permuted_stds = []
decoded_outputs = []
with torch.no_grad():
for batch_idx, batch_lines in enumerate(
batch_by_n(tqdm.tqdm(src_lines), batch_size=batch_size)
):
# this call is very cheap
set_t5_dropout(model, dropout=0.0)
model_inputs = prepare_model_inputs(batch_lines)

# generate hypothesis in eval mode
model_outputs = model.generate(**model_inputs, **default_generate_kwargs)
decoded_output = tokenizer.batch_decode(
model_outputs, skip_special_tokens=True
)
decoded_outputs.extend(decoded_output)

set_t5_dropout(model, dropout=dropout)

# score the hypothesis with inference-time dropout
iter_scores = []
src_ids = model_inputs["input_ids"]
# we need to pass the attention mask to the model (encoder attention mask)
encoder_attention_mask = model_inputs["attention_mask"]
tgt_ids = prepare_model_inputs(decoded_output)["input_ids"]
cpu = torch.device("cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ég vona að þetta sé ekki dýrt kall

tgt_ids_cpu = tgt_ids.to(cpu)
# after we get scores we move the rest of the work to the cpu
decoder_target_mask = tgt_ids_cpu.eq(tokenizer.pad_token_id).logical_not()
# we need to know how many tokens are in each sequence (to filter out padding tokens)
lens = decoder_target_mask.sum(dim=1)
# since the trailing batch may be smaller than batch_size
for iteration_idx in range(num_iter):
# for reproducibility
batch_iter_seed = hash((seed, batch_idx, iteration_idx))
_rng_gen = torch.manual_seed(batch_iter_seed)

# get one monte-carlo iteration of scores with dropout
output = model(
input_ids=src_ids,
labels=tgt_ids,
attention_mask=encoder_attention_mask,
use_cache=True,
)
unnormalized_score = output.logits.detach()
del output
# normalize the scores to get a (log) probability distribution
# and move to cpu (so the reference counting can free it sooner)
score = (
unnormalized_score.log_softmax(-1).detach().clone().float().to(cpu)
)

# Select out the scores of the tokens in the sequence (we get scores for the whole vocab inventory)
# We have:
# - array of indices I_ij
# - array of floats S_ijk
# We want output matrix O_ij after using I as an index into S as follows:
# O_ij = S_{i}{j}{I_ij}
scores_without_padding = score.gather(
dim=2, index=tgt_ids_cpu.unsqueeze(-1)
).squeeze(-1)
# filter out padding tokens
scores_without_padding_tuples = scores_without_padding[
decoder_target_mask
].split(lens.tolist())
iter_scores.append(scores_without_padding_tuples)

# collect, reduce and store the scores
actual_batch_size, _ = tgt_ids_cpu.shape
for seq_index in range(actual_batch_size):
seq_scores = [
iter_scores[iter_index][seq_index] for iter_index in range(num_iter)
]
seq_scores = torch.stack(seq_scores, dim=0)
seq_means = seq_scores.mean(dim=0)
seq_stds = seq_scores.std(dim=0)
permuted_means.append(seq_means)
permuted_stds.append(seq_stds)

# # recover the original order by inverting the length-sorted indices
inverse_indices = sorted_indices.argsort().tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ertu alveg viss?

permuted_means = [permuted_means[i] for i in inverse_indices]
permuted_stds = [permuted_stds[i] for i in inverse_indices]
decoded_outputs = [decoded_outputs[i] for i in inverse_indices]

# write results to disk
with open(output_path, "w") as f:
for mean, std, src_line, decoded_output in zip(
permuted_means, permuted_stds, src_lines, decoded_outputs
):
obj = {
"mean": mean.tolist(),
"std": std.tolist(),
"source": src_line,
"output": decoded_output,
}
f.write(json.dumps(obj, ensure_ascii=False))
f.write("\n")


if __name__ == "__main__":
main()
Loading