Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 24, 2024
1 parent 5c5d378 commit 62ba00b
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
4 changes: 3 additions & 1 deletion swift/llm/argument/base_args/generation_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class GenerationArguments:

stream: bool = False
stop_words: List[str] = field(default_factory=list)
logprobs: bool = False

def get_request_config(self):
if getattr(self, 'task_type') != 'causal_lm':
Expand All @@ -48,4 +49,5 @@ def get_request_config(self):
num_beams=self.num_beams,
stop=self.stop_words,
stream=self.stream,
repetition_penalty=self.repetition_penalty)
repetition_penalty=self.repetition_penalty,
logprobs=self.logprobs)
4 changes: 3 additions & 1 deletion swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def infer_dataset(self) -> List[Dict[str, Any]]:
val_dataset, request_config, template=self.template, use_tqdm=True, **self.infer_kwargs)
for data, resp, labels in zip(val_dataset, resp_list, labels_list):
response = resp.choices[0].message.content
data = {'response': response, 'labels': labels, **data}
if labels:
data['labels'] = labels
data = {'response': response, 'logprobs': resp.choices[0].logprobs, **data}
result_list.append(data)
if is_dist:
total_result_list = [None for _ in range(args.world_size)] if args.rank == 0 else None
Expand Down
12 changes: 10 additions & 2 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,13 @@ def _get_adapter_names(self, adapter_request: AdapterRequest) -> List[str]:
self._add_adapter(adapter_request.path, adapter_name)
return [adapter_name]

@staticmethod
def _get_seq_cls_logprobs(logprobs):
res = []
for i, logprob in enumerate(logprobs.tolist()):
res.append({'index': i, 'logprob': logprob})
return {'content': res}

def _infer_seq_cls(self,
template: Template,
inputs: Dict[str, Any],
Expand All @@ -266,16 +273,17 @@ def _infer_seq_cls(self,
num_prompt_tokens = self._get_num_tokens(inputs)
inputs.pop('labels')
logits = self.model(**inputs, **call_kwargs).logits
logprobs = torch.log_softmax(logits, -1)
preds = torch.argmax(logits, dim=-1).tolist()
res = []
for pred in preds:
for i, pred in enumerate(preds):
usage_info = self._get_usage_info(num_prompt_tokens, 1)
choices = [
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=str(pred), tool_calls=None),
finish_reason='stop',
logprobs=None)
logprobs=self._get_seq_cls_logprobs(logprobs[i]))
]
res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info))
return res
Expand Down
1 change: 1 addition & 0 deletions tests/train/test_cls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

kwargs = {
'per_device_train_batch_size': 2,
'per_device_eval_batch_size': 2,
Expand Down

0 comments on commit 62ba00b

Please sign in to comment.