Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 28, 2024
1 parent 246c7a9 commit 5ed1938
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
- remove_unused_columns: 默认值False
- logging_first_step: 是否记录第一个step的打印,默认值True
- logging_steps: 日志打印间隔,默认值5
- average_tokens_across_devices: 是否在设备之间对token数进行平均。如果设置为True,将使用all_reduce同步`num_tokens_in_batch`以进行精确的损失计算。默认为None,如果为分布式训练则设置为True,否则为False
- metric_for_best_model: 默认为None. 即当`predict_with_generate`设置为False, 则为'loss', 否则设置为'rouge-l'
- greater_is_better: 默认为None. 即当`metric_for_best_model`含'loss'时, 设置为False, 否则设置为True.

Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ This parameter list inherits from transformers `Seq2SeqTrainingArguments`, with
- remove_unused_columns: Default is False.
- logging_first_step: Whether to log the first step print, default is True.
- logging_steps: Interval for logging prints, default is 5.
- average_tokens_across_devices: Whether to average the token count across devices. If set to True, it will use all_reduce to synchronize `num_tokens_in_batch` for accurate loss computation. The default is None; set to True for distributed training, otherwise set to False.
- metric_for_best_model: Default is None. When `predict_with_generate` is set to False, it is 'loss'; otherwise, it is 'rouge-l'.
- greater_is_better: Default is None. When `metric_for_best_model` contains 'loss', set to False; otherwise, set to True.

Expand Down
6 changes: 4 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
InferClient, run_deploy, AdapterRequest, prepare_model_template)
from .export import (export_main, merge_lora, quantize_model, export_to_ollama)
from .eval import eval_main
from .app import app_main
from .train import sft_main, pt_main, rlhf_main, get_multimodal_target_regex
from .argument import (EvalArguments, InferArguments, TrainArguments, ExportArguments, DeployArguments,
RLHFArguments, WebUIArguments, BaseArguments)
RLHFArguments, WebUIArguments, BaseArguments, AppArguments)
from .template import (TEMPLATE_MAPPING, Template, Word, get_template, TemplateType, register_template,
TemplateInputs, Messages, TemplateMeta, get_template_meta, InferRequest, load_image,
MaxLengthError, load_file)
Expand All @@ -36,11 +37,12 @@
'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template'
],
'export': ['export_main', 'merge_lora', 'quantize_model', 'export_to_ollama'],
'app': ['app_main'],
'eval': ['eval_main'],
'train': ['sft_main', 'pt_main', 'rlhf_main', 'get_multimodal_target_regex'],
'argument': [
'EvalArguments', 'InferArguments', 'TrainArguments', 'ExportArguments', 'WebUIArguments', 'DeployArguments',
'RLHFArguments', 'BaseArguments'
'RLHFArguments', 'BaseArguments', 'AppArguments'
],
'template': [
'TEMPLATE_MAPPING', 'Template', 'Word', 'get_template', 'TemplateType', 'register_template',
Expand Down
1 change: 1 addition & 0 deletions swift/llm/app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .app import app_main, SwiftApp
9 changes: 8 additions & 1 deletion swift/llm/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..argument import AppArguments
from ..base import SwiftPipeline
from ..infer import run_deploy
from .llm_ui import build_llm_ui

logger = get_logger()

Expand All @@ -17,4 +18,10 @@ class SwiftApp(SwiftPipeline):

def run(self):
args = self.args
deploy_context = nullcontext() if args.eval_url else run_deploy(self.args, return_url=True)
deploy_context = nullcontext() if args.api_url else run_deploy(self.args, return_url=True)
demo = build_llm_ui()
demo.queue().launch()


def app_main(args: Union[List[str], AppArguments, None] = None):
return SwiftApp(args).main()
81 changes: 81 additions & 0 deletions swift/llm/app/llm_ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple

import gradio as gr

History = List[Tuple[str, str]]
Messages = List[Dict[str, str]]


def clear_session():
return '', []


def modify_system_session(system: str, default_system: str):
system = system or default_system
return system, system, []


def model_chat(query: str, history: History, system: str) -> Tuple[str, str, History]:
query = query or ''
history = history or []
history.append([query, None])
messages = history_to_messages(history, system)
gen = Generation.call(model='qwen2-72b-instruct', messages=messages, result_format='message', stream=True)
for response in gen:
if response.status_code == HTTPStatus.OK:
role = response.output.choices[0].message.role
response = response.output.choices[0].message.content
system, history = messages_to_history(messages + [{'role': role, 'content': response}])
yield '', history, system
else:
raise ValueError('Request id: %s, Status code: %s, error code: %s, error message: %s' %
(response.request_id, response.status_code, response.code, response.message))


locale_mapping = {
'modify_system': {
'en': '🛠️ Set system and clear history',
'zh': '🛠️ 设置system并清空历史'
},
'clear_history': {
'en': '🧹 Clear history',
'zh': '🧹 清空历史'
},
'submit': {
'en': '🚀 Send',
'zh': '🚀 发送'
},
}


def build_llm_ui(studio_title: str, *, lang: str = 'zh', default_system: Optional[str] = None):
with gr.Blocks() as demo:
gr.Markdown(f'<center><font size=8>{studio_title}</center>')
with gr.Row():
with gr.Column(scale=3):
system_input = gr.Textbox(value=default_system, lines=1, label='System')
with gr.Column(scale=1):
modify_system = gr.Button(locale_mapping['modify_system'][lang], scale=2)
chatbot = gr.Chatbot(label='Chatbot')
textbox = gr.Textbox(lines=1, label='Input')

with gr.Row():
clear_history = gr.Button(locale_mapping['clear_history'][lang])
submit = gr.Button(locale_mapping['submit'][lang])

system_state = gr.State(value=default_system)
textbox.submit(model_chat, inputs=[textbox, chatbot, system_state], outputs=[textbox, chatbot, system_input])

submit.click(
model_chat,
inputs=[textbox, chatbot, system_state],
outputs=[textbox, chatbot, system_input],
concurrency_limit=5)
clear_history.click(fn=clear_session, inputs=[], outputs=[textbox, chatbot])
modify_system.click(
fn=modify_system_session,
inputs=[system_input, default_system],
outputs=[system_state, system_input, chatbot])
return demo
2 changes: 1 addition & 1 deletion swift/llm/argument/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .app_args import AppArguments
from .base_args import BaseArguments
from .deploy_args import DeployArguments
from .eval_args import EvalArguments
Expand All @@ -8,4 +9,3 @@
from .train_args import TrainArguments
from .tuner_args import TunerArguments
from .webui_args import WebUIArguments
from .app_args import AppArguments
1 change: 0 additions & 1 deletion swift/llm/argument/app_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class AppArguments(DeployArguments):
api_url: Optional[str] = None
studio_title: Optional[str] = None


def _init_torch_dtype(self) -> None:
if self.api_url:
return
Expand Down
6 changes: 5 additions & 1 deletion swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ class Seq2SeqTrainingOverrideArguments(Seq2SeqTrainingArguments):
lr_scheduler_kwargs: Optional[Union[dict, str]] = None
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
report_to: List[str] = field(default_factory=lambda: ['tensorboard'])
eval_strategy: Optional[str] = None # steps, epoch

remove_unused_columns: bool = False
logging_first_step: bool = True
eval_strategy: Optional[str] = None # steps, epoch
average_tokens_across_devices: Optional[bool] = None

def _init_output_dir(self):
if self.output_dir is not None:
Expand All @@ -54,6 +56,8 @@ def _init_eval_strategy(self):

def __post_init__(self):
self._init_output_dir()
if self.average_tokens_across_devices is None:
self.average_tokens_across_devices = self.world_size > 1
if self.metric_for_best_model is None:
self.metric_for_best_model = 'rouge-l' if self.predict_with_generate else 'loss'
if self.greater_is_better is None:
Expand Down
15 changes: 7 additions & 8 deletions swift/llm/infer/infer_engine/infer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from copy import deepcopy
from dataclasses import asdict
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from urllib.parse import urljoin

import aiohttp
import json

from urllib.parse import urljoin
from dacite import from_dict
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -123,12 +122,12 @@ def _parse_stream_data(data: bytes) -> Optional[str]:
return data[5:].strip()

async def infer_async(
self,
infer_request: InferRequest,
request_config: Optional[RequestConfig] = None,
*,
model: Optional[str] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
self,
infer_request: InferRequest,
request_config: Optional[RequestConfig] = None,
*,
model: Optional[str] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
request_config = deepcopy(request_config or RequestConfig())
if model is None:
if len(self.models) == 1:
Expand Down

0 comments on commit 5ed1938

Please sign in to comment.