diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a076f704b8ede2..364e1398383a22 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -598,6 +598,8 @@ title: T5v1.1 - local: model_doc/tapex title: TAPEX + - local: model_doc/telechat2 + title: TeleChat2 - local: model_doc/transfo-xl title: Transformer XL - local: model_doc/ul2 diff --git a/docs/source/en/index.md b/docs/source/en/index.md index dcecfc872d61d0..4326a2ea653920 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -325,6 +325,7 @@ Flax), PyTorch, and/or TensorFlow. | [Table Transformer](model_doc/table-transformer) | ✅ | ❌ | ❌ | | [TAPAS](model_doc/tapas) | ✅ | ✅ | ❌ | | [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ | +| [TeleChat2](model_doc/telechat2) | ✅ | ❌ | ❌ | | [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ | | [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ | | [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/telechat2.md b/docs/source/en/model_doc/telechat2.md new file mode 100644 index 00000000000000..bb7eefc61adf6d --- /dev/null +++ b/docs/source/en/model_doc/telechat2.md @@ -0,0 +1,85 @@ + + +# TeleChat2 + +## Overview + +The TeleChat2 model was proposed in [TELECHAT TECHNICAL REPORT](https://arxiv.org/pdf/2401.03804) by TeleAI. + +### Summary + +The abstract from the paper is the following: + +TeleChat is a series of large language models, offering decoder-based language models in various sizes (3B, 7B, and 12B). For each size, we provide both the base pretrained model and the fine-tuned chat model aligned with human preferences. TeleChat leverages a Transformer architecture with features such as SwiGLU activation, advanced attention mechanisms (QKV bias, group query attention), and support for sliding window attention. The models are optimized for bilingual proficiency (English and Chinese) and include an enhanced tokenizer adaptable to diverse natural languages and coding formats. + +## Usage tips + +The original code for telechat2 can be found [here](https://huggingface.co/Tele-AI/TeleChat2-7B). + +In the following, we demonstrate how to use `TeleChat2-7B` for the inference. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose. + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer +>>> device = "cuda" # the device to load the model onto + +>>> model = AutoModelForCausalLM.from_pretrained("Tele-AI/TeleChat2-7B", device_map="auto") +>>> tokenizer = AutoTokenizer.from_pretrained("Tele-AI/TeleChat2-7B") + +>>> prompt = "Give me a short introduction to large language model." + +>>> messages = [{"role": "user", "content": prompt}] + +>>> text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + +>>> model_inputs = tokenizer([text], return_tensors="pt").to(device) + +>>> generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512, do_sample=True) + +>>> generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)] + +>>> response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] +``` + +## TeleChat2Config + +[[autodoc]] TeleChat2Config + + +## TeleChat2Model + +[[autodoc]] TeleChat2Model + - forward + +## TeleChat2ForCausalLM + +[[autodoc]] TeleChat2ForCausalLM + - forward + +## TeleChat2ForSequenceClassification + +[[autodoc]] TeleChat2ForSequenceClassification + - forward + +## TeleChat2ForTokenClassification + +[[autodoc]] TeleChat2ForTokenClassification + - forward + +## TeleChat2ForQuestionAnswering + +[[autodoc]] TeleChat2ForQuestionAnswering + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 364141c8e406b2..b3bfa75d39354e 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -96,6 +96,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel) * [RAG](https://huggingface.co/docs/transformers/model_doc/rag#transformers.RagModel) * [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel) +* [TeleChat2](https://huggingface.co/docs/transformers/model_doc/telechat2) * [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel) * [VisionTextDualEncoder](https://huggingface.co/docs/transformers/model_doc/vision_text_dual_encoder#transformers.VisionTextDualEncoderModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) @@ -303,6 +304,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron) * [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel) +* [TeleChat2](https://huggingface.co/docs/transformers/model_doc/telechat2) * [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava) * [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava) * [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7df1af049de626..59e00df09d48ca 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -788,6 +788,7 @@ "TapasConfig", "TapasTokenizer", ], + "models.telechat2": ["TeleChat2Config"], "models.time_series_transformer": ["TimeSeriesTransformerConfig"], "models.timesformer": ["TimesformerConfig"], "models.timm_backbone": ["TimmBackboneConfig"], @@ -3573,6 +3574,16 @@ "load_tf_weights_in_tapas", ] ) + _import_structure["models.telechat2"].extend( + [ + "TeleChat2ForCausalLM", + "TeleChat2ForQuestionAnswering", + "TeleChat2ForSequenceClassification", + "TeleChat2ForTokenClassification", + "TeleChat2Model", + "TeleChat2PreTrainedModel", + ] + ) _import_structure["models.time_series_transformer"].extend( [ "TimeSeriesTransformerForPrediction", @@ -5801,6 +5812,7 @@ TapasConfig, TapasTokenizer, ) + from .models.telechat2 import TeleChat2Config from .models.time_series_transformer import ( TimeSeriesTransformerConfig, ) @@ -8135,6 +8147,14 @@ TapasPreTrainedModel, load_tf_weights_in_tapas, ) + from .models.telechat2 import ( + TeleChat2ForCausalLM, + TeleChat2ForQuestionAnswering, + TeleChat2ForSequenceClassification, + TeleChat2ForTokenClassification, + TeleChat2Model, + TeleChat2PreTrainedModel, + ) from .models.time_series_transformer import ( TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel, diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index f37f589d5d53e0..3eb7eb8252cb8a 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1603,6 +1603,7 @@ def converted(self) -> Tokenizer: "CodeLlamaTokenizer": LlamaConverter, "GemmaTokenizer": GemmaConvert, "Phi3Tokenizer": LlamaConverter, + "TeleChat2Tokenizer": LlamaConverter, } diff --git a/src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py b/src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py index 0b93e4c53ff891..0e67af91855075 100755 --- a/src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py +++ b/src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py @@ -30,7 +30,9 @@ TOKENIZER_CLASSES = { # Phi3 uses Llama tokenizer - name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast") + name: getattr( + transformers, "LlamaTokenizerFast" if name in ["Phi3Tokenizer", "TeleChat2Tokenizer"] else name + "Fast" + ) for name in SLOW_TO_FAST_CONVERTERS } diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index ff03d09966a4d6..94db4e2051fea5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -251,6 +251,7 @@ t5, table_transformer, tapas, + telechat2, time_series_transformer, timesformer, timm_backbone, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 6c052aa0eaa0f3..18b1086d245cdc 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -278,6 +278,7 @@ ("t5", "T5Config"), ("table-transformer", "TableTransformerConfig"), ("tapas", "TapasConfig"), + ("telechat2", "TeleChat2Config"), ("time_series_transformer", "TimeSeriesTransformerConfig"), ("timesformer", "TimesformerConfig"), ("timm_backbone", "TimmBackboneConfig"), @@ -608,6 +609,7 @@ ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), ("tapex", "TAPEX"), + ("telechat2", "TeleChat2"), ("time_series_transformer", "Time Series Transformer"), ("timesformer", "TimeSformer"), ("timm_backbone", "TimmBackbone"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 861754f591769b..86767e13b2c909 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -256,6 +256,7 @@ ("t5", "T5Model"), ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), + ("telechat2", "TeleChat2Model"), ("time_series_transformer", "TimeSeriesTransformerModel"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), @@ -556,6 +557,7 @@ ("speech_to_text_2", "Speech2Text2ForCausalLM"), ("stablelm", "StableLmForCausalLM"), ("starcoder2", "Starcoder2ForCausalLM"), + ("telechat2", "TeleChat2ForCausalLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("trocr", "TrOCRForCausalLM"), ("whisper", "WhisperForCausalLM"), @@ -1029,6 +1031,7 @@ ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), ("tapas", "TapasForSequenceClassification"), + ("telechat2", "TeleChat2ForSequenceClassification"), ("transfo-xl", "TransfoXLForSequenceClassification"), ("umt5", "UMT5ForSequenceClassification"), ("xlm", "XLMForSequenceClassification"), @@ -1105,6 +1108,7 @@ ("splinter", "SplinterForQuestionAnswering"), ("squeezebert", "SqueezeBertForQuestionAnswering"), ("t5", "T5ForQuestionAnswering"), + ("telechat2", "TeleChat2ForQuestionAnswering"), ("umt5", "UMT5ForQuestionAnswering"), ("xlm", "XLMForQuestionAnsweringSimple"), ("xlm-roberta", "XLMRobertaForQuestionAnswering"), @@ -1207,6 +1211,7 @@ ("stablelm", "StableLmForTokenClassification"), ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), + ("telechat2", "TeleChat2ForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 350c230f142c15..fbe747acf0f0e7 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -493,6 +493,7 @@ ), ("tapas", ("TapasTokenizer", None)), ("tapex", ("TapexTokenizer", None)), + ("telechat2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("transfo-xl", ("TransfoXLTokenizer", None)), ("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/models/telechat2/__init__.py b/src/transformers/models/telechat2/__init__.py new file mode 100644 index 00000000000000..2946b0cec92081 --- /dev/null +++ b/src/transformers/models/telechat2/__init__.py @@ -0,0 +1,65 @@ +# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_telechat2": ["TeleChat2Config"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_telechat2"] = [ + "TeleChat2ForCausalLM", + "TeleChat2ForQuestionAnswering", + "TeleChat2ForSequenceClassification", + "TeleChat2ForTokenClassification", + "TeleChat2Model", + "TeleChat2PreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_telechat2 import TeleChat2Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_telechat2 import ( + TeleChat2ForCausalLM, + TeleChat2ForQuestionAnswering, + TeleChat2ForSequenceClassification, + TeleChat2ForTokenClassification, + TeleChat2Model, + TeleChat2PreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/telechat2/configuration_telechat2.py b/src/transformers/models/telechat2/configuration_telechat2.py new file mode 100644 index 00000000000000..0bf264107c1b46 --- /dev/null +++ b/src/transformers/models/telechat2/configuration_telechat2.py @@ -0,0 +1,210 @@ +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TeleChat2 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class TeleChat2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TeleChat2Model`]. It is used to instantiate an TeleChat2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of + TeleChat2-7B [Tele-AI/TeleChat2-7B](https://huggingface.co/Tele-AI/TeleChat2-7B). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the TeleChat2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`TeleChat2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + ffn_hidden_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability applied to the hidden layers in the MLP blocks. + n_layer (`int`, *optional*, defaults to 30): + Number of hidden layers in the Transformer decoder. + n_head (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=n_head`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `n_head`. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. TeleChat2 1 supports up to 8192 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'telechat23'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'telechat23'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'telechat23'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'telechat23'. Scaling factor applied to high frequency components of the RoPE + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_heads + + ```python + >>> from transformers import TeleChat2Model, TeleChat2Config + + >>> # Initializing a TeleChat2 telechat2-7b style configuration + >>> configuration = TeleChat2Config() + + >>> # Initializing a model from the telechat2-7b style configuration + >>> model = TeleChat2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "telechat2" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `TeleChat2Model` + base_model_tp_plan = { + "h.*.self_attention.query": "colwise", + "h.*.self_attention.key_value": "colwise", + "h.*.self_attn.dense": "rowwise", + "h.*.mlp.gate_proj": "colwise", + "h.*.mlp.up_proj": "colwise", + "h.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + ffn_hidden_size=11008, + hidden_dropout=0.0, + n_layer=30, + n_head=32, + num_key_value_heads=32, + max_position_embeddings=2048, + initializer_range=0.02, + layer_norm_epsilon=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.num_hidden_layers = n_layer + self.num_attention_heads = n_head + self.hidden_dropout = hidden_dropout + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = n_head + + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.layer_norm_epsilon = layer_norm_epsilon + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/src/transformers/models/telechat2/modeling_telechat2.py b/src/transformers/models/telechat2/modeling_telechat2.py new file mode 100644 index 00000000000000..0a04d5fc2846fe --- /dev/null +++ b/src/transformers/models/telechat2/modeling_telechat2.py @@ -0,0 +1,1163 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. team and BigScience workshop. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# This code is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +"""PyTorch TeleChat2 model implementation, refactored with Transformers-style conventions.""" + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_telechat2 import TeleChat2Config + + +logger = logging.get_logger(__name__) + + +_CHECKPOINT_FOR_DOC = "TeleAI/TeleChat2-3B" +_CONFIG_FOR_DOC = "TeleChat2Config" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->TeleChat2 +class TeleChat2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + TeleChat2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(TeleChat2RMSNorm) + + +class TeleChat2RotaryEmbedding(nn.Module): + def __init__( + self, + config: TeleChat2Config, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + """ + Apply dropout to `x` and add the result to the residual tensor. + + Args: + x (`torch.Tensor`): Input tensor. + residual (`torch.Tensor`): Residual tensor. + prob (`float`): Dropout probability. + training (`bool`): Whether in training mode. + + Returns: + `torch.Tensor`: Output after dropout and addition. + """ + out = F.dropout(x, p=prob, training=training) + out = residual + out + return out + + +class TeleChat2MLP(nn.Module): + """ + TeleChat2 MLP block with a gated activation function. + """ + + def __init__(self, config: TeleChat2Config): + super().__init__() + hidden_size = config.hidden_size + self.gate_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(config.ffn_hidden_size, hidden_size, bias=True) + self.hidden_dropout = config.hidden_dropout + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + intermediate_output = self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) + return output + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class TeleChat2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + """ + + def __init__(self, config: TeleChat2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.query = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.key_value = nn.Linear(config.hidden_size, self.head_dim * config.num_key_value_heads * 2, bias=False) + self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + + mixed_kv = self.key_value(hidden_states) + mixed_kv = mixed_kv.view(*input_shape, -1, 2 * self.head_dim) + key_states, value_states = torch.split(mixed_kv, self.head_dim, dim=-1) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.dense(attn_output) + return attn_output, attn_weights + + +class TeleChat2DecoderLayer(nn.Module): + def __init__(self, config: TeleChat2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attention = TeleChat2Attention(config=config, layer_idx=layer_idx) + + self.mlp = TeleChat2MLP(config) + self.input_layernorm = TeleChat2RMSNorm(self.hidden_size, eps=config.layer_norm_epsilon) + self.post_attention_layernorm = TeleChat2RMSNorm(self.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, residual) + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +TELECHAT2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TeleChat2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare TELECHAT2 Model outputting raw hidden-states without any specific head on top.", + TELECHAT2_START_DOCSTRING, +) +class TeleChat2PreTrainedModel(PreTrainedModel): + config_class = TeleChat2Config + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["TeleChat2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +TELECHAT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare TeleChat2 Model outputting raw hidden-states without any specific head on top.", + TELECHAT2_START_DOCSTRING, +) +class TeleChat2Model(TeleChat2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TeleChat2DecoderLayer`] + + Args: + config: TeleChat2Config + """ + + def __init__(self, config: TeleChat2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.h = nn.ModuleList([TeleChat2DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) + self.ln_f = TeleChat2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.rotary_emb = TeleChat2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + + @add_start_docstrings_to_model_forward(TELECHAT2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.h[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.ln_f(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class TeleChat2ForCausalLM(TeleChat2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config: TeleChat2Config): + super().__init__(config) + self.transformer = TeleChat2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.word_embeddings + + def set_input_embeddings(self, value): + self.transformer.word_embeddings = value + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.transformer = decoder + + def get_decoder(self): + return self.transformer + + @add_start_docstrings_to_model_forward(TELECHAT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, TeleChat2ForCausalLM + + >>> model = TeleChat2ForCausalLM.from_pretrained("TeleAI/TeleChat2-7B") + >>> tokenizer = AutoTokenizer.from_pretrained("TeleAI/TeleChat2-7B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The TeleChat2 Model transformer with a sequence classification head on top (linear layer). + + [`TeleChat2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + TELECHAT2_START_DOCSTRING, +) +class TeleChat2ForSequenceClassification(TeleChat2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = TeleChat2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.word_embeddings + + def set_input_embeddings(self, value): + self.transformer.word_embeddings = value + + @add_start_docstrings_to_model_forward(TELECHAT2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The TeleChat2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + TELECHAT2_START_DOCSTRING, +) +class TeleChat2ForQuestionAnswering(TeleChat2PreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->TeleChat2 + def __init__(self, config): + super().__init__(config) + self.transformer = TeleChat2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.word_embeddings + + def set_input_embeddings(self, value): + self.transformer.word_embeddings = value + + @add_start_docstrings_to_model_forward(TELECHAT2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The TeleChat2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + TELECHAT2_START_DOCSTRING, +) +class TeleChat2ForTokenClassification(TeleChat2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = TeleChat2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.word_embeddings + + def set_input_embeddings(self, value): + self.transformer.word_embeddings = value + + @add_start_docstrings_to_model_forward(TELECHAT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 922d67264bb142..46170a423a178a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -9116,6 +9116,48 @@ def load_tf_weights_in_tapas(*args, **kwargs): requires_backends(load_tf_weights_in_tapas, ["torch"]) +class TeleChat2ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TeleChat2ForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TeleChat2ForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TeleChat2ForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TeleChat2Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TeleChat2PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class TimeSeriesTransformerForPrediction(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/telechat2/__init__.py b/tests/models/telechat2/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/telechat2/test_modeling_telechat2.py b/tests/models/telechat2/test_modeling_telechat2.py new file mode 100644 index 00000000000000..6789a8f1c31e14 --- /dev/null +++ b/tests/models/telechat2/test_modeling_telechat2.py @@ -0,0 +1,616 @@ +# coding=utf-8 +# Copyright 2024 The TeleChat team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch TeleChat2 model.""" + +import gc +import unittest + +import pytest +from packaging import version + +from transformers import AutoTokenizer, TeleChat2Config, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig +from transformers.testing_utils import ( + backend_empty_cache, + require_bitsandbytes, + require_flash_attn, + require_torch, + require_torch_gpu, + require_torch_sdpa, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + TeleChat2ForCausalLM, + TeleChat2ForQuestionAnswering, + TeleChat2ForSequenceClassification, + TeleChat2ForTokenClassification, + TeleChat2Model, + ) + + +class TeleChat2ModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + max_window_layers=3, + use_sliding_window=True, + sliding_window=50, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + bos_token_id=1, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.max_window_layers = max_window_layers + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.scope = scope + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return TeleChat2Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + n_layer=self.num_hidden_layers, + max_window_layers=self.max_window_layers, + use_sliding_window=self.use_sliding_window, + sliding_window=self.sliding_window, + n_head=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + ffn_hidden_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + ) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->TeleChat2 + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TeleChat2Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->TeleChat2 + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = TeleChat2Model(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->TeleChat2 + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = TeleChat2ForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->TeleChat2 + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = TeleChat2ForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class TeleChat2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + TeleChat2Model, + TeleChat2ForCausalLM, + TeleChat2ForSequenceClassification, + TeleChat2ForTokenClassification, + TeleChat2ForQuestionAnswering, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (TeleChat2ForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": TeleChat2Model, + "text-classification": TeleChat2ForSequenceClassification, + "token-classification": TeleChat2ForTokenClassification, + "text-generation": TeleChat2ForCausalLM, + "zero-shot": TeleChat2ForSequenceClassification, + "question-answering": TeleChat2ForQuestionAnswering, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + def setUp(self): + self.model_tester = TeleChat2ModelTester(self) + self.config_tester = ConfigTester(self, config_class=TeleChat2Config, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_TeleChat2_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + print(config) + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = TeleChat2ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_TeleChat2_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = TeleChat2ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_TeleChat2_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = TeleChat2ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_TeleChat2_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = TeleChat2ForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + @unittest.skip(reason="TeleChat2 buffers include complex numbers, which breaks this test") + def test_save_load_fast_init_from_base(self): + pass + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="TeleChat2 flash attention does not support right padding") + + +@require_torch +class TeleChat2IntegrationTest(unittest.TestCase): + @slow + def test_model_450m_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = TeleChat2ForCausalLM.from_pretrained("TeleAI/TeleChat2-3B", device_map="auto") + input_ids = torch.tensor([input_ids]).to(model.model.word_embeddings.weight.device) + with torch.no_grad(): + out = model(input_ids).logits.float().cpu() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[-1.9537, -1.6193, -1.4123, -1.4673, -1.8511, -1.9309, -1.9826, -2.1776]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([3.2025, 7.1265, 4.6058, 3.6423, 1.6357, 3.9265, 5.1883, 5.8760, 2.7942, 4.4823, 3.2571, 2.1063, 3.4275, 4.2028, 1.9767, 5.2115, 6.6756, 6.3999, 6.0483, 5.7378, 5.6660, 5.2298, 5.4103, 5.1248, 5.4376, 2.4570, 2.6107, 5.4039, 2.8077, 4.7777]) # fmt: skip + print(out[0, 0, :30]) + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + def test_model_450m_generation(self): + EXPECTED_TEXT_COMPLETION = ( + """My favourite condiment is 100% natural, organic and vegan. I love to use it in my cooking and I""" + ) + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("TeleAI/TeleChat2-3B", use_fast=False) + model = TeleChat2ForCausalLM.from_pretrained("TeleAI/TeleChat2-3B", device_map="auto") + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.word_embeddings.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @require_bitsandbytes + @slow + @require_flash_attn + @pytest.mark.flash_attn_test + def test_model_450m_long_prompt(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = TeleChat2ForCausalLM.from_pretrained( + "TeleAI/TeleChat2-3B", + device_map="auto", + load_in_4bit=True, + attn_implementation="flash_attention_2", + ) + input_ids = torch.tensor([input_ids]).to(model.model.word_embeddings.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + @require_torch_sdpa + def test_model_450m_long_prompt_sdpa(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = TeleChat2ForCausalLM.from_pretrained( + "TeleAI/TeleChat2-3B", device_map="auto", attn_implementation="sdpa" + ) + input_ids = torch.tensor([input_ids]).to(model.model.word_embeddings.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + + backend_empty_cache(torch_device) + gc.collect() + + EXPECTED_TEXT_COMPLETION = ( + """My favourite condiment is 100% natural, organic and vegan. I love to use it in my cooking and I""" + ) + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("TeleAI/TeleChat2-3B", use_fast=False) + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.word_embeddings.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + def test_speculative_generation(self): + EXPECTED_TEXT_COMPLETION = ( + "My favourite condiment is 100% natural honey, and I always like to use it in my recipes. I love" + ) + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("TeleAIt/TeleChat2-7B", use_fast=False) + model = TeleChat2ForCausalLM.from_pretrained( + "TeleAI/TeleChat2-3B", device_map="auto", torch_dtype=torch.float16 + ) + assistant_model = TeleChat2ForCausalLM.from_pretrained( + "TeleChat/TeleChat2-3B", device_map="auto", torch_dtype=torch.float16 + ) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.word_embeddings.weight.device) + + # greedy generation outputs + set_seed(0) + generated_ids = model.generate( + input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + telechat_model = "TeleAI/TeleChat2-3B" + + tokenizer = AutoTokenizer.from_pretrained(telechat_model, pad_token="<_pad>", padding_side="right") + EXPECTED_TEXT_COMPLETION = [ + "My favourite condiment is 100% natural, organic, gluten free, vegan, and free from preservatives. I" + ] + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = TeleChat2ForCausalLM.from_pretrained( + telechat_model, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + }, + ), + ) + + prompt = ["My favourite condiment is "] + prompt_tokens = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 116e26e7834f26..ea86e228d2494a 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -199,6 +199,7 @@ "giou_cost", "giou_loss_coefficient", ], + "TeleChat2Config": ["n_head"], }