diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index adee066c9..bf089b720 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -244,12 +244,4 @@ def func(target, queue, args, kwargs): @staticmethod def safe_asyncio_run(coro): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - if loop: - result = InferEngine.thread_run(asyncio.run, args=(coro, )) - else: - result = asyncio.run(coro) - return result + return InferEngine.thread_run(asyncio.run, args=(coro, )) diff --git a/swift/ui/app.py b/swift/ui/app.py index 92d87e5a4..3de542210 100644 --- a/swift/ui/app.py +++ b/swift/ui/app.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +from copy import copy from dataclasses import fields from functools import partial from typing import List, Union @@ -72,6 +73,7 @@ def run(self): for f in fields(self.args): if getattr(self.args, f.name): LLMInfer.default_dict[f.name] = getattr(self.args, f.name) + LLMInfer.is_gradio_app = True LLMInfer.is_multimodal = self.args.model_meta.is_multimodal LLMInfer.build_ui(LLMInfer) @@ -94,9 +96,19 @@ def run(self): if isinstance(value, list): value = ' '.join([v or '' for v in value]) LLMInfer.elements()[f.name].value = value - app.load(LLMInfer.deploy_model, list(LLMInfer.valid_elements().values()), - [LLMInfer.element('runtime_tab'), - LLMInfer.element('running_tasks')]) + + args = copy(self.args) + args.port = find_free_port() + + values = [] + for key in LLMInfer.valid_elements(): + if key in args.__dict__: + value = getattr(args, key) + else: + value = LLMInfer.element(key).value + values.append(value) + _, running_task = LLMInfer.deploy_model(*values) + LLMInfer.element('running_tasks').value = running_task['value'] else: app.load( partial(LLMTrain.update_input_model, arg_cls=RLHFArguments), diff --git a/swift/ui/base.py b/swift/ui/base.py index ed6c389ac..508b61f4b 100644 --- a/swift/ui/base.py +++ b/swift/ui/base.py @@ -4,10 +4,11 @@ import sys import time import typing +from collections import OrderedDict from dataclasses import fields from datetime import datetime from functools import wraps -from typing import Any, Dict, List, OrderedDict, Type +from typing import Any, Dict, List, Type import gradio as gr import json @@ -220,12 +221,12 @@ def elements(cls): @classmethod def valid_elements(cls): + valid_elements = OrderedDict() elements = cls.elements() - return { - key: value - for key, value in elements.items() - if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record' - } + for key, value in elements.items(): + if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record': + valid_elements[key] = value + return valid_elements @classmethod def element_keys(cls): diff --git a/swift/ui/llm_infer/llm_infer.py b/swift/ui/llm_infer/llm_infer.py index 6b6b609fe..c366ddd7e 100644 --- a/swift/ui/llm_infer/llm_infer.py +++ b/swift/ui/llm_infer/llm_infer.py @@ -1,6 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import atexit import os import re +import signal import sys import time from copy import deepcopy @@ -298,6 +300,8 @@ def deploy_model(cls, *args): cnt += 1 if cnt >= 60: logger.warning_once(f'Deploy costing too much time, please check log file: {log_file}') + if cls.is_gradio_app: + cls.register_clean_hook() logger.info('Deploy done.') cls.deployed = True running_task = Runtime.refresh_tasks(log_file) @@ -305,6 +309,28 @@ def deploy_model(cls, *args): cls.running_task = running_task['value'] return gr.update(open=True), running_task + @classmethod + def clean_deployment(cls): + if not cls.is_gradio_app: + return + + logger.info('Killing deployment') + _, args = Runtime.parse_info_from_cmdline(cls.running_task) + os.system(f'pkill -9 -f {args["log_file"]}') + logger.info('Done.') + + @classmethod + def register_clean_hook(cls): + atexit.register(LLMInfer.clean_deployment) + signal.signal(signal.SIGINT, LLMInfer.signal_handler) + if os.name != 'nt': + signal.signal(signal.SIGTERM, LLMInfer.signal_handler) + + @staticmethod + def signal_handler(*args, **kwargs): + LLMInfer.clean_deployment() + sys.exit(0) + @classmethod def clear_session(cls): return '', [], gr.update(value=None), gr.update(value=None), gr.update(value=None), []