Skip to content

Commit

Permalink
merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Oct 11, 2023
2 parents 5a82297 + 69873d5 commit 3126d2f
Show file tree
Hide file tree
Showing 115 changed files with 1,277 additions and 295 deletions.
4 changes: 3 additions & 1 deletion docs/source/en/main_classes/callback.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Callbacks are "read only" pieces of code, apart from the [`TrainerControl`] obje
cannot change anything in the training loop. For customizations that require changes in the training loop, you should
subclass [`Trainer`] and override the methods you need (see [trainer](trainer) for examples).

By default a [`Trainer`] will use the following callbacks:
By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] will use the following callbacks.

- [`DefaultFlowCallback`] which handles the default behavior for logging, saving and evaluation.
- [`PrinterCallback`] or [`ProgressCallback`] to display progress and print the
Expand All @@ -45,6 +45,8 @@ By default a [`Trainer`] will use the following callbacks:
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.

If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).

The main class that implements callbacks is [`TrainerCallback`]. It gets the
[`TrainingArguments`] used to instantiate the [`Trainer`], can access that
Trainer's internal state via [`TrainerState`], and can take some actions on the training loop via
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="t

def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):

def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def cross_entropy_loss(logits, labels):
# region Create learning rate function
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):

def create_learning_rate_fn(
num_train_steps: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):

def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def cross_entropy_loss(logits, labels):

def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/token-classification/run_flax_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def cross_entropy_loss(logits, labels):

def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/vision/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):

def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class DataTrainingArguments:
},
)
source_prefix: Optional[str] = field(
default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)

forced_bos_token: Optional[str] = field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):

def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):

def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,20 @@ class GenerationConfig(PushToHubMixin):
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
> Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192)
num_assistant_tokens (`int`, *optional*, defaults to 5):
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
model requires lots of corrections, lower speed-ups are reached.
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`):
Defines the schedule at which max assistant tokens shall be changed during inference.
- `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else
reduce by 1
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
> Wild card
generation_kwargs:
Expand Down Expand Up @@ -294,6 +308,10 @@ def __init__(self, **kwargs):
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

# Assistant generation
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})

Expand Down
130 changes: 74 additions & 56 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,10 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}

# allow assistant_encoder_outputs to be passed if we're doing assisted generating
if "assistant_encoder_outputs" in model_kwargs:
model_args |= {"assistant_encoder_outputs"}

for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
Expand Down Expand Up @@ -1297,6 +1301,43 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)

def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]:
if self.config.is_encoder_decoder:
key = "decoder_attention_mask"
else:
key = "attention_mask"

if key not in model_kwargs:
return model_kwargs

mask = model_kwargs[key]
mask_extension_length = new_mask_length - mask.shape[1]

if mask_extension_length < 0:
raise ValueError("Cannot extend attention mask to a length less than it already is")

model_kwargs[key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], mask_extension_length))],
dim=-1,
)

return model_kwargs

def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs

token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
extension_length = new_length - token_type_ids.shape[1]
token_type_copies = final_token_type.repeat(1, extension_length)
model_kwargs["token_type_ids"] = torch.cat(
[model_kwargs["token_type_ids"], token_type_copies],
dim=-1,
)

return model_kwargs

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1575,7 +1616,7 @@ def generate(
raise ValueError("assisted generate requires `use_cache=True`")

# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
if assistant_model.config.is_encoder_decoder:
if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs:
assistant_model_kwargs = copy.deepcopy(model_kwargs)
inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
Expand Down Expand Up @@ -4310,8 +4351,14 @@ def assisted_decoding(
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# Assistant: initialize assistant-related variables
if not hasattr(assistant_model, "max_assistant_tokens"):
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
if hasattr(assistant_model, "num_assistant_tokens"):
warnings.warn(
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be removed in v.37. Make sure to set `num_assistant_tokens` via the generation_config instead.",
FutureWarning,
)
num_assistant_tokens = assistant_model.num_assistant_tokens
else:
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens

# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
Expand Down Expand Up @@ -4384,26 +4431,23 @@ def assisted_decoding(
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
for _ in range(int(assistant_model.max_assistant_tokens)):
for _ in range(int(num_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
assist_attn = torch.ones_like(candidate_input_ids)
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=assist_inputs,
decoder_attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(
assist_inputs,
attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
)
else:
Expand Down Expand Up @@ -4441,61 +4485,35 @@ def assisted_decoding(
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.

# 2.1. Run a forward pass on the candidate sequence
if "past_key_values" in model_kwargs:
model_attn = torch.ones_like(candidate_input_ids)
model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=model_input_ids,
decoder_attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
outputs = self(
model_input_ids,
attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
else:
outputs = self(
candidate_input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
# 2.1. Prepare the model inputs
candidate_kwargs = copy.copy(model_kwargs)
candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])

model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)

# 2.2. Process the new logits
# 2.2. Run a forward pass on the candidate sequence
outputs = self(
**model_inputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

# 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
if len(logits_warper) > 0:
for i in range(candidate_length):
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])

# 3. Obtain the next tokens from the original model logits.
if do_sample:
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
selected_tokens = new_logits.argmax(dim=-1)

# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
Expand Down Expand Up @@ -4529,13 +4547,13 @@ def assisted_decoding(
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if n_matches == int(assistant_model.max_assistant_tokens):
assistant_model.max_assistant_tokens += 2.0
else:
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
if n_matches == int(num_assistant_tokens):
num_assistant_tokens += 2.0
else:
num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0)

# Assistant: main logic end

if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need

Expand Down
Loading

0 comments on commit 3126d2f

Please sign in to comment.