Skip to content

Commit

Permalink
fix(server): fix OPT implementation (#2061)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jun 12, 2024
1 parent 376a0b7 commit 521de6c
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def forward(
return_dict=return_dict,
)

logits, speculative_logits = self.lm_head(outputs)
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)

loss = None

Expand Down
3 changes: 1 addition & 2 deletions server/text_generation_server/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,4 @@ def forward(
use_cache=True,
)

logits = outputs.logits
return logits, speculative_logits, outputs.past_key_values
return outputs.logits, speculative_logits, outputs.past_key_values
4 changes: 2 additions & 2 deletions server/text_generation_server/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def __init__(
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)

return outputs.logits, outputs.past_key_values
return outputs.logits, speculative_logits, outputs.past_key_values
8 changes: 5 additions & 3 deletions server/text_generation_server/models/rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def __init__(

def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
):
# Model Forward
outputs = self.model.forward(
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values

return outputs.logits, speculative_logits, outputs.past_key_values

0 comments on commit 521de6c

Please sign in to comment.