From fe5759b65b1881850c82f13482a16911f424a415 Mon Sep 17 00:00:00 2001 From: Korbinian Poeppel Date: Sat, 21 Dec 2024 13:40:54 +0000 Subject: [PATCH] Feat: Enable longer context window for inference by chunking. --- .../models/xlstm/configuration_xlstm.py | 4 + .../models/xlstm/modeling_xlstm.py | 86 ++++++++++++++----- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/xlstm/configuration_xlstm.py b/src/transformers/models/xlstm/configuration_xlstm.py index 130c71ecb2da2b..5037a0bf16a5b7 100644 --- a/src/transformers/models/xlstm/configuration_xlstm.py +++ b/src/transformers/models/xlstm/configuration_xlstm.py @@ -114,6 +114,8 @@ class xLSTMConfig(PretrainedConfig): EOS token id needed for generation. force_bos_token_insert (bool, optional, *optional*, defaults to `True`): Whether to force the insertion of a BOS token for prompting. + max_inference_chunksize (int, optional, *optional*, defaults to 16384): + Limit the chunk size for inference to save memory. Example: @@ -172,6 +174,7 @@ def __init__( bos_token_id: int = 0, eos_token_id: int = 2, force_bos_token_insert: bool = True, + max_inference_chunksize: int = 16384, **kwargs, ): self.vocab_size = vocab_size @@ -209,6 +212,7 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.force_bos_token_insert = force_bos_token_insert + self.max_inference_chunksize = max_inference_chunksize super().__init__( bos_token_id=bos_token_id, diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index de1fe232ddc7e8..a89c9adfeae04c 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -306,28 +306,60 @@ def forward( cache_params = None hidden_states = inputs_embeds - all_hidden_states = () if output_hidden_states else None - for i, xlstm_block in enumerate(self.blocks): - if self.gradient_checkpointing and self.training: - hidden_states, rnn_state = self._gradient_checkpointing_func( - xlstm_block.__call__, - hidden_states, - cache_params.rnn_state[i] if cache_params is not None else None, - ) - else: - hidden_states, rnn_state = xlstm_block( - hidden_states, - state=cache_params.rnn_state[i] if cache_params is not None else None, - ) - if cache_params: - for state_idx in range(len(cache_params.rnn_state[i])): - local_rnn_state = rnn_state[state_idx] - local_rnn_state = rnn_state[state_idx] - cache_params.rnn_state[i][state_idx].copy_(local_rnn_state) - cache_params.rnn_state_initial = False - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + if ( + not self.training + and self.config.max_inference_chunksize < hidden_states.shape[1] + and not output_hidden_states + ): + all_hidden_states = None + offset = 0 + with torch.no_grad(): + if cache_params is None: + cache_params = xLSTMCache(config=self.config, batch_size=hidden_states.shape[0]) + final_state = torch.zeros_like(hidden_states) + while offset < hidden_states.shape[1]: + hidden_states_chunk = hidden_states[ + :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) + ] + for i, xlstm_block in enumerate(self.blocks): + hidden_states_chunk, rnn_state = xlstm_block( + hidden_states_chunk, + state=cache_params.rnn_state[i], + ) + for state_idx in range(len(cache_params.rnn_state[i])): + local_rnn_state = rnn_state[state_idx] + local_rnn_state = rnn_state[state_idx] + cache_params.rnn_state[i][state_idx].copy_(local_rnn_state) + cache_params.rnn_state_initial = False + final_state[ + :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) + ] = hidden_states_chunk + offset += self.config.max_inference_chunksize + hidden_states = final_state + else: + all_hidden_states = () if output_hidden_states else None + for i, xlstm_block in enumerate(self.blocks): + if self.gradient_checkpointing and self.training: + hidden_states, rnn_state = self._gradient_checkpointing_func( + xlstm_block.__call__, + hidden_states, + cache_params.rnn_state[i] if cache_params is not None else None, + ) + else: + hidden_states, rnn_state = xlstm_block( + hidden_states, + state=cache_params.rnn_state[i] if cache_params is not None else None, + ) + if cache_params: + for state_idx in range(len(cache_params.rnn_state[i])): + local_rnn_state = rnn_state[state_idx] + local_rnn_state = rnn_state[state_idx] + cache_params.rnn_state[i][state_idx].copy_(local_rnn_state) + cache_params.rnn_state_initial = False + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) if use_cache: cache_params.seqlen_offset += inputs_embeds.shape[1] @@ -507,7 +539,17 @@ def forward( logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() - logits = soft_cap(logits, self.config.output_logit_soft_cap) + if not self.training and self.config.max_inference_chunksize < logits.shape[1]: + offset = 0 + with torch.no_grad(): + while offset < logits.shape[1]: + logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap( + logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])], + self.config.output_logit_soft_cap, + ) + offset += self.config.max_inference_chunksize + else: + logits = soft_cap(logits, self.config.output_logit_soft_cap) loss = None if labels is not None: