diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 83d62dea2ee..9b2d01e0763 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -792,7 +792,7 @@ def forward( return_dict=return_dict, ) - logits, speculative_logits = self.lm_head(outputs) + logits, speculative_logits = self.lm_head(outputs.last_hidden_state) loss = None diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index c0e1adf2d9d..d1f8f5be150 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -85,5 +85,4 @@ def forward( use_cache=True, ) - logits = outputs.logits - return logits, speculative_logits, outputs.past_key_values + return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 5b84f4ff837..87319ef09a7 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -75,11 +75,11 @@ def __init__( def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, ) - return outputs.logits, outputs.past_key_values + return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index d4764dedee2..50f6ead88fa 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -71,11 +71,13 @@ def __init__( def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + ): # Model Forward - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, + use_cache=True, ) - return outputs.logits, outputs.past_key_values + + return outputs.logits, speculative_logits, outputs.past_key_values