From e4c7a24f94503a8282fd4b1605c5c9c72514d973 Mon Sep 17 00:00:00 2001 From: paulh Date: Fri, 8 Mar 2024 11:31:36 +0000 Subject: [PATCH 1/2] updates --- src/transformers/generation/streamers.py | 22 +- src/transformers/generation/utils.py | 1567 +++++++++++++++++----- 2 files changed, 1239 insertions(+), 350 deletions(-) diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index c75b43466af7a8..7f03fad1306b10 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -34,6 +34,10 @@ def end(self): """Function that is called by `.generate()` to signal the end of generation""" raise NotImplementedError() + def is_running(self) -> bool: + """Function that is called by `.generate()` to check if the streamer has ended""" + raise NotImplementedError() + class TextStreamer(BaseStreamer): """ @@ -69,7 +73,9 @@ class TextStreamer(BaseStreamer): ``` """ - def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): + def __init__( + self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs + ): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.decode_kwargs = decode_kwargs @@ -203,12 +209,17 @@ class TextIteratorStreamer(TextStreamer): """ def __init__( - self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs + self, + tokenizer: "AutoTokenizer", + skip_prompt: bool = False, + timeout: Optional[float] = None, + **decode_kwargs ): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.text_queue = Queue() self.stop_signal = None self.timeout = timeout + self.stopped = False def on_finalized_text(self, text: str, stream_end: bool = False): """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" @@ -216,6 +227,13 @@ def on_finalized_text(self, text: str, stream_end: bool = False): if stream_end: self.text_queue.put(self.stop_signal, timeout=self.timeout) + def end(self): + self.stopped = True + self.on_finalized_text("", stream_end=True) + + def is_running(self) -> bool: + return not self.stopped + def __iter__(self): return self diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1d7eef755bf984..74ddd5de529664 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -34,7 +34,12 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging +from ..utils import ( + ModelOutput, + is_accelerate_available, + is_torchdynamo_compiling, + logging, +) from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( @@ -313,15 +318,21 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput -GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] +GreedySearchOutput = Union[ + GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput +] SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] -ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] +ContrastiveSearchOutput = Union[ + ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput +] # Typing shortcuts GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] -GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput] +GenerateBeamOutput = Union[ + GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput +] GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] @@ -376,7 +387,9 @@ def _prepare_model_inputs( else: input_name = self.main_input_name - model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + model_kwargs = { + k: v for k, v in model_kwargs.items() if v is not None or k != input_name + } # 2. check whether model_input_name is passed as kwarg # if yes and `inputs` is None use kwarg inputs @@ -398,7 +411,9 @@ def _prepare_model_inputs( if input_name == "input_ids" and "inputs_embeds" in model_kwargs: if not self.config.is_encoder_decoder: has_inputs_embeds_forwarding = "inputs_embeds" in set( - inspect.signature(self.prepare_inputs_for_generation).parameters.keys() + inspect.signature( + self.prepare_inputs_for_generation + ).parameters.keys() ) if not has_inputs_embeds_forwarding: raise ValueError( @@ -408,16 +423,22 @@ def _prepare_model_inputs( ) # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of # the attention mask) can rely on the actual model input. - model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, model_kwargs=model_kwargs + model_kwargs["input_ids"] = ( + self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs=model_kwargs + ) ) else: if inputs is not None: - raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") + raise ValueError( + "You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one." + ) inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" # 4. if `inputs` is still None, try to create `input_ids` from BOS token - inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) + inputs = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs + ) return inputs, input_name, model_kwargs def _maybe_initialize_input_ids_for_generation( @@ -437,7 +458,9 @@ def _maybe_initialize_input_ids_for_generation( return torch.ones(shape, dtype=torch.long, device=self.device) * -100 if bos_token_id is None: - raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + raise ValueError( + "`bos_token_id` has to be defined when no `input_ids` are provided." + ) # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with # soft-prompting or in multimodal implementations built on top of decoder-only language models. @@ -449,7 +472,10 @@ def _maybe_initialize_input_ids_for_generation( if "inputs_embeds" in model_kwargs: return torch.ones((batch_size, 0), dtype=torch.long, device=self.device) - return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + return ( + torch.ones((batch_size, 1), dtype=torch.long, device=self.device) + * bos_token_id + ) def _prepare_attention_mask_for_generation( self, @@ -457,20 +483,32 @@ def _prepare_attention_mask_for_generation( pad_token_id: Optional[int], eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: - is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [ + torch.int, + torch.long, + ] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + pad_token_id not in eos_token_id + ) # Check if input is input_ids and padded -> only then is attention_mask defined - if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: + if ( + is_input_ids + and is_pad_token_in_inputs + and is_pad_token_not_equal_to_eos_token_id + ): return inputs.ne(pad_token_id).long() else: return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) def _prepare_encoder_decoder_kwargs_for_generation( - self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str] = None, ) -> Dict[str, Any]: # 1. get encoder encoder = self.get_encoder() @@ -490,14 +528,20 @@ def _prepare_encoder_decoder_kwargs_for_generation( if not any(argument.startswith(p) for p in irrelevant_prefix) } encoder_signature = set(inspect.signature(encoder.forward).parameters) - encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + encoder_accepts_wildcard = ( + "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + ) if not encoder_accepts_wildcard: encoder_kwargs = { - argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature + argument: value + for argument, value in encoder_kwargs.items() + if argument in encoder_signature } # 3. make sure that encoder returns `ModelOutput` - model_input_name = model_input_name if model_input_name is not None else self.main_input_name + model_input_name = ( + model_input_name if model_input_name is not None else self.main_input_name + ) encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) @@ -524,7 +568,9 @@ def _prepare_decoder_input_ids_for_generation( decoder_input_ids = None # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + decoder_start_token_id = self._get_decoder_start_token_id( + decoder_start_token_id, bos_token_id + ) if device is None: device = self.device if isinstance(decoder_start_token_id, list): @@ -532,18 +578,24 @@ def _prepare_decoder_input_ids_for_generation( raise ValueError( f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}" ) - decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device) + decoder_input_ids_start = torch.tensor( + decoder_start_token_id, dtype=torch.long, device=device + ) decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) else: decoder_input_ids_start = ( - torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + torch.ones((batch_size, 1), dtype=torch.long, device=device) + * decoder_start_token_id ) # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: decoder_input_ids = decoder_input_ids_start # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token - elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + elif ( + self.config.model_type == "vision-encoder-decoder" + and "donut" in self.name_or_path.lower() + ): pass elif self.config.model_type in ["whisper"]: pass @@ -556,11 +608,16 @@ def _prepare_decoder_input_ids_for_generation( isinstance(decoder_start_token_id, torch.Tensor) and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item() ): - decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + decoder_input_ids = torch.cat( + [decoder_input_ids_start, decoder_input_ids], dim=-1 + ) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = torch.cat( - (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + ( + torch.ones_like(decoder_attention_mask)[:, :1], + decoder_attention_mask, + ), dim=-1, ) model_kwargs["decoder_attention_mask"] = decoder_attention_mask @@ -568,14 +625,20 @@ def _prepare_decoder_input_ids_for_generation( return decoder_input_ids, model_kwargs def _get_decoder_start_token_id( - self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None + self, + decoder_start_token_id: Union[int, List[int]] = None, + bos_token_id: int = None, ) -> int: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.generation_config.decoder_start_token_id ) - bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id + bos_token_id = ( + bos_token_id + if bos_token_id is not None + else self.generation_config.bos_token_id + ) if decoder_start_token_id is not None: return decoder_start_token_id @@ -596,8 +659,12 @@ def _expand_inputs_for_generation( def _expand_dict_for_generation(dict_to_expand): for key in dict_to_expand: - if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor): - dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + if dict_to_expand[key] is not None and isinstance( + dict_to_expand[key], torch.Tensor + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave( + expand_size, dim=0 + ) return dict_to_expand if input_ids is not None: @@ -607,12 +674,18 @@ def _expand_dict_for_generation(dict_to_expand): if is_encoder_decoder: if model_kwargs.get("encoder_outputs") is None: - raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation( + model_kwargs["encoder_outputs"] + ) return input_ids, model_kwargs - def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False): + def _extract_past_from_model_output( + self, outputs: ModelOutput, standardize_cache_format: bool = False + ): past_key_values = None if "past_key_values" in outputs: past_key_values = outputs.past_key_values @@ -624,7 +697,9 @@ def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cach # Bloom fix: standardizes the cache format when requested if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"): batch_size = outputs.logits.shape[0] - past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size) + past_key_values = self._convert_to_standard_cache( + past_key_values, batch_size=batch_size + ) return past_key_values def _update_model_kwargs_for_generation( @@ -645,21 +720,32 @@ def _update_model_kwargs_for_generation( # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + model_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 + ) if not is_encoder_decoder: # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + [ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], 1)), + ], + dim=-1, ) else: # update decoder attention mask if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] model_kwargs["decoder_attention_mask"] = torch.cat( - [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], + [ + decoder_attention_mask, + decoder_attention_mask.new_ones( + (decoder_attention_mask.shape[0], 1) + ), + ], dim=-1, ) @@ -725,23 +811,52 @@ def _get_logits_warper( # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` - if generation_config.temperature is not None and generation_config.temperature != 1.0: + if ( + generation_config.temperature is not None + and generation_config.temperature != 1.0 + ): warpers.append(TemperatureLogitsWarper(generation_config.temperature)) if generation_config.top_k is not None and generation_config.top_k != 0: - warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) + warpers.append( + TopKLogitsWarper( + top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep + ) + ) if generation_config.top_p is not None and generation_config.top_p < 1.0: - warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.typical_p is not None and generation_config.typical_p < 1.0: warpers.append( - TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + TopPLogitsWarper( + top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep + ) + ) + if ( + generation_config.typical_p is not None + and generation_config.typical_p < 1.0 + ): + warpers.append( + TypicalLogitsWarper( + mass=generation_config.typical_p, + min_tokens_to_keep=min_tokens_to_keep, + ) ) - if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: + if ( + generation_config.epsilon_cutoff is not None + and 0.0 < generation_config.epsilon_cutoff < 1.0 + ): warpers.append( - EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) + EpsilonLogitsWarper( + epsilon=generation_config.epsilon_cutoff, + min_tokens_to_keep=min_tokens_to_keep, + ) ) - if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: + if ( + generation_config.eta_cutoff is not None + and 0.0 < generation_config.eta_cutoff < 1.0 + ): warpers.append( - EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep) + EtaLogitsWarper( + epsilon=generation_config.eta_cutoff, + min_tokens_to_keep=min_tokens_to_keep, + ) ) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: @@ -766,7 +881,10 @@ def _get_logits_processor( # instantiate processors list processors = LogitsProcessorList() - if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1: + if ( + generation_config.guidance_scale is not None + and generation_config.guidance_scale != 1 + ): processors.append( UnbatchedClassifierFreeGuidanceLogitsProcessor( generation_config.guidance_scale, @@ -777,9 +895,16 @@ def _get_logits_processor( ) ) if generation_config.sequence_bias is not None: - processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias)) + processors.append( + SequenceBiasLogitsProcessor( + sequence_bias=generation_config.sequence_bias + ) + ) - if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: + if ( + generation_config.diversity_penalty is not None + and generation_config.diversity_penalty > 0.0 + ): processors.append( HammingDiversityLogitsProcessor( diversity_penalty=generation_config.diversity_penalty, @@ -793,30 +918,51 @@ def _get_logits_processor( ): processors.append( EncoderRepetitionPenaltyLogitsProcessor( - penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids + penalty=generation_config.encoder_repetition_penalty, + encoder_input_ids=encoder_input_ids, ) ) - if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: - processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) - if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: - processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) + if ( + generation_config.repetition_penalty is not None + and generation_config.repetition_penalty != 1.0 + ): + processors.append( + RepetitionPenaltyLogitsProcessor( + penalty=generation_config.repetition_penalty + ) + ) + if ( + generation_config.no_repeat_ngram_size is not None + and generation_config.no_repeat_ngram_size > 0 + ): + processors.append( + NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size) + ) if ( generation_config.encoder_no_repeat_ngram_size is not None and generation_config.encoder_no_repeat_ngram_size > 0 ): processors.append( - EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids) + EncoderNoRepeatNGramLogitsProcessor( + generation_config.encoder_no_repeat_ngram_size, encoder_input_ids + ) ) if generation_config.bad_words_ids is not None: processors.append( - NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) + NoBadWordsLogitsProcessor( + generation_config.bad_words_ids, generation_config.eos_token_id + ) ) if ( generation_config.min_length is not None and generation_config.eos_token_id is not None and generation_config.min_length > 0 ): - processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) + processors.append( + MinLengthLogitsProcessor( + generation_config.min_length, generation_config.eos_token_id + ) + ) if ( generation_config.min_new_tokens is not None and generation_config.eos_token_id is not None @@ -824,20 +970,27 @@ def _get_logits_processor( ): processors.append( MinNewTokensLengthLogitsProcessor( - input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id + input_ids_seq_length, + generation_config.min_new_tokens, + generation_config.eos_token_id, ) ) if prefix_allowed_tokens_fn is not None: processors.append( PrefixConstrainedLogitsProcessor( - prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups + prefix_allowed_tokens_fn, + generation_config.num_beams // generation_config.num_beam_groups, ) ) if generation_config.forced_bos_token_id is not None: - processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) + processors.append( + ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id) + ) if generation_config.forced_eos_token_id is not None: processors.append( - ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) + ForcedEOSTokenLogitsProcessor( + generation_config.max_length, generation_config.forced_eos_token_id + ) ) if generation_config.remove_invalid_values is True: processors.append(InfNanRemoveLogitsProcessor()) @@ -850,22 +1003,31 @@ def _get_logits_processor( ) ) if generation_config.suppress_tokens is not None: - processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens)) + processors.append( + SuppressTokensLogitsProcessor(generation_config.suppress_tokens) + ) if generation_config.begin_suppress_tokens is not None: begin_index = input_ids_seq_length begin_index = ( begin_index - if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) + if ( + input_ids_seq_length > 1 + or generation_config.forced_bos_token_id is None + ) else begin_index + 1 ) if generation_config.forced_decoder_ids is not None: # generation starts after the last token that is forced begin_index += generation_config.forced_decoder_ids[-1][0] processors.append( - SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) + SuppressTokensAtBeginLogitsProcessor( + generation_config.begin_suppress_tokens, begin_index + ) ) if generation_config.forced_decoder_ids is not None: - processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) + processors.append( + ForceTokensLogitsProcessor(generation_config.forced_decoder_ids) + ) processors = self._merge_criteria_processor_list(processors, logits_processor) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: @@ -873,11 +1035,15 @@ def _get_logits_processor( return processors def _get_stopping_criteria( - self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] + self, + generation_config: GenerationConfig, + stopping_criteria: Optional[StoppingCriteriaList], ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: - max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + max_position_embeddings = getattr( + self.config, "max_position_embeddings", None + ) criteria.append( MaxLengthCriteria( max_length=generation_config.max_length, @@ -899,7 +1065,11 @@ def _merge_criteria_processor_list( for default in default_list: for custom in custom_list: if type(custom) is type(default): - object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" + object_type = ( + "stopping criteria" + if isinstance(custom, StoppingCriteria) + else "logits processor" + ) raise ValueError( f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" f" `.generate()`, but it has already been created with the values {default}. {default} has been" @@ -995,7 +1165,9 @@ def compute_transition_scores( # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent # to a beam search approach were the first (and only) beam is always selected if beam_indices is None: - beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) + beam_indices = ( + torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) + ) beam_indices = beam_indices.expand(-1, len(scores)) # 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being @@ -1061,7 +1233,10 @@ def _validate_model_class(self): def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # If a `Cache` instance is passed, checks whether the model is compatible with it - if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: + if ( + isinstance(model_kwargs.get("past_key_values", None), Cache) + and not self._supports_cache_class + ): raise ValueError( f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " "check the model documentation for supported cache formats." @@ -1073,7 +1248,9 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): model_kwargs.pop(key, None) unused_model_args = [] - model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) + model_args = set( + inspect.signature(self.prepare_inputs_for_generation).parameters + ) # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) if "kwargs" in model_args or "model_kwargs" in model_args: @@ -1118,11 +1295,17 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): " generate arguments will also show up in this list)" ) - def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + def _validate_generated_length( + self, generation_config, input_ids_length, has_default_max_length + ): """Performs validation related to the resulting generated length""" # 1. Max length warnings related to poor parameterization - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: + if ( + has_default_max_length + and generation_config.max_new_tokens is None + and generation_config.max_length == 20 + ): # 20 is the default max_length of the generation config warnings.warn( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " @@ -1131,7 +1314,9 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) if input_ids_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + input_ids_string = ( + "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + ) raise ValueError( f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" @@ -1144,13 +1329,15 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de "increase the maximum length." ) if has_default_max_length: - min_length_error_suffix += ( - f" Note that `max_length` is set to {generation_config.max_length}, its default value." - ) - if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + min_length_error_suffix += f" Note that `max_length` is set to {generation_config.max_length}, its default value." + if ( + generation_config.min_length is not None + and generation_config.min_length > generation_config.max_length + ): warnings.warn( f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than" - f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + f" the maximum possible length ({generation_config.max_length})." + + min_length_error_suffix, UserWarning, ) if generation_config.min_new_tokens is not None: @@ -1159,7 +1346,8 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de warnings.warn( f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when " f"added to the prompt length ({input_ids_length}), is larger than" - f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + f" the maximum possible length ({generation_config.max_length})." + + min_length_error_suffix, UserWarning, ) @@ -1184,7 +1372,8 @@ def _prepare_generation_config( if ( not is_torchdynamo_compiling() and self.generation_config._from_model_config - and self.generation_config._original_object_hash == hash(self.generation_config) + and self.generation_config._original_object_hash + == hash(self.generation_config) and self.config._has_non_default_generation_parameters() ): new_generation_config = GenerationConfig.from_model_config(self.config) @@ -1203,7 +1392,9 @@ def _prepare_generation_config( if is_torchdynamo_compiling(): model_kwargs = kwargs generate_attributes_in_kwargs = [ - key for key, value in kwargs.items() if getattr(generation_config, key, None) != value + key + for key, value in kwargs.items() + if getattr(generation_config, key, None) != value ] if len(generate_attributes_in_kwargs) > 0: raise ValueError( @@ -1223,7 +1414,9 @@ def generate( generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, @@ -1316,7 +1509,9 @@ def generate( """ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() - generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, **kwargs + ) self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined @@ -1325,10 +1520,19 @@ def generate( synced_gpus = True else: synced_gpus = False - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + if ( + generation_config.pad_token_id is None + and generation_config.eos_token_id is not None + ): if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " @@ -1337,7 +1541,9 @@ def generate( eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." + ) generation_config.pad_token_id = eos_token_id # 3. Define model inputs @@ -1360,12 +1566,22 @@ def generate( else: model_kwargs["use_cache"] = generation_config.use_cache - accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys() + ) requires_attention_mask = "encoder_outputs" not in model_kwargs - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + if ( + model_kwargs.get("attention_mask", None) is None + and requires_attention_mask + and accepts_attention_mask + ): + model_kwargs["attention_mask"] = ( + self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, + ) ) # decoder-only models should use left-padding for generation @@ -1375,7 +1591,8 @@ def generate( if ( generation_config.pad_token_id is not None and len(inputs_tensor.shape) == 2 - and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) + > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " @@ -1400,14 +1617,21 @@ def generate( device=inputs_tensor.device, ) else: - input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + input_ids = ( + inputs_tensor + if model_input_name == "input_ids" + else model_kwargs.pop("input_ids") + ) if streamer is not None: streamer.put(input_ids.cpu()) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) if generation_config.max_new_tokens is not None: if not has_default_max_length and generation_config.max_length is not None: logger.warning( @@ -1416,7 +1640,9 @@ def generate( "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_length + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_length + ) # otherwise the total length [inputs-embeds-len + new-tokens-len] will go beyond indicated `max_length`` elif ( @@ -1425,7 +1651,9 @@ def generate( and not self.config.is_encoder_decoder ): generation_config.max_length -= inputs_tensor.shape[1] - generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) + generation_config.min_length = max( + generation_config.min_length - inputs_tensor.shape[1], 0 + ) if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation == "static": @@ -1439,9 +1667,15 @@ def generate( "The `generation_config` defines a `cache_implementation` that is not compatible with this model." " Make sure it has a `_setup_cache` function." ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + self._setup_cache( + cache_cls, + max_batch_size=batch_size, + max_cache_len=generation_config.max_length, + ) - self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + self._validate_generated_length( + generation_config, input_ids_length, has_default_max_length + ) # 7. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) @@ -1486,7 +1720,9 @@ def generate( f"but is {generation_config.num_return_sequences}." ) if batch_size > 1: - raise ValueError("assisted generate is only supported for batch_size = 1") + raise ValueError( + "assisted generate is only supported for batch_size = 1" + ) if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") @@ -1506,7 +1742,11 @@ def generate( candidate_generator=candidate_generator, do_sample=generation_config.do_sample, logits_processor=prepared_logits_processor, - logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, + logits_warper=( + self._get_logits_warper(generation_config) + if generation_config.do_sample + else None + ), stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, @@ -1712,10 +1952,15 @@ def typeerror(): if isinstance(word_ids[0], list): if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() - if any(not isinstance(token_ids, list) for token_ids in word_ids): + if any( + not isinstance(token_ids, list) for token_ids in word_ids + ): typeerror() if any( - any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) + any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in token_ids + ) for token_ids in word_ids ): typeerror() @@ -1724,7 +1969,10 @@ def typeerror(): else: if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() - if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): + if any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in word_ids + ): typeerror() constraint = PhrasalConstraint(word_ids) @@ -1888,22 +2136,56 @@ def _contrastive_search( ['DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMindā€™s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it'] ```""" # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - sequential = sequential if sequential is not None else self.generation_config.low_memory + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) + sequential = ( + sequential if sequential is not None else self.generation_config.low_memory + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) + if eos_token_id is not None + else None + ) + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_logits = ( + output_logits + if output_logits is not None + else self.generation_config.output_logits + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -1914,19 +2196,33 @@ def _contrastive_search( # init attention / hidden states / scores tuples raw_logits = () if (return_dict_in_generate and output_logits) else None scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + unfinished_sequences = torch.ones( + input_ids.shape[0], dtype=torch.long, device=input_ids.device + ) this_peer_finished = False # used by synced_gpus only batch_size = input_ids.shape[0] @@ -1935,7 +2231,9 @@ def _contrastive_search( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -1947,12 +2245,17 @@ def _contrastive_search( if model_kwargs.get("past_key_values") is None: # prepare inputs model_kwargs["use_cache"] = True - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs + ) # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save # the `encoder_outputs` outputs = self( - **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions + **model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, ) # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with @@ -1975,7 +2278,9 @@ def _contrastive_search( if not sequential: # Expands model inputs top_k times, for batched forward passes (akin to beam search). _, model_kwargs = self._expand_inputs_for_generation( - expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + expand_size=top_k, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, ) past_key_values = model_kwargs.get("past_key_values") @@ -1996,8 +2301,12 @@ def _contrastive_search( # contrastive_search main logic start: # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # degeneration penalty - processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) - processed_logit_for_next_step = logits_warper(input_ids, processed_logit_for_next_step) + processed_logit_for_next_step = logits_processor( + input_ids, logit_for_next_step + ) + processed_logit_for_next_step = logits_warper( + input_ids, processed_logit_for_next_step + ) next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1) top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) @@ -2010,7 +2319,9 @@ def _contrastive_search( scores += (processed_logit_for_next_step,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -2039,7 +2350,9 @@ def _contrastive_search( all_outputs = [] for i in range(top_k): # compute the candidate tokens by the language model and collect their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) + next_model_inputs = self.prepare_inputs_for_generation( + top_k_ids[:, i].view(-1, 1), **model_kwargs + ) outputs = self( **next_model_inputs, @@ -2053,7 +2366,9 @@ def _contrastive_search( else: # compute the candidate tokens by the language model and collect their hidden_states # assembles top_k_ids into batch of size k - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + next_model_inputs = self.prepare_inputs_for_generation( + top_k_ids.view(-1, 1), **model_kwargs + ) outputs = self( **next_model_inputs, @@ -2076,7 +2391,9 @@ def _contrastive_search( # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't # introduce (noticeable) slowdowns on single-device runs. - selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) + selected_idx = _ranking_fast( + context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k + ) selected_idx = selected_idx.to("cpu") # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing @@ -2085,11 +2402,15 @@ def _contrastive_search( next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) next_hidden = next_hidden[range(batch_size), selected_idx, :] - last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) + last_hidden_states = torch.cat( + [last_hidden_states, next_hidden.unsqueeze(1)], dim=1 + ) next_decoder_hidden_states = () for layer in full_hidden_states: - layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] + layer = torch.stack(torch.split(layer, top_k))[ + range(batch_size), selected_idx, : + ] next_decoder_hidden_states += (layer,) # generate past_key_values cache of only the selected token @@ -2107,19 +2428,27 @@ def _contrastive_search( next_past_key_values = selected_outputs["past_key_values"] else: - next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) + next_past_key_values = self._extract_past_from_model_output( + outputs, standardize_cache_format=True + ) new_key_values = () for layer in next_past_key_values: items = () # item is either the key or the value matrix for item in layer: - item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] - item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz] + item = torch.stack( + torch.split(item, top_k, dim=0) + ) # [B, K, num_head, seq_len, esz] + item = item[ + range(batch_size), selected_idx, ... + ] # [B, num_head, seq_len, esz] items += (item,) new_key_values += (items,) next_past_key_values = new_key_values - logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] + logit_for_next_step = torch.stack(torch.split(logits, top_k))[ + range(batch_size), selected_idx, : + ] # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration if self.config.is_encoder_decoder: @@ -2127,10 +2456,14 @@ def _contrastive_search( next_step_decoder_attentions = () if output_attentions: for layer in outputs.cross_attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] + layer = torch.stack(torch.split(layer, top_k, dim=0))[ + range(batch_size), selected_idx, ... + ] next_step_cross_attentions += (layer,) for layer in outputs.decoder_attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] + layer = torch.stack(torch.split(layer, top_k, dim=0))[ + range(batch_size), selected_idx, ... + ] next_step_decoder_attentions += (layer,) outputs = Seq2SeqLMOutput( past_key_values=next_past_key_values, @@ -2142,7 +2475,9 @@ def _contrastive_search( next_step_attentions = () if output_attentions: for layer in outputs.attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] + layer = torch.stack(torch.split(layer, top_k, dim=0))[ + range(batch_size), selected_idx, ... + ] next_step_attentions += (layer,) outputs = CausalLMOutputWithPast( past_key_values=next_past_key_values, @@ -2157,25 +2492,38 @@ def _contrastive_search( # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) + if not streamer.is_running(): + break model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + model_inputs=model_inputs, ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + next_tokens.tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) ) # stop when each sentence is finished - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores + ) if unfinished_sequences.max() == 0: this_peer_finished = True @@ -2342,26 +2690,54 @@ def _greedy_search( ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) + if eos_token_id is not None + else None + ) + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -2372,26 +2748,42 @@ def _greedy_search( # init attention / hidden states / scores tuples raw_logits = () if (return_dict_in_generate and output_logits) else None scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + unfinished_sequences = torch.ones( + input_ids.shape[0], dtype=torch.long, device=input_ids.device + ) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -2425,7 +2817,9 @@ def _greedy_search( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -2443,13 +2837,20 @@ def _greedy_search( # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) + if not streamer.is_running(): + break + model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, @@ -2460,10 +2861,14 @@ def _greedy_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + next_tokens.tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores + ) # stop when each sentence is finished if unfinished_sequences.max() == 0: @@ -2639,28 +3044,62 @@ def _sample( ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.'] ```""" # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) + if eos_token_id is not None + else None + ) + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_logits = ( + output_logits + if output_logits is not None + else self.generation_config.output_logits + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -2671,19 +3110,33 @@ def _sample( # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + unfinished_sequences = torch.ones( + input_ids.shape[0], dtype=torch.long, device=input_ids.device + ) this_peer_finished = False # used by synced_gpus only # auto-regressive generation @@ -2691,7 +3144,9 @@ def _sample( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -2726,7 +3181,9 @@ def _sample( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -2745,24 +3202,37 @@ def _sample( # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) + if not streamer.is_running(): + break model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + model_inputs=model_inputs, ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + next_tokens.tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores + ) # stop when each sentence is finished if unfinished_sequences.max() == 0: @@ -2964,29 +3434,62 @@ def _beam_search( ['Wie alt bist du?'] ```""" # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - sequential = sequential if sequential is not None else self.generation_config.low_memory + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + sequential = ( + sequential if sequential is not None else self.generation_config.low_memory + ) if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) if len(stopping_criteria) == 0: - warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + warnings.warn( + "You don't have defined any stopping_criteria, this will likely loop forever", + UserWarning, + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_logits = ( + output_logits + if output_logits is not None + else self.generation_config.output_logits + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -3008,22 +3511,38 @@ def _beam_search( scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + tuple(() for _ in range(batch_beam_size)) + if (return_dict_in_generate and output_scores) + else None + ) + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores = torch.zeros( + (batch_size, num_beams), dtype=torch.float, device=input_ids.device + ) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) @@ -3034,7 +3553,9 @@ def _beam_search( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -3096,9 +3617,9 @@ def _beam_search( ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed - ) + next_token_scores = next_token_scores_processed + beam_scores[ + :, None + ].expand_as(next_token_scores_processed) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -3108,7 +3629,9 @@ def _beam_search( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -3121,12 +3644,18 @@ def _beam_search( # reshape for beam search vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + next_token_scores = next_token_scores.view( + batch_size, num_beams * vocab_size + ) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. n_eos_tokens = len(eos_token_id) if eos_token_id else 0 next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True + next_token_scores, + max(2, 1 + n_eos_tokens) * num_beams, + dim=1, + largest=True, + sorted=True, ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") @@ -3148,10 +3677,15 @@ def _beam_search( beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + input_ids = torch.cat( + [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 + ) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + model_inputs=model_inputs, ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -3159,7 +3693,12 @@ def _beam_search( ) if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + beam_indices = tuple( + ( + beam_indices[beam_idx[i]] + (beam_idx[i],) + for i in range(len(beam_indices)) + ) + ) # increase cur_len cur_len = cur_len + 1 @@ -3361,26 +3900,54 @@ def _beam_sample( ['Wie alt bist du?'] ```""" # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_logits = ( + output_logits + if output_logits is not None + else self.generation_config.output_logits + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -3397,20 +3964,36 @@ def _beam_sample( scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + tuple(() for _ in range(batch_beam_size)) + if (return_dict_in_generate and output_scores) + else None + ) + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores = torch.zeros( + (batch_size, num_beams), dtype=torch.float, device=input_ids.device + ) beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only @@ -3420,7 +4003,9 @@ def _beam_sample( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -3447,10 +4032,12 @@ def _beam_sample( ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed + next_token_scores_processed = logits_warper( + input_ids, next_token_scores_processed ) + next_token_scores = next_token_scores_processed + beam_scores[ + :, None + ].expand_as(next_token_scores_processed) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -3460,7 +4047,9 @@ def _beam_sample( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -3474,14 +4063,18 @@ def _beam_sample( # reshape for beam search vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + next_token_scores = next_token_scores.view( + batch_size, num_beams * vocab_size + ) probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_token_scores, _indices = torch.sort( + next_token_scores, descending=True, dim=1 + ) next_tokens = torch.gather(next_tokens, -1, _indices) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") @@ -3502,10 +4095,15 @@ def _beam_sample( beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + input_ids = torch.cat( + [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 + ) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + model_inputs=model_inputs, ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -3513,7 +4111,12 @@ def _beam_sample( ) if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + beam_indices = tuple( + ( + beam_indices[beam_idx[i]] + (beam_idx[i],) + for i in range(len(beam_indices)) + ) + ) # increase cur_len cur_len = cur_len + 1 @@ -3708,26 +4311,54 @@ def _group_beam_search( ['Wie alt bist du?'] ```""" # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_logits = ( + output_logits + if output_logits is not None + else self.generation_config.output_logits + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -3744,7 +4375,10 @@ def _group_beam_search( batch_beam_size, cur_len = input_ids.shape if return_dict_in_generate and output_scores: - beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] + beam_indices = [ + tuple(() for _ in range(num_sub_beams * batch_size)) + for _ in range(num_beam_groups) + ] else: beam_indices = None @@ -3756,20 +4390,34 @@ def _group_beam_search( # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in # the same group don't produce same tokens everytime. - beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + beam_scores = torch.full( + (batch_size, num_beams), -1e9, dtype=torch.float, device=device + ) beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) @@ -3780,7 +4428,9 @@ def _group_beam_search( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -3788,10 +4438,14 @@ def _group_beam_search( break # predicted tokens in cur_len step - current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + current_tokens = torch.zeros( + batch_size * num_beams, dtype=input_ids.dtype, device=device + ) # indices which will form the beams in the next time step - reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + reordering_indices = torch.zeros( + batch_size * num_beams, dtype=torch.long, device=device + ) # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3821,7 +4475,10 @@ def _group_beam_search( for batch_idx in range(batch_size): batch_group_indices.extend( - [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + [ + batch_idx * num_beams + idx + for idx in range(group_start_idx, group_end_idx) + ] ) group_input_ids = input_ids[batch_group_indices] @@ -3834,28 +4491,43 @@ def _group_beam_search( vocab_size = next_token_scores.shape[-1] next_token_scores_processed = logits_processor( - group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx + group_input_ids, + next_token_scores, + current_tokens=current_tokens, + beam_group_idx=beam_group_idx, + ) + next_token_scores = next_token_scores_processed + beam_scores[ + batch_group_indices + ].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as( + next_token_scores_processed ) - next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) - next_token_scores = next_token_scores.expand_as(next_token_scores_processed) if output_scores: processed_score[batch_group_indices] = next_token_scores_processed # reshape for beam search - next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + next_token_scores = next_token_scores.view( + batch_size, group_size * vocab_size + ) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. n_eos_tokens = len(eos_token_id) if eos_token_id else 0 next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True + next_token_scores, + max(2, 1 + n_eos_tokens) * group_size, + dim=1, + largest=True, + sorted=True, ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless - process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + process_beam_indices = ( + sum(beam_indices, ()) if beam_indices is not None else None + ) beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, @@ -3873,11 +4545,15 @@ def _group_beam_search( if return_dict_in_generate and output_scores: beam_indices[beam_group_idx] = tuple( - beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) + beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) + for i in range(len(beam_indices[0])) ) input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + group_input_ids = torch.cat( + [group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], + dim=-1, + ) current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx @@ -3896,7 +4572,9 @@ def _group_beam_search( raw_logits += (raw_logit_score,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -3911,7 +4589,10 @@ def _group_beam_search( input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + model_inputs=model_inputs, ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4117,28 +4798,59 @@ def _constrained_beam_search( ['Wie alt sind Sie?'] ```""" # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) if len(stopping_criteria) == 0: - warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + warnings.warn( + "You don't have defined any stopping_criteria, this will likely loop forever", + UserWarning, + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_logits = ( + output_logits + if output_logits is not None + else self.generation_config.output_logits + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -4160,22 +4872,38 @@ def _constrained_beam_search( scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + tuple(() for _ in range(batch_beam_size)) + if (return_dict_in_generate and output_scores) + else None + ) + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores = torch.zeros( + (batch_size, num_beams), dtype=torch.float, device=input_ids.device + ) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) @@ -4186,7 +4914,9 @@ def _constrained_beam_search( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -4213,9 +4943,9 @@ def _constrained_beam_search( next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed - ) + next_token_scores = next_token_scores_processed + beam_scores[ + :, None + ].expand_as(next_token_scores_processed) scores_for_all_vocab = next_token_scores.clone() @@ -4227,7 +4957,9 @@ def _constrained_beam_search( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -4241,12 +4973,18 @@ def _constrained_beam_search( # reshape for beam search vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + next_token_scores = next_token_scores.view( + batch_size, num_beams * vocab_size + ) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. n_eos_tokens = len(eos_token_id) if eos_token_id else 0 next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True + next_token_scores, + max(2, 1 + n_eos_tokens) * num_beams, + dim=1, + largest=True, + sorted=True, ) next_indices = (next_tokens / vocab_size).long() @@ -4268,9 +5006,14 @@ def _constrained_beam_search( beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + input_ids = torch.cat( + [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 + ) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + model_inputs=model_inputs, ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4278,12 +5021,19 @@ def _constrained_beam_search( ) if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + beam_indices = tuple( + ( + beam_indices[beam_idx[i]] + (beam_idx[i],) + for i in range(len(beam_indices)) + ) + ) # increase cur_len cur_len = cur_len + 1 - if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + if constrained_beam_scorer.is_done or all( + stopping_criteria(input_ids, scores) + ): if not synced_gpus: break else: @@ -4467,23 +5217,57 @@ def _assisted_decoding( ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) if eos_token_id is not None and pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) + if eos_token_id is not None + else None + ) + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_logits = ( + output_logits + if output_logits is not None + else self.generation_config.output_logits + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -4494,15 +5278,27 @@ def _assisted_decoding( # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) # keep track of which sequences are already finished @@ -4516,7 +5312,9 @@ def _assisted_decoding( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -4526,7 +5324,9 @@ def _assisted_decoding( cur_len = input_ids.shape[-1] # 1. Fetch candidate sequences from a `CandidateGenerator` - candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) + candidate_input_ids, candidate_logits = candidate_generator.get_candidates( + input_ids + ) candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) @@ -4547,11 +5347,17 @@ def _assisted_decoding( # 2.1. Prepare the model inputs candidate_kwargs = copy.copy(model_kwargs) candidate_kwargs = _prepare_attention_mask( - candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder + candidate_kwargs, + candidate_input_ids.shape[1], + self.config.is_encoder_decoder, + ) + candidate_kwargs = _prepare_token_type_ids( + candidate_kwargs, candidate_input_ids.shape[1] ) - candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) - model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) + model_inputs = self.prepare_inputs_for_generation( + candidate_input_ids, **candidate_kwargs + ) # 2.2. Run a forward pass on the candidate sequence outputs = self( @@ -4561,14 +5367,20 @@ def _assisted_decoding( ) # 2.3. Process the new logits - new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present + new_logits = outputs.logits[ + :, -candidate_length - 1 : + ] # excludes the input prompt if present next_token_logits = new_logits.clone() if len(logits_processor) > 0: for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + new_logits[:, i, :] = logits_processor( + candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] + ) if len(logits_warper) > 0: for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + new_logits[:, i, :] = logits_warper( + candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] + ) # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) @@ -4590,12 +5402,17 @@ def _assisted_decoding( else: if do_sample: probs = new_logits.softmax(dim=-1) - selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + selected_tokens = torch.multinomial( + probs[0, :, :], num_samples=1 + ).squeeze(1)[None, :] else: selected_tokens = new_logits.argmax(dim=-1) candidate_new_tokens = candidate_input_ids[:, cur_len:] - n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + n_matches = ( + (~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) + < 1 + ).sum() # Ensure we don't generate beyond max_len or an EOS token if last_assistant_token_is_eos and n_matches == candidate_length: @@ -4612,14 +5429,20 @@ def _assisted_decoding( input_ids = torch.cat((input_ids, valid_tokens), dim=-1) if streamer is not None: streamer.put(valid_tokens.cpu()) + if not streamer.is_running(): + break new_cur_len = input_ids.shape[-1] # 4.2. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 - outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) + outputs.past_key_values = _crop_past_key_values( + self, outputs.past_key_values, new_cache_size + ) # 5. Update the candidate generation strategy if needed - candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + candidate_generator.update_candidate_strategy( + input_ids, new_logits, n_matches + ) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -4640,7 +5463,10 @@ def _assisted_decoding( if output_attentions: if self.config.is_encoder_decoder: cross_attentions = _split_model_outputs( - cross_attentions, outputs.cross_attentions, cur_len, added_len + cross_attentions, + outputs.cross_attentions, + cur_len, + added_len, ) decoder_attentions = _split_model_outputs( decoder_attentions, @@ -4660,15 +5486,24 @@ def _assisted_decoding( if output_hidden_states: if self.config.is_encoder_decoder: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len + decoder_hidden_states, + outputs.decoder_hidden_states, + cur_len, + added_len, ) else: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.hidden_states, cur_len, added_len + decoder_hidden_states, + outputs.hidden_states, + cur_len, + added_len, ) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + model_inputs=model_inputs, ) # if eos_token was found in one sentence, set sentence to finished @@ -4680,7 +5515,9 @@ def _assisted_decoding( .prod(dim=0) ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores + ) # stop when each sentence is finished if unfinished_sequences.max() == 0: @@ -4694,7 +5531,8 @@ def _assisted_decoding( if ( hasattr(candidate_generator, "assistant_model") - and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" + and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule + == "heuristic" ): candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( candidate_generator.num_assistant_tokens @@ -4777,14 +5615,18 @@ def _speculative_sampling( # The selected tokens include the matches (if any) plus the next sampled tokens if n_matches > 0: - valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) + valid_tokens = torch.cat( + (new_candidate_input_ids[:, :n_matches], t), dim=-1 + ) else: valid_tokens = t return valid_tokens, n_matches -def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): +def _split_model_outputs( + outputs, new_outputs, cur_len, added_len, is_decoder_attention=False +): """ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple where each member corresponds to a single generated token. @@ -4824,11 +5666,17 @@ def _ranking_fast( """ norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) - cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] + cosine_matrix = torch.matmul( + norm_context_hidden, norm_next_hidden.transpose(1, 2) + ).squeeze( + -1 + ) # [B*K, S] degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] next_top_k_probs = next_top_k_probs.view(-1) # [B*K] contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty - contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] + contrastive_score = torch.stack( + torch.split(contrastive_score, beam_width) + ) # [B, K] _, selected_idx = contrastive_score.max(dim=-1) # [B] return selected_idx @@ -4851,7 +5699,10 @@ def _split(data, full_batch_size: int, split_size: int = None): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0], tuple): return [ - tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) + tuple( + tuple(tensor[i : i + split_size] for tensor in inner_tuple) + for inner_tuple in data + ) for i in range(0, full_batch_size, split_size) ] @@ -4888,29 +5739,43 @@ def _split_model_inputs( # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them keys = ( - model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() + model_input.__dataclass_fields__.keys() + if hasattr(model_input, "__dataclass_fields__") + else model_input.keys() ) # We only keep keys that are in the model_input keys = [k for k in keys if k in model_input] # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a # ModelOutput object. # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] + bool_keys = [ + k for k in keys if isinstance(model_input[k], bool) or k == "cache_position" + ] keys_to_ignore = ["cache_position", "encoder_outputs"] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] + non_bool_keys = [ + k + for k in keys + if not isinstance(model_input[k], bool) and k not in keys_to_ignore + ] # we split the tensors and tuples of tensors data_split_list = [ - {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys} + { + k: _split(model_input[k], full_batch_size, split_size)[i] + for k in non_bool_keys + } for i in range(full_batch_size // split_size) ] # bool values are the same and replicated for each split bool_data = {k: model_input[k] for k in bool_keys} # encoder_outputs is a ModelOutput object and should be split by its own if "encoder_outputs" in model_input: - encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size) + encoder_outputs_split = _split_model_inputs( + model_input["encoder_outputs"], split_size, full_batch_size + ) data_split_list = [ - {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) + {**data_split, "encoder_outputs": encoder_outputs_split[i]} + for i, data_split in enumerate(data_split_list) ] # Convert each dictionary in the list to an object of the inferred class @@ -4949,11 +5814,17 @@ def _concat(data): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): return tuple( - tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0]))) + tuple( + torch.cat([attr[i][j] for attr in data], dim=0) + for j in range(len(data[0][0])) + ) for i in range(len(data[0])) ) else: - return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0]))) + return tuple( + torch.cat([attr[i] for attr in data], dim=0) + for i in range(len(data[0])) + ) elif isinstance(data[0], (int, float)): # If the elements are integers or floats, return a tensor return torch.tensor(data) From ee07f39a4e77306ce5bbe4dc1f0db1d588f3e731 Mon Sep 17 00:00:00 2001 From: paulh Date: Mon, 11 Mar 2024 09:21:12 +0000 Subject: [PATCH 2/2] format reverse --- src/transformers/generation/utils.py | 1559 ++++++-------------------- 1 file changed, 348 insertions(+), 1211 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 74ddd5de529664..a58fbf70fc28c1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -34,12 +34,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..utils import ( - ModelOutput, - is_accelerate_available, - is_torchdynamo_compiling, - logging, -) +from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( @@ -318,21 +313,15 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput -GreedySearchOutput = Union[ - GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput -] +GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] -ContrastiveSearchOutput = Union[ - ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput -] +ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] # Typing shortcuts GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] -GenerateBeamOutput = Union[ - GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput -] +GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput] GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] @@ -387,9 +376,7 @@ def _prepare_model_inputs( else: input_name = self.main_input_name - model_kwargs = { - k: v for k, v in model_kwargs.items() if v is not None or k != input_name - } + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} # 2. check whether model_input_name is passed as kwarg # if yes and `inputs` is None use kwarg inputs @@ -411,9 +398,7 @@ def _prepare_model_inputs( if input_name == "input_ids" and "inputs_embeds" in model_kwargs: if not self.config.is_encoder_decoder: has_inputs_embeds_forwarding = "inputs_embeds" in set( - inspect.signature( - self.prepare_inputs_for_generation - ).parameters.keys() + inspect.signature(self.prepare_inputs_for_generation).parameters.keys() ) if not has_inputs_embeds_forwarding: raise ValueError( @@ -423,22 +408,16 @@ def _prepare_model_inputs( ) # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of # the attention mask) can rely on the actual model input. - model_kwargs["input_ids"] = ( - self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, model_kwargs=model_kwargs - ) + model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs=model_kwargs ) else: if inputs is not None: - raise ValueError( - "You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one." - ) + raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" # 4. if `inputs` is still None, try to create `input_ids` from BOS token - inputs = self._maybe_initialize_input_ids_for_generation( - inputs, bos_token_id, model_kwargs - ) + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) return inputs, input_name, model_kwargs def _maybe_initialize_input_ids_for_generation( @@ -458,9 +437,7 @@ def _maybe_initialize_input_ids_for_generation( return torch.ones(shape, dtype=torch.long, device=self.device) * -100 if bos_token_id is None: - raise ValueError( - "`bos_token_id` has to be defined when no `input_ids` are provided." - ) + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with # soft-prompting or in multimodal implementations built on top of decoder-only language models. @@ -472,10 +449,7 @@ def _maybe_initialize_input_ids_for_generation( if "inputs_embeds" in model_kwargs: return torch.ones((batch_size, 0), dtype=torch.long, device=self.device) - return ( - torch.ones((batch_size, 1), dtype=torch.long, device=self.device) - * bos_token_id - ) + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( self, @@ -483,32 +457,20 @@ def _prepare_attention_mask_for_generation( pad_token_id: Optional[int], eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: - is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [ - torch.int, - torch.long, - ] + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( - pad_token_id not in eos_token_id - ) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) # Check if input is input_ids and padded -> only then is attention_mask defined - if ( - is_input_ids - and is_pad_token_in_inputs - and is_pad_token_not_equal_to_eos_token_id - ): + if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: return inputs.ne(pad_token_id).long() else: return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) def _prepare_encoder_decoder_kwargs_for_generation( - self, - inputs_tensor: torch.Tensor, - model_kwargs, - model_input_name: Optional[str] = None, + self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None ) -> Dict[str, Any]: # 1. get encoder encoder = self.get_encoder() @@ -528,20 +490,14 @@ def _prepare_encoder_decoder_kwargs_for_generation( if not any(argument.startswith(p) for p in irrelevant_prefix) } encoder_signature = set(inspect.signature(encoder.forward).parameters) - encoder_accepts_wildcard = ( - "kwargs" in encoder_signature or "model_kwargs" in encoder_signature - ) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature if not encoder_accepts_wildcard: encoder_kwargs = { - argument: value - for argument, value in encoder_kwargs.items() - if argument in encoder_signature + argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature } # 3. make sure that encoder returns `ModelOutput` - model_input_name = ( - model_input_name if model_input_name is not None else self.main_input_name - ) + model_input_name = model_input_name if model_input_name is not None else self.main_input_name encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) @@ -568,9 +524,7 @@ def _prepare_decoder_input_ids_for_generation( decoder_input_ids = None # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - decoder_start_token_id = self._get_decoder_start_token_id( - decoder_start_token_id, bos_token_id - ) + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) if device is None: device = self.device if isinstance(decoder_start_token_id, list): @@ -578,24 +532,18 @@ def _prepare_decoder_input_ids_for_generation( raise ValueError( f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}" ) - decoder_input_ids_start = torch.tensor( - decoder_start_token_id, dtype=torch.long, device=device - ) + decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device) decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) else: decoder_input_ids_start = ( - torch.ones((batch_size, 1), dtype=torch.long, device=device) - * decoder_start_token_id + torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id ) # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: decoder_input_ids = decoder_input_ids_start # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token - elif ( - self.config.model_type == "vision-encoder-decoder" - and "donut" in self.name_or_path.lower() - ): + elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): pass elif self.config.model_type in ["whisper"]: pass @@ -608,16 +556,11 @@ def _prepare_decoder_input_ids_for_generation( isinstance(decoder_start_token_id, torch.Tensor) and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item() ): - decoder_input_ids = torch.cat( - [decoder_input_ids_start, decoder_input_ids], dim=-1 - ) + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = torch.cat( - ( - torch.ones_like(decoder_attention_mask)[:, :1], - decoder_attention_mask, - ), + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), dim=-1, ) model_kwargs["decoder_attention_mask"] = decoder_attention_mask @@ -625,20 +568,14 @@ def _prepare_decoder_input_ids_for_generation( return decoder_input_ids, model_kwargs def _get_decoder_start_token_id( - self, - decoder_start_token_id: Union[int, List[int]] = None, - bos_token_id: int = None, + self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None ) -> int: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.generation_config.decoder_start_token_id ) - bos_token_id = ( - bos_token_id - if bos_token_id is not None - else self.generation_config.bos_token_id - ) + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id @@ -659,12 +596,8 @@ def _expand_inputs_for_generation( def _expand_dict_for_generation(dict_to_expand): for key in dict_to_expand: - if dict_to_expand[key] is not None and isinstance( - dict_to_expand[key], torch.Tensor - ): - dict_to_expand[key] = dict_to_expand[key].repeat_interleave( - expand_size, dim=0 - ) + if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand if input_ids is not None: @@ -674,18 +607,12 @@ def _expand_dict_for_generation(dict_to_expand): if is_encoder_decoder: if model_kwargs.get("encoder_outputs") is None: - raise ValueError( - "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." - ) - model_kwargs["encoder_outputs"] = _expand_dict_for_generation( - model_kwargs["encoder_outputs"] - ) + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) return input_ids, model_kwargs - def _extract_past_from_model_output( - self, outputs: ModelOutput, standardize_cache_format: bool = False - ): + def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False): past_key_values = None if "past_key_values" in outputs: past_key_values = outputs.past_key_values @@ -697,9 +624,7 @@ def _extract_past_from_model_output( # Bloom fix: standardizes the cache format when requested if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"): batch_size = outputs.logits.shape[0] - past_key_values = self._convert_to_standard_cache( - past_key_values, batch_size=batch_size - ) + past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size) return past_key_values def _update_model_kwargs_for_generation( @@ -720,32 +645,21 @@ def _update_model_kwargs_for_generation( # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 - ) + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) if not is_encoder_decoder: # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [ - attention_mask, - attention_mask.new_ones((attention_mask.shape[0], 1)), - ], - dim=-1, + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) else: # update decoder attention mask if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] model_kwargs["decoder_attention_mask"] = torch.cat( - [ - decoder_attention_mask, - decoder_attention_mask.new_ones( - (decoder_attention_mask.shape[0], 1) - ), - ], + [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], dim=-1, ) @@ -811,52 +725,23 @@ def _get_logits_warper( # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` - if ( - generation_config.temperature is not None - and generation_config.temperature != 1.0 - ): + if generation_config.temperature is not None and generation_config.temperature != 1.0: warpers.append(TemperatureLogitsWarper(generation_config.temperature)) if generation_config.top_k is not None and generation_config.top_k != 0: - warpers.append( - TopKLogitsWarper( - top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep - ) - ) + warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: warpers.append( - TopPLogitsWarper( - top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep - ) + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) ) - if ( - generation_config.typical_p is not None - and generation_config.typical_p < 1.0 - ): + if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: warpers.append( - TypicalLogitsWarper( - mass=generation_config.typical_p, - min_tokens_to_keep=min_tokens_to_keep, - ) + EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) ) - if ( - generation_config.epsilon_cutoff is not None - and 0.0 < generation_config.epsilon_cutoff < 1.0 - ): + if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: warpers.append( - EpsilonLogitsWarper( - epsilon=generation_config.epsilon_cutoff, - min_tokens_to_keep=min_tokens_to_keep, - ) - ) - if ( - generation_config.eta_cutoff is not None - and 0.0 < generation_config.eta_cutoff < 1.0 - ): - warpers.append( - EtaLogitsWarper( - epsilon=generation_config.eta_cutoff, - min_tokens_to_keep=min_tokens_to_keep, - ) + EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep) ) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: @@ -881,10 +766,7 @@ def _get_logits_processor( # instantiate processors list processors = LogitsProcessorList() - if ( - generation_config.guidance_scale is not None - and generation_config.guidance_scale != 1 - ): + if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1: processors.append( UnbatchedClassifierFreeGuidanceLogitsProcessor( generation_config.guidance_scale, @@ -895,16 +777,9 @@ def _get_logits_processor( ) ) if generation_config.sequence_bias is not None: - processors.append( - SequenceBiasLogitsProcessor( - sequence_bias=generation_config.sequence_bias - ) - ) + processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias)) - if ( - generation_config.diversity_penalty is not None - and generation_config.diversity_penalty > 0.0 - ): + if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: processors.append( HammingDiversityLogitsProcessor( diversity_penalty=generation_config.diversity_penalty, @@ -918,51 +793,30 @@ def _get_logits_processor( ): processors.append( EncoderRepetitionPenaltyLogitsProcessor( - penalty=generation_config.encoder_repetition_penalty, - encoder_input_ids=encoder_input_ids, + penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids ) ) - if ( - generation_config.repetition_penalty is not None - and generation_config.repetition_penalty != 1.0 - ): - processors.append( - RepetitionPenaltyLogitsProcessor( - penalty=generation_config.repetition_penalty - ) - ) - if ( - generation_config.no_repeat_ngram_size is not None - and generation_config.no_repeat_ngram_size > 0 - ): - processors.append( - NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size) - ) + if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) + if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: + processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) if ( generation_config.encoder_no_repeat_ngram_size is not None and generation_config.encoder_no_repeat_ngram_size > 0 ): processors.append( - EncoderNoRepeatNGramLogitsProcessor( - generation_config.encoder_no_repeat_ngram_size, encoder_input_ids - ) + EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids) ) if generation_config.bad_words_ids is not None: processors.append( - NoBadWordsLogitsProcessor( - generation_config.bad_words_ids, generation_config.eos_token_id - ) + NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) ) if ( generation_config.min_length is not None and generation_config.eos_token_id is not None and generation_config.min_length > 0 ): - processors.append( - MinLengthLogitsProcessor( - generation_config.min_length, generation_config.eos_token_id - ) - ) + processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) if ( generation_config.min_new_tokens is not None and generation_config.eos_token_id is not None @@ -970,27 +824,20 @@ def _get_logits_processor( ): processors.append( MinNewTokensLengthLogitsProcessor( - input_ids_seq_length, - generation_config.min_new_tokens, - generation_config.eos_token_id, + input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id ) ) if prefix_allowed_tokens_fn is not None: processors.append( PrefixConstrainedLogitsProcessor( - prefix_allowed_tokens_fn, - generation_config.num_beams // generation_config.num_beam_groups, + prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups ) ) if generation_config.forced_bos_token_id is not None: - processors.append( - ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id) - ) + processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) if generation_config.forced_eos_token_id is not None: processors.append( - ForcedEOSTokenLogitsProcessor( - generation_config.max_length, generation_config.forced_eos_token_id - ) + ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) ) if generation_config.remove_invalid_values is True: processors.append(InfNanRemoveLogitsProcessor()) @@ -1003,31 +850,22 @@ def _get_logits_processor( ) ) if generation_config.suppress_tokens is not None: - processors.append( - SuppressTokensLogitsProcessor(generation_config.suppress_tokens) - ) + processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens)) if generation_config.begin_suppress_tokens is not None: begin_index = input_ids_seq_length begin_index = ( begin_index - if ( - input_ids_seq_length > 1 - or generation_config.forced_bos_token_id is None - ) + if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) else begin_index + 1 ) if generation_config.forced_decoder_ids is not None: # generation starts after the last token that is forced begin_index += generation_config.forced_decoder_ids[-1][0] processors.append( - SuppressTokensAtBeginLogitsProcessor( - generation_config.begin_suppress_tokens, begin_index - ) + SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) ) if generation_config.forced_decoder_ids is not None: - processors.append( - ForceTokensLogitsProcessor(generation_config.forced_decoder_ids) - ) + processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) processors = self._merge_criteria_processor_list(processors, logits_processor) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: @@ -1035,15 +873,11 @@ def _get_logits_processor( return processors def _get_stopping_criteria( - self, - generation_config: GenerationConfig, - stopping_criteria: Optional[StoppingCriteriaList], + self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: - max_position_embeddings = getattr( - self.config, "max_position_embeddings", None - ) + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) criteria.append( MaxLengthCriteria( max_length=generation_config.max_length, @@ -1065,11 +899,7 @@ def _merge_criteria_processor_list( for default in default_list: for custom in custom_list: if type(custom) is type(default): - object_type = ( - "stopping criteria" - if isinstance(custom, StoppingCriteria) - else "logits processor" - ) + object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" raise ValueError( f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" f" `.generate()`, but it has already been created with the values {default}. {default} has been" @@ -1165,9 +995,7 @@ def compute_transition_scores( # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent # to a beam search approach were the first (and only) beam is always selected if beam_indices is None: - beam_indices = ( - torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) - ) + beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) beam_indices = beam_indices.expand(-1, len(scores)) # 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being @@ -1233,10 +1061,7 @@ def _validate_model_class(self): def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # If a `Cache` instance is passed, checks whether the model is compatible with it - if ( - isinstance(model_kwargs.get("past_key_values", None), Cache) - and not self._supports_cache_class - ): + if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: raise ValueError( f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " "check the model documentation for supported cache formats." @@ -1248,9 +1073,7 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): model_kwargs.pop(key, None) unused_model_args = [] - model_args = set( - inspect.signature(self.prepare_inputs_for_generation).parameters - ) + model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) if "kwargs" in model_args or "model_kwargs" in model_args: @@ -1295,17 +1118,11 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): " generate arguments will also show up in this list)" ) - def _validate_generated_length( - self, generation_config, input_ids_length, has_default_max_length - ): + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): """Performs validation related to the resulting generated length""" # 1. Max length warnings related to poor parameterization - if ( - has_default_max_length - and generation_config.max_new_tokens is None - and generation_config.max_length == 20 - ): + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: # 20 is the default max_length of the generation config warnings.warn( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " @@ -1314,9 +1131,7 @@ def _validate_generated_length( UserWarning, ) if input_ids_length >= generation_config.max_length: - input_ids_string = ( - "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - ) + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" raise ValueError( f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" @@ -1329,15 +1144,13 @@ def _validate_generated_length( "increase the maximum length." ) if has_default_max_length: - min_length_error_suffix += f" Note that `max_length` is set to {generation_config.max_length}, its default value." - if ( - generation_config.min_length is not None - and generation_config.min_length > generation_config.max_length - ): + min_length_error_suffix += ( + f" Note that `max_length` is set to {generation_config.max_length}, its default value." + ) + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: warnings.warn( f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than" - f" the maximum possible length ({generation_config.max_length})." - + min_length_error_suffix, + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, UserWarning, ) if generation_config.min_new_tokens is not None: @@ -1346,8 +1159,7 @@ def _validate_generated_length( warnings.warn( f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when " f"added to the prompt length ({input_ids_length}), is larger than" - f" the maximum possible length ({generation_config.max_length})." - + min_length_error_suffix, + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, UserWarning, ) @@ -1372,8 +1184,7 @@ def _prepare_generation_config( if ( not is_torchdynamo_compiling() and self.generation_config._from_model_config - and self.generation_config._original_object_hash - == hash(self.generation_config) + and self.generation_config._original_object_hash == hash(self.generation_config) and self.config._has_non_default_generation_parameters() ): new_generation_config = GenerationConfig.from_model_config(self.config) @@ -1392,9 +1203,7 @@ def _prepare_generation_config( if is_torchdynamo_compiling(): model_kwargs = kwargs generate_attributes_in_kwargs = [ - key - for key, value in kwargs.items() - if getattr(generation_config, key, None) != value + key for key, value in kwargs.items() if getattr(generation_config, key, None) != value ] if len(generate_attributes_in_kwargs) > 0: raise ValueError( @@ -1414,9 +1223,7 @@ def generate( generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, @@ -1509,9 +1316,7 @@ def generate( """ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() - generation_config, model_kwargs = self._prepare_generation_config( - generation_config, **kwargs - ) + generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined @@ -1520,19 +1325,10 @@ def generate( synced_gpus = True else: synced_gpus = False - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if ( - generation_config.pad_token_id is None - and generation_config.eos_token_id is not None - ): + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " @@ -1541,9 +1337,7 @@ def generate( eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] - logger.warning( - f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." - ) + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") generation_config.pad_token_id = eos_token_id # 3. Define model inputs @@ -1566,22 +1360,12 @@ def generate( else: model_kwargs["use_cache"] = generation_config.use_cache - accepts_attention_mask = "attention_mask" in set( - inspect.signature(self.forward).parameters.keys() - ) + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs - if ( - model_kwargs.get("attention_mask", None) is None - and requires_attention_mask - and accepts_attention_mask - ): - model_kwargs["attention_mask"] = ( - self._prepare_attention_mask_for_generation( - inputs_tensor, - generation_config.pad_token_id, - generation_config.eos_token_id, - ) + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ) # decoder-only models should use left-padding for generation @@ -1591,8 +1375,7 @@ def generate( if ( generation_config.pad_token_id is not None and len(inputs_tensor.shape) == 2 - and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) - > 0 + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " @@ -1617,21 +1400,14 @@ def generate( device=inputs_tensor.device, ) else: - input_ids = ( - inputs_tensor - if model_input_name == "input_ids" - else model_kwargs.pop("input_ids") - ) + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") if streamer is not None: streamer.put(input_ids.cpu()) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] - has_default_max_length = ( - kwargs.get("max_length") is None - and generation_config.max_length is not None - ) + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if generation_config.max_new_tokens is not None: if not has_default_max_length and generation_config.max_length is not None: logger.warning( @@ -1640,9 +1416,7 @@ def generate( "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) - generation_config.max_length = ( - generation_config.max_new_tokens + input_ids_length - ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_length # otherwise the total length [inputs-embeds-len + new-tokens-len] will go beyond indicated `max_length`` elif ( @@ -1651,9 +1425,7 @@ def generate( and not self.config.is_encoder_decoder ): generation_config.max_length -= inputs_tensor.shape[1] - generation_config.min_length = max( - generation_config.min_length - inputs_tensor.shape[1], 0 - ) + generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation == "static": @@ -1667,15 +1439,9 @@ def generate( "The `generation_config` defines a `cache_implementation` that is not compatible with this model." " Make sure it has a `_setup_cache` function." ) - self._setup_cache( - cache_cls, - max_batch_size=batch_size, - max_cache_len=generation_config.max_length, - ) + self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) - self._validate_generated_length( - generation_config, input_ids_length, has_default_max_length - ) + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) @@ -1720,9 +1486,7 @@ def generate( f"but is {generation_config.num_return_sequences}." ) if batch_size > 1: - raise ValueError( - "assisted generate is only supported for batch_size = 1" - ) + raise ValueError("assisted generate is only supported for batch_size = 1") if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") @@ -1742,11 +1506,7 @@ def generate( candidate_generator=candidate_generator, do_sample=generation_config.do_sample, logits_processor=prepared_logits_processor, - logits_warper=( - self._get_logits_warper(generation_config) - if generation_config.do_sample - else None - ), + logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, @@ -1952,15 +1712,10 @@ def typeerror(): if isinstance(word_ids[0], list): if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() - if any( - not isinstance(token_ids, list) for token_ids in word_ids - ): + if any(not isinstance(token_ids, list) for token_ids in word_ids): typeerror() if any( - any( - (not isinstance(token_id, int) or token_id < 0) - for token_id in token_ids - ) + any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) for token_ids in word_ids ): typeerror() @@ -1969,10 +1724,7 @@ def typeerror(): else: if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() - if any( - (not isinstance(token_id, int) or token_id < 0) - for token_id in word_ids - ): + if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): typeerror() constraint = PhrasalConstraint(word_ids) @@ -2136,56 +1888,22 @@ def _contrastive_search( ['DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMindā€™s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it'] ```""" # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - logits_warper = ( - logits_warper if logits_warper is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - eos_token_id = ( - eos_token_id - if eos_token_id is not None - else self.generation_config.eos_token_id - ) - sequential = ( - sequential if sequential is not None else self.generation_config.low_memory - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + sequential = sequential if sequential is not None else self.generation_config.low_memory if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = ( - torch.tensor(eos_token_id).to(input_ids.device) - if eos_token_id is not None - else None - ) - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) - output_logits = ( - output_logits - if output_logits is not None - else self.generation_config.output_logits - ) + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions + output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -2196,33 +1914,19 @@ def _contrastive_search( # init attention / hidden states / scores tuples raw_logits = () if (return_dict_in_generate and output_logits) else None scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones( - input_ids.shape[0], dtype=torch.long, device=input_ids.device - ) + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only batch_size = input_ids.shape[0] @@ -2231,9 +1935,7 @@ def _contrastive_search( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -2245,17 +1947,12 @@ def _contrastive_search( if model_kwargs.get("past_key_values") is None: # prepare inputs model_kwargs["use_cache"] = True - model_inputs = self.prepare_inputs_for_generation( - input_ids, **model_kwargs - ) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save # the `encoder_outputs` outputs = self( - **model_inputs, - return_dict=True, - output_hidden_states=True, - output_attentions=output_attentions, + **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with @@ -2278,9 +1975,7 @@ def _contrastive_search( if not sequential: # Expands model inputs top_k times, for batched forward passes (akin to beam search). _, model_kwargs = self._expand_inputs_for_generation( - expand_size=top_k, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, + expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) past_key_values = model_kwargs.get("past_key_values") @@ -2301,12 +1996,8 @@ def _contrastive_search( # contrastive_search main logic start: # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # degeneration penalty - processed_logit_for_next_step = logits_processor( - input_ids, logit_for_next_step - ) - processed_logit_for_next_step = logits_warper( - input_ids, processed_logit_for_next_step - ) + processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) + processed_logit_for_next_step = logits_warper(input_ids, processed_logit_for_next_step) next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1) top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) @@ -2319,9 +2010,7 @@ def _contrastive_search( scores += (processed_logit_for_next_step,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -2350,9 +2039,7 @@ def _contrastive_search( all_outputs = [] for i in range(top_k): # compute the candidate tokens by the language model and collect their hidden_states - next_model_inputs = self.prepare_inputs_for_generation( - top_k_ids[:, i].view(-1, 1), **model_kwargs - ) + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) outputs = self( **next_model_inputs, @@ -2366,9 +2053,7 @@ def _contrastive_search( else: # compute the candidate tokens by the language model and collect their hidden_states # assembles top_k_ids into batch of size k - next_model_inputs = self.prepare_inputs_for_generation( - top_k_ids.view(-1, 1), **model_kwargs - ) + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) outputs = self( **next_model_inputs, @@ -2391,9 +2076,7 @@ def _contrastive_search( # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't # introduce (noticeable) slowdowns on single-device runs. - selected_idx = _ranking_fast( - context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k - ) + selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) selected_idx = selected_idx.to("cpu") # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing @@ -2402,15 +2085,11 @@ def _contrastive_search( next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) next_hidden = next_hidden[range(batch_size), selected_idx, :] - last_hidden_states = torch.cat( - [last_hidden_states, next_hidden.unsqueeze(1)], dim=1 - ) + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) next_decoder_hidden_states = () for layer in full_hidden_states: - layer = torch.stack(torch.split(layer, top_k))[ - range(batch_size), selected_idx, : - ] + layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] next_decoder_hidden_states += (layer,) # generate past_key_values cache of only the selected token @@ -2428,27 +2107,19 @@ def _contrastive_search( next_past_key_values = selected_outputs["past_key_values"] else: - next_past_key_values = self._extract_past_from_model_output( - outputs, standardize_cache_format=True - ) + next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) new_key_values = () for layer in next_past_key_values: items = () # item is either the key or the value matrix for item in layer: - item = torch.stack( - torch.split(item, top_k, dim=0) - ) # [B, K, num_head, seq_len, esz] - item = item[ - range(batch_size), selected_idx, ... - ] # [B, num_head, seq_len, esz] + item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] + item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz] items += (item,) new_key_values += (items,) next_past_key_values = new_key_values - logit_for_next_step = torch.stack(torch.split(logits, top_k))[ - range(batch_size), selected_idx, : - ] + logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration if self.config.is_encoder_decoder: @@ -2456,14 +2127,10 @@ def _contrastive_search( next_step_decoder_attentions = () if output_attentions: for layer in outputs.cross_attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[ - range(batch_size), selected_idx, ... - ] + layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] next_step_cross_attentions += (layer,) for layer in outputs.decoder_attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[ - range(batch_size), selected_idx, ... - ] + layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] next_step_decoder_attentions += (layer,) outputs = Seq2SeqLMOutput( past_key_values=next_past_key_values, @@ -2475,9 +2142,7 @@ def _contrastive_search( next_step_attentions = () if output_attentions: for layer in outputs.attentions: - layer = torch.stack(torch.split(layer, top_k, dim=0))[ - range(batch_size), selected_idx, ... - ] + layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] next_step_attentions += (layer,) outputs = CausalLMOutputWithPast( past_key_values=next_past_key_values, @@ -2492,12 +2157,8 @@ def _contrastive_search( # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -2506,24 +2167,17 @@ def _contrastive_search( if not streamer.is_running(): break model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - model_inputs=model_inputs, + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished - unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores - ) + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) if unfinished_sequences.max() == 0: this_peer_finished = True @@ -2690,54 +2344,26 @@ def _greedy_search( ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - eos_token_id = ( - eos_token_id - if eos_token_id is not None - else self.generation_config.eos_token_id - ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = ( - torch.tensor(eos_token_id).to(input_ids.device) - if eos_token_id is not None - else None - ) - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions + output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -2748,42 +2374,26 @@ def _greedy_search( # init attention / hidden states / scores tuples raw_logits = () if (return_dict_in_generate and output_logits) else None scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones( - input_ids.shape[0], dtype=torch.long, device=input_ids.device - ) + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -2817,9 +2427,7 @@ def _greedy_search( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -2837,12 +2445,8 @@ def _greedy_search( # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -2850,7 +2454,6 @@ def _greedy_search( streamer.put(next_tokens.cpu()) if not streamer.is_running(): break - model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, @@ -2861,14 +2464,10 @@ def _greedy_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores - ) + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) # stop when each sentence is finished if unfinished_sequences.max() == 0: @@ -3044,62 +2643,28 @@ def _sample( ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.'] ```""" # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - logits_warper = ( - logits_warper if logits_warper is not None else LogitsProcessorList() - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - eos_token_id = ( - eos_token_id - if eos_token_id is not None - else self.generation_config.eos_token_id - ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = ( - torch.tensor(eos_token_id).to(input_ids.device) - if eos_token_id is not None - else None - ) - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) - output_logits = ( - output_logits - if output_logits is not None - else self.generation_config.output_logits - ) + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions + output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -3110,33 +2675,19 @@ def _sample( # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones( - input_ids.shape[0], dtype=torch.long, device=input_ids.device - ) + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only # auto-regressive generation @@ -3144,9 +2695,7 @@ def _sample( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -3181,9 +2730,7 @@ def _sample( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -3202,12 +2749,8 @@ def _sample( # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -3216,23 +2759,16 @@ def _sample( if not streamer.is_running(): break model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - model_inputs=model_inputs, + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores - ) + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) # stop when each sentence is finished if unfinished_sequences.max() == 0: @@ -3434,62 +2970,29 @@ def _beam_search( ['Wie alt bist du?'] ```""" # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - sequential = ( - sequential if sequential is not None else self.generation_config.low_memory - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + sequential = sequential if sequential is not None else self.generation_config.low_memory if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: - warnings.warn( - "You don't have defined any stopping_criteria, this will likely loop forever", - UserWarning, - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - eos_token_id = ( - eos_token_id - if eos_token_id is not None - else self.generation_config.eos_token_id - ) + warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) - output_logits = ( - output_logits - if output_logits is not None - else self.generation_config.output_logits - ) + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions + output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -3511,38 +3014,22 @@ def _beam_search( scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( - tuple(() for _ in range(batch_beam_size)) - if (return_dict_in_generate and output_scores) - else None - ) - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None + tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores = torch.zeros( - (batch_size, num_beams), dtype=torch.float, device=input_ids.device - ) + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) @@ -3553,9 +3040,7 @@ def _beam_search( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -3617,9 +3102,9 @@ def _beam_search( ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[ - :, None - ].expand_as(next_token_scores_processed) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -3629,9 +3114,7 @@ def _beam_search( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -3644,18 +3127,12 @@ def _beam_search( # reshape for beam search vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view( - batch_size, num_beams * vocab_size - ) + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. n_eos_tokens = len(eos_token_id) if eos_token_id else 0 next_token_scores, next_tokens = torch.topk( - next_token_scores, - max(2, 1 + n_eos_tokens) * num_beams, - dim=1, - largest=True, - sorted=True, + next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") @@ -3677,15 +3154,10 @@ def _beam_search( beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] - input_ids = torch.cat( - [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 - ) + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - model_inputs=model_inputs, + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -3693,12 +3165,7 @@ def _beam_search( ) if return_dict_in_generate and output_scores: - beam_indices = tuple( - ( - beam_indices[beam_idx[i]] + (beam_idx[i],) - for i in range(len(beam_indices)) - ) - ) + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 @@ -3900,54 +3367,26 @@ def _beam_sample( ['Wie alt bist du?'] ```""" # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - eos_token_id = ( - eos_token_id - if eos_token_id is not None - else self.generation_config.eos_token_id - ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) - output_logits = ( - output_logits - if output_logits is not None - else self.generation_config.output_logits - ) + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions + output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -3964,36 +3403,20 @@ def _beam_sample( scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( - tuple(() for _ in range(batch_beam_size)) - if (return_dict_in_generate and output_scores) - else None - ) - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None + tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) - beam_scores = torch.zeros( - (batch_size, num_beams), dtype=torch.float, device=input_ids.device - ) + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only @@ -4003,9 +3426,7 @@ def _beam_sample( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -4032,12 +3453,10 @@ def _beam_sample( ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores_processed = logits_warper( - input_ids, next_token_scores_processed + next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed ) - next_token_scores = next_token_scores_processed + beam_scores[ - :, None - ].expand_as(next_token_scores_processed) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -4047,9 +3466,7 @@ def _beam_sample( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -4063,18 +3480,14 @@ def _beam_sample( # reshape for beam search vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view( - batch_size, num_beams * vocab_size - ) + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - next_token_scores, _indices = torch.sort( - next_token_scores, descending=True, dim=1 - ) + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) next_tokens = torch.gather(next_tokens, -1, _indices) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") @@ -4095,15 +3508,10 @@ def _beam_sample( beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] - input_ids = torch.cat( - [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 - ) + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - model_inputs=model_inputs, + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4111,12 +3519,7 @@ def _beam_sample( ) if return_dict_in_generate and output_scores: - beam_indices = tuple( - ( - beam_indices[beam_idx[i]] + (beam_idx[i],) - for i in range(len(beam_indices)) - ) - ) + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 @@ -4311,54 +3714,26 @@ def _group_beam_search( ['Wie alt bist du?'] ```""" # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - eos_token_id = ( - eos_token_id - if eos_token_id is not None - else self.generation_config.eos_token_id - ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) - output_logits = ( - output_logits - if output_logits is not None - else self.generation_config.output_logits - ) + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions + output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -4375,10 +3750,7 @@ def _group_beam_search( batch_beam_size, cur_len = input_ids.shape if return_dict_in_generate and output_scores: - beam_indices = [ - tuple(() for _ in range(num_sub_beams * batch_size)) - for _ in range(num_beam_groups) - ] + beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] else: beam_indices = None @@ -4390,34 +3762,20 @@ def _group_beam_search( # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in # the same group don't produce same tokens everytime. - beam_scores = torch.full( - (batch_size, num_beams), -1e9, dtype=torch.float, device=device - ) + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) @@ -4428,9 +3786,7 @@ def _group_beam_search( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -4438,14 +3794,10 @@ def _group_beam_search( break # predicted tokens in cur_len step - current_tokens = torch.zeros( - batch_size * num_beams, dtype=input_ids.dtype, device=device - ) + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # indices which will form the beams in the next time step - reordering_indices = torch.zeros( - batch_size * num_beams, dtype=torch.long, device=device - ) + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -4475,10 +3827,7 @@ def _group_beam_search( for batch_idx in range(batch_size): batch_group_indices.extend( - [ - batch_idx * num_beams + idx - for idx in range(group_start_idx, group_end_idx) - ] + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] ) group_input_ids = input_ids[batch_group_indices] @@ -4491,43 +3840,28 @@ def _group_beam_search( vocab_size = next_token_scores.shape[-1] next_token_scores_processed = logits_processor( - group_input_ids, - next_token_scores, - current_tokens=current_tokens, - beam_group_idx=beam_group_idx, - ) - next_token_scores = next_token_scores_processed + beam_scores[ - batch_group_indices - ].unsqueeze(-1) - next_token_scores = next_token_scores.expand_as( - next_token_scores_processed + group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) if output_scores: processed_score[batch_group_indices] = next_token_scores_processed # reshape for beam search - next_token_scores = next_token_scores.view( - batch_size, group_size * vocab_size - ) + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. n_eos_tokens = len(eos_token_id) if eos_token_id else 0 next_token_scores, next_tokens = torch.topk( - next_token_scores, - max(2, 1 + n_eos_tokens) * group_size, - dim=1, - largest=True, - sorted=True, + next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless - process_beam_indices = ( - sum(beam_indices, ()) if beam_indices is not None else None - ) + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, @@ -4545,15 +3879,11 @@ def _group_beam_search( if return_dict_in_generate and output_scores: beam_indices[beam_group_idx] = tuple( - beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) - for i in range(len(beam_indices[0])) + beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) ) input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat( - [group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], - dim=-1, - ) + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx @@ -4572,9 +3902,7 @@ def _group_beam_search( raw_logits += (raw_logit_score,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -4589,10 +3917,7 @@ def _group_beam_search( input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - model_inputs=model_inputs, + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4798,59 +4123,28 @@ def _constrained_beam_search( ['Wie alt sind Sie?'] ```""" # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: - warnings.warn( - "You don't have defined any stopping_criteria, this will likely loop forever", - UserWarning, - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - eos_token_id = ( - eos_token_id - if eos_token_id is not None - else self.generation_config.eos_token_id - ) + warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) - output_logits = ( - output_logits - if output_logits is not None - else self.generation_config.output_logits - ) + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions + output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -4872,38 +4166,22 @@ def _constrained_beam_search( scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( - tuple(() for _ in range(batch_beam_size)) - if (return_dict_in_generate and output_scores) - else None - ) - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None + tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores = torch.zeros( - (batch_size, num_beams), dtype=torch.float, device=input_ids.device - ) + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) @@ -4914,9 +4192,7 @@ def _constrained_beam_search( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -4943,9 +4219,9 @@ def _constrained_beam_search( next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[ - :, None - ].expand_as(next_token_scores_processed) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) scores_for_all_vocab = next_token_scores.clone() @@ -4957,9 +4233,7 @@ def _constrained_beam_search( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -4973,18 +4247,12 @@ def _constrained_beam_search( # reshape for beam search vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view( - batch_size, num_beams * vocab_size - ) + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. n_eos_tokens = len(eos_token_id) if eos_token_id else 0 next_token_scores, next_tokens = torch.topk( - next_token_scores, - max(2, 1 + n_eos_tokens) * num_beams, - dim=1, - largest=True, - sorted=True, + next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True ) next_indices = (next_tokens / vocab_size).long() @@ -5006,14 +4274,9 @@ def _constrained_beam_search( beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] - input_ids = torch.cat( - [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 - ) + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - model_inputs=model_inputs, + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -5021,19 +4284,12 @@ def _constrained_beam_search( ) if return_dict_in_generate and output_scores: - beam_indices = tuple( - ( - beam_indices[beam_idx[i]] + (beam_idx[i],) - for i in range(len(beam_indices)) - ) - ) + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 - if constrained_beam_scorer.is_done or all( - stopping_criteria(input_ids, scores) - ): + if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): if not synced_gpus: break else: @@ -5217,57 +4473,23 @@ def _assisted_decoding( ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - logits_warper = ( - logits_warper if logits_warper is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - eos_token_id = ( - eos_token_id - if eos_token_id is not None - else self.generation_config.eos_token_id - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None and pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = ( - torch.tensor(eos_token_id).to(input_ids.device) - if eos_token_id is not None - else None - ) - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) - output_logits = ( - output_logits - if output_logits is not None - else self.generation_config.output_logits - ) + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions + output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -5278,27 +4500,15 @@ def _assisted_decoding( # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished @@ -5312,9 +4522,7 @@ def _assisted_decoding( if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -5324,9 +4532,7 @@ def _assisted_decoding( cur_len = input_ids.shape[-1] # 1. Fetch candidate sequences from a `CandidateGenerator` - candidate_input_ids, candidate_logits = candidate_generator.get_candidates( - input_ids - ) + candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) @@ -5347,17 +4553,11 @@ def _assisted_decoding( # 2.1. Prepare the model inputs candidate_kwargs = copy.copy(model_kwargs) candidate_kwargs = _prepare_attention_mask( - candidate_kwargs, - candidate_input_ids.shape[1], - self.config.is_encoder_decoder, - ) - candidate_kwargs = _prepare_token_type_ids( - candidate_kwargs, candidate_input_ids.shape[1] + candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder ) + candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) - model_inputs = self.prepare_inputs_for_generation( - candidate_input_ids, **candidate_kwargs - ) + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) # 2.2. Run a forward pass on the candidate sequence outputs = self( @@ -5367,20 +4567,14 @@ def _assisted_decoding( ) # 2.3. Process the new logits - new_logits = outputs.logits[ - :, -candidate_length - 1 : - ] # excludes the input prompt if present + new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present next_token_logits = new_logits.clone() if len(logits_processor) > 0: for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_processor( - candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] - ) + new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) if len(logits_warper) > 0: for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_warper( - candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] - ) + new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) @@ -5402,17 +4596,12 @@ def _assisted_decoding( else: if do_sample: probs = new_logits.softmax(dim=-1) - selected_tokens = torch.multinomial( - probs[0, :, :], num_samples=1 - ).squeeze(1)[None, :] + selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] else: selected_tokens = new_logits.argmax(dim=-1) candidate_new_tokens = candidate_input_ids[:, cur_len:] - n_matches = ( - (~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) - < 1 - ).sum() + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() # Ensure we don't generate beyond max_len or an EOS token if last_assistant_token_is_eos and n_matches == candidate_length: @@ -5435,14 +4624,10 @@ def _assisted_decoding( # 4.2. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 - outputs.past_key_values = _crop_past_key_values( - self, outputs.past_key_values, new_cache_size - ) + outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) # 5. Update the candidate generation strategy if needed - candidate_generator.update_candidate_strategy( - input_ids, new_logits, n_matches - ) + candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -5463,10 +4648,7 @@ def _assisted_decoding( if output_attentions: if self.config.is_encoder_decoder: cross_attentions = _split_model_outputs( - cross_attentions, - outputs.cross_attentions, - cur_len, - added_len, + cross_attentions, outputs.cross_attentions, cur_len, added_len ) decoder_attentions = _split_model_outputs( decoder_attentions, @@ -5486,24 +4668,15 @@ def _assisted_decoding( if output_hidden_states: if self.config.is_encoder_decoder: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, - outputs.decoder_hidden_states, - cur_len, - added_len, + decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len ) else: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, - outputs.hidden_states, - cur_len, - added_len, + decoder_hidden_states, outputs.hidden_states, cur_len, added_len ) model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - model_inputs=model_inputs, + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) # if eos_token was found in one sentence, set sentence to finished @@ -5515,9 +4688,7 @@ def _assisted_decoding( .prod(dim=0) ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores - ) + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) # stop when each sentence is finished if unfinished_sequences.max() == 0: @@ -5531,8 +4702,7 @@ def _assisted_decoding( if ( hasattr(candidate_generator, "assistant_model") - and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule - == "heuristic" + and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" ): candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( candidate_generator.num_assistant_tokens @@ -5615,18 +4785,14 @@ def _speculative_sampling( # The selected tokens include the matches (if any) plus the next sampled tokens if n_matches > 0: - valid_tokens = torch.cat( - (new_candidate_input_ids[:, :n_matches], t), dim=-1 - ) + valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) else: valid_tokens = t return valid_tokens, n_matches -def _split_model_outputs( - outputs, new_outputs, cur_len, added_len, is_decoder_attention=False -): +def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): """ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple where each member corresponds to a single generated token. @@ -5666,17 +4832,11 @@ def _ranking_fast( """ norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) - cosine_matrix = torch.matmul( - norm_context_hidden, norm_next_hidden.transpose(1, 2) - ).squeeze( - -1 - ) # [B*K, S] + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] next_top_k_probs = next_top_k_probs.view(-1) # [B*K] contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty - contrastive_score = torch.stack( - torch.split(contrastive_score, beam_width) - ) # [B, K] + contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] _, selected_idx = contrastive_score.max(dim=-1) # [B] return selected_idx @@ -5699,10 +4859,7 @@ def _split(data, full_batch_size: int, split_size: int = None): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0], tuple): return [ - tuple( - tuple(tensor[i : i + split_size] for tensor in inner_tuple) - for inner_tuple in data - ) + tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) for i in range(0, full_batch_size, split_size) ] @@ -5739,43 +4896,29 @@ def _split_model_inputs( # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them keys = ( - model_input.__dataclass_fields__.keys() - if hasattr(model_input, "__dataclass_fields__") - else model_input.keys() + model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() ) # We only keep keys that are in the model_input keys = [k for k in keys if k in model_input] # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a # ModelOutput object. # bool should not be split but replicated for each split - bool_keys = [ - k for k in keys if isinstance(model_input[k], bool) or k == "cache_position" - ] + bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] keys_to_ignore = ["cache_position", "encoder_outputs"] - non_bool_keys = [ - k - for k in keys - if not isinstance(model_input[k], bool) and k not in keys_to_ignore - ] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] # we split the tensors and tuples of tensors data_split_list = [ - { - k: _split(model_input[k], full_batch_size, split_size)[i] - for k in non_bool_keys - } + {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys} for i in range(full_batch_size // split_size) ] # bool values are the same and replicated for each split bool_data = {k: model_input[k] for k in bool_keys} # encoder_outputs is a ModelOutput object and should be split by its own if "encoder_outputs" in model_input: - encoder_outputs_split = _split_model_inputs( - model_input["encoder_outputs"], split_size, full_batch_size - ) + encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size) data_split_list = [ - {**data_split, "encoder_outputs": encoder_outputs_split[i]} - for i, data_split in enumerate(data_split_list) + {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) ] # Convert each dictionary in the list to an object of the inferred class @@ -5814,17 +4957,11 @@ def _concat(data): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): return tuple( - tuple( - torch.cat([attr[i][j] for attr in data], dim=0) - for j in range(len(data[0][0])) - ) + tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0]))) for i in range(len(data[0])) ) else: - return tuple( - torch.cat([attr[i] for attr in data], dim=0) - for i in range(len(data[0])) - ) + return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0]))) elif isinstance(data[0], (int, float)): # If the elements are integers or floats, return a tensor return torch.tensor(data)