From 5ed193873c27efa702fe544def85362478c39808 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 28 Dec 2024 20:20:11 +0800 Subject: [PATCH] update --- ...44\350\241\214\345\217\202\346\225\260.md" | 1 + .../Instruction/Command-line-parameters.md | 1 + swift/llm/__init__.py | 6 +- swift/llm/app/__init__.py | 1 + swift/llm/app/app.py | 9 ++- swift/llm/app/llm_ui.py | 81 +++++++++++++++++++ swift/llm/argument/__init__.py | 2 +- swift/llm/argument/app_args.py | 1 - swift/llm/argument/train_args.py | 6 +- swift/llm/infer/infer_engine/infer_client.py | 15 ++-- 10 files changed, 109 insertions(+), 14 deletions(-) create mode 100644 swift/llm/app/__init__.py diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index db8dd4201..ad14059d5 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -99,6 +99,7 @@ - remove_unused_columns: 默认值False - logging_first_step: 是否记录第一个step的打印,默认值True - logging_steps: 日志打印间隔,默认值5 +- average_tokens_across_devices: 是否在设备之间对token数进行平均。如果设置为True,将使用all_reduce同步`num_tokens_in_batch`以进行精确的损失计算。默认为None,如果为分布式训练则设置为True,否则为False - metric_for_best_model: 默认为None. 即当`predict_with_generate`设置为False, 则为'loss', 否则设置为'rouge-l' - greater_is_better: 默认为None. 即当`metric_for_best_model`含'loss'时, 设置为False, 否则设置为True. diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 28be6ae39..1f559d602 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -100,6 +100,7 @@ This parameter list inherits from transformers `Seq2SeqTrainingArguments`, with - remove_unused_columns: Default is False. - logging_first_step: Whether to log the first step print, default is True. - logging_steps: Interval for logging prints, default is 5. +- average_tokens_across_devices: Whether to average the token count across devices. If set to True, it will use all_reduce to synchronize `num_tokens_in_batch` for accurate loss computation. The default is None; set to True for distributed training, otherwise set to False. - metric_for_best_model: Default is None. When `predict_with_generate` is set to False, it is 'loss'; otherwise, it is 'rouge-l'. - greater_is_better: Default is None. When `metric_for_best_model` contains 'loss', set to False; otherwise, set to True. diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index 8e8145b6a..86d90a463 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -9,9 +9,10 @@ InferClient, run_deploy, AdapterRequest, prepare_model_template) from .export import (export_main, merge_lora, quantize_model, export_to_ollama) from .eval import eval_main + from .app import app_main from .train import sft_main, pt_main, rlhf_main, get_multimodal_target_regex from .argument import (EvalArguments, InferArguments, TrainArguments, ExportArguments, DeployArguments, - RLHFArguments, WebUIArguments, BaseArguments) + RLHFArguments, WebUIArguments, BaseArguments, AppArguments) from .template import (TEMPLATE_MAPPING, Template, Word, get_template, TemplateType, register_template, TemplateInputs, Messages, TemplateMeta, get_template_meta, InferRequest, load_image, MaxLengthError, load_file) @@ -36,11 +37,12 @@ 'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template' ], 'export': ['export_main', 'merge_lora', 'quantize_model', 'export_to_ollama'], + 'app': ['app_main'], 'eval': ['eval_main'], 'train': ['sft_main', 'pt_main', 'rlhf_main', 'get_multimodal_target_regex'], 'argument': [ 'EvalArguments', 'InferArguments', 'TrainArguments', 'ExportArguments', 'WebUIArguments', 'DeployArguments', - 'RLHFArguments', 'BaseArguments' + 'RLHFArguments', 'BaseArguments', 'AppArguments' ], 'template': [ 'TEMPLATE_MAPPING', 'Template', 'Word', 'get_template', 'TemplateType', 'register_template', diff --git a/swift/llm/app/__init__.py b/swift/llm/app/__init__.py new file mode 100644 index 000000000..5f71930ec --- /dev/null +++ b/swift/llm/app/__init__.py @@ -0,0 +1 @@ +from .app import app_main, SwiftApp diff --git a/swift/llm/app/app.py b/swift/llm/app/app.py index d6c565f0d..bea1f70fd 100644 --- a/swift/llm/app/app.py +++ b/swift/llm/app/app.py @@ -7,6 +7,7 @@ from ..argument import AppArguments from ..base import SwiftPipeline from ..infer import run_deploy +from .llm_ui import build_llm_ui logger = get_logger() @@ -17,4 +18,10 @@ class SwiftApp(SwiftPipeline): def run(self): args = self.args - deploy_context = nullcontext() if args.eval_url else run_deploy(self.args, return_url=True) + deploy_context = nullcontext() if args.api_url else run_deploy(self.args, return_url=True) + demo = build_llm_ui() + demo.queue().launch() + + +def app_main(args: Union[List[str], AppArguments, None] = None): + return SwiftApp(args).main() diff --git a/swift/llm/app/llm_ui.py b/swift/llm/app/llm_ui.py index e69de29bb..6e48e2c9a 100644 --- a/swift/llm/app/llm_ui.py +++ b/swift/llm/app/llm_ui.py @@ -0,0 +1,81 @@ +import os +from http import HTTPStatus +from typing import Dict, List, Optional, Tuple + +import gradio as gr + +History = List[Tuple[str, str]] +Messages = List[Dict[str, str]] + + +def clear_session(): + return '', [] + + +def modify_system_session(system: str, default_system: str): + system = system or default_system + return system, system, [] + + +def model_chat(query: str, history: History, system: str) -> Tuple[str, str, History]: + query = query or '' + history = history or [] + history.append([query, None]) + messages = history_to_messages(history, system) + gen = Generation.call(model='qwen2-72b-instruct', messages=messages, result_format='message', stream=True) + for response in gen: + if response.status_code == HTTPStatus.OK: + role = response.output.choices[0].message.role + response = response.output.choices[0].message.content + system, history = messages_to_history(messages + [{'role': role, 'content': response}]) + yield '', history, system + else: + raise ValueError('Request id: %s, Status code: %s, error code: %s, error message: %s' % + (response.request_id, response.status_code, response.code, response.message)) + + +locale_mapping = { + 'modify_system': { + 'en': '🛠️ Set system and clear history', + 'zh': '🛠️ 设置system并清空历史' + }, + 'clear_history': { + 'en': '🧹 Clear history', + 'zh': '🧹 清空历史' + }, + 'submit': { + 'en': '🚀 Send', + 'zh': '🚀 发送' + }, +} + + +def build_llm_ui(studio_title: str, *, lang: str = 'zh', default_system: Optional[str] = None): + with gr.Blocks() as demo: + gr.Markdown(f'
{studio_title}
') + with gr.Row(): + with gr.Column(scale=3): + system_input = gr.Textbox(value=default_system, lines=1, label='System') + with gr.Column(scale=1): + modify_system = gr.Button(locale_mapping['modify_system'][lang], scale=2) + chatbot = gr.Chatbot(label='Chatbot') + textbox = gr.Textbox(lines=1, label='Input') + + with gr.Row(): + clear_history = gr.Button(locale_mapping['clear_history'][lang]) + submit = gr.Button(locale_mapping['submit'][lang]) + + system_state = gr.State(value=default_system) + textbox.submit(model_chat, inputs=[textbox, chatbot, system_state], outputs=[textbox, chatbot, system_input]) + + submit.click( + model_chat, + inputs=[textbox, chatbot, system_state], + outputs=[textbox, chatbot, system_input], + concurrency_limit=5) + clear_history.click(fn=clear_session, inputs=[], outputs=[textbox, chatbot]) + modify_system.click( + fn=modify_system_session, + inputs=[system_input, default_system], + outputs=[system_state, system_input, chatbot]) + return demo diff --git a/swift/llm/argument/__init__.py b/swift/llm/argument/__init__.py index d9f30c526..987cf828e 100644 --- a/swift/llm/argument/__init__.py +++ b/swift/llm/argument/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .app_args import AppArguments from .base_args import BaseArguments from .deploy_args import DeployArguments from .eval_args import EvalArguments @@ -8,4 +9,3 @@ from .train_args import TrainArguments from .tuner_args import TunerArguments from .webui_args import WebUIArguments -from .app_args import AppArguments diff --git a/swift/llm/argument/app_args.py b/swift/llm/argument/app_args.py index 2308461dd..11331c8d5 100644 --- a/swift/llm/argument/app_args.py +++ b/swift/llm/argument/app_args.py @@ -13,7 +13,6 @@ class AppArguments(DeployArguments): api_url: Optional[str] = None studio_title: Optional[str] = None - def _init_torch_dtype(self) -> None: if self.api_url: return diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index a2deb20b5..f99fbe062 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -32,9 +32,11 @@ class Seq2SeqTrainingOverrideArguments(Seq2SeqTrainingArguments): lr_scheduler_kwargs: Optional[Union[dict, str]] = None gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None report_to: List[str] = field(default_factory=lambda: ['tensorboard']) + eval_strategy: Optional[str] = None # steps, epoch + remove_unused_columns: bool = False logging_first_step: bool = True - eval_strategy: Optional[str] = None # steps, epoch + average_tokens_across_devices: Optional[bool] = None def _init_output_dir(self): if self.output_dir is not None: @@ -54,6 +56,8 @@ def _init_eval_strategy(self): def __post_init__(self): self._init_output_dir() + if self.average_tokens_across_devices is None: + self.average_tokens_across_devices = self.world_size > 1 if self.metric_for_best_model is None: self.metric_for_best_model = 'rouge-l' if self.predict_with_generate else 'loss' if self.greater_is_better is None: diff --git a/swift/llm/infer/infer_engine/infer_client.py b/swift/llm/infer/infer_engine/infer_client.py index 0244bc4c3..958afa347 100644 --- a/swift/llm/infer/infer_engine/infer_client.py +++ b/swift/llm/infer/infer_engine/infer_client.py @@ -3,11 +3,10 @@ from copy import deepcopy from dataclasses import asdict from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from urllib.parse import urljoin import aiohttp import json - -from urllib.parse import urljoin from dacite import from_dict from requests.exceptions import HTTPError @@ -123,12 +122,12 @@ def _parse_stream_data(data: bytes) -> Optional[str]: return data[5:].strip() async def infer_async( - self, - infer_request: InferRequest, - request_config: Optional[RequestConfig] = None, - *, - model: Optional[str] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: + self, + infer_request: InferRequest, + request_config: Optional[RequestConfig] = None, + *, + model: Optional[str] = None, + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: request_config = deepcopy(request_config or RequestConfig()) if model is None: if len(self.models) == 1: