From 2e7c942c81f4803b75f476d587fa7b12cae2ea2a Mon Sep 17 00:00:00 2001 From: nakranivaibhav <67785830+nakranivaibhav@users.noreply.github.com> Date: Tue, 6 Feb 2024 08:11:42 +0530 Subject: [PATCH] Adds LlamaForQuestionAnswering class in modeling_llama.py along with AutoModel Support (#28777) * This is a test commit * testing commit * final commit with some changes * Removed copy statement * Fixed formatting issues * Fixed error added past_key_values in the forward method * Fixed a trailing whitespace. Damn the formatting rules are strict * Added the copy statement --- docs/source/en/model_doc/llama.md | 5 + docs/source/en/tasks/question_answering.md | 2 +- src/transformers/__init__.py | 9 +- src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/llama/__init__.py | 9 +- .../models/llama/modeling_llama.py | 104 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 7 ++ tests/models/llama/test_modeling_llama.py | 8 +- 8 files changed, 140 insertions(+), 5 deletions(-) diff --git a/docs/source/en/model_doc/llama.md b/docs/source/en/model_doc/llama.md index 96f2a2e7eb7cdd..915d5ecc70b554 100644 --- a/docs/source/en/model_doc/llama.md +++ b/docs/source/en/model_doc/llama.md @@ -116,6 +116,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] LlamaForSequenceClassification - forward +## LlamaForQuestionAnswering + +[[autodoc]] LlamaForQuestionAnswering + - forward + ## FlaxLlamaModel [[autodoc]] FlaxLlamaModel diff --git a/docs/source/en/tasks/question_answering.md b/docs/source/en/tasks/question_answering.md index 625cf96dc886b6..7c228061ff8e71 100644 --- a/docs/source/en/tasks/question_answering.md +++ b/docs/source/en/tasks/question_answering.md @@ -36,7 +36,7 @@ The task illustrated in this tutorial is supported by the following model archit -[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) +[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5e2b87089aba9e..b233ee2acb09ee 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2483,6 +2483,7 @@ _import_structure["models.llama"].extend( [ "LlamaForCausalLM", + "LlamaForQuestionAnswering", "LlamaForSequenceClassification", "LlamaModel", "LlamaPreTrainedModel", @@ -7025,7 +7026,13 @@ LiltModel, LiltPreTrainedModel, ) - from .models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel + from .models.llama import ( + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaModel, + LlamaPreTrainedModel, + ) from .models.llava import ( LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, LlavaForConditionalGeneration, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 34a1f9d1d3a9ca..8ef6dc5df5a9ae 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -849,6 +849,7 @@ ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), ("led", "LEDForQuestionAnswering"), ("lilt", "LiltForQuestionAnswering"), + ("llama", "LlamaForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"), ("luke", "LukeForQuestionAnswering"), ("lxmert", "LxmertForQuestionAnswering"), diff --git a/src/transformers/models/llama/__init__.py b/src/transformers/models/llama/__init__.py index b5e9a60cda6e3c..b5262941cb0e5c 100644 --- a/src/transformers/models/llama/__init__.py +++ b/src/transformers/models/llama/__init__.py @@ -54,6 +54,7 @@ "LlamaModel", "LlamaPreTrainedModel", "LlamaForSequenceClassification", + "LlamaForQuestionAnswering", ] try: @@ -90,7 +91,13 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel + from .modeling_llama import ( + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaModel, + LlamaPreTrainedModel, + ) try: if not is_flax_available(): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 90706cae68c0c1..4c8579fce24d76 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -36,7 +36,12 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 from ...utils import ( @@ -1413,3 +1418,100 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2e31313be7b6e8..c766f3f522b124 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4689,6 +4689,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class LlamaForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LlamaForSequenceClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index c1cc479123f0a0..8ee7617a0b742e 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -44,6 +44,7 @@ from transformers import ( CodeLlamaTokenizer, LlamaForCausalLM, + LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaModel, LlamaTokenizer, @@ -278,7 +279,11 @@ def prepare_config_and_inputs_for_common(self): @require_torch class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification) if is_torch_available() else () + all_model_classes = ( + (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering) + if is_torch_available() + else () + ) all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { @@ -286,6 +291,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi "text-classification": LlamaForSequenceClassification, "text-generation": LlamaForCausalLM, "zero-shot": LlamaForSequenceClassification, + "question-answering": LlamaForQuestionAnswering, } if is_torch_available() else {}