Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix app ui #2780

Merged
merged 7 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions swift/llm/infer/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,4 @@ def func(target, queue, args, kwargs):

@staticmethod
def safe_asyncio_run(coro):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop:
result = InferEngine.thread_run(asyncio.run, args=(coro, ))
else:
result = asyncio.run(coro)
return result
return InferEngine.thread_run(asyncio.run, args=(coro, ))
18 changes: 15 additions & 3 deletions swift/ui/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import copy
from dataclasses import fields
from functools import partial
from typing import List, Union
Expand Down Expand Up @@ -72,6 +73,7 @@ def run(self):
for f in fields(self.args):
if getattr(self.args, f.name):
LLMInfer.default_dict[f.name] = getattr(self.args, f.name)

LLMInfer.is_gradio_app = True
LLMInfer.is_multimodal = self.args.model_meta.is_multimodal
LLMInfer.build_ui(LLMInfer)
Expand All @@ -94,9 +96,19 @@ def run(self):
if isinstance(value, list):
value = ' '.join([v or '' for v in value])
LLMInfer.elements()[f.name].value = value
app.load(LLMInfer.deploy_model, list(LLMInfer.valid_elements().values()),
[LLMInfer.element('runtime_tab'),
LLMInfer.element('running_tasks')])

args = copy(self.args)
args.port = find_free_port()

values = []
for key in LLMInfer.valid_elements():
if key in args.__dict__:
value = getattr(args, key)
else:
value = LLMInfer.element(key).value
values.append(value)
_, running_task = LLMInfer.deploy_model(*values)
LLMInfer.element('running_tasks').value = running_task['value']
else:
app.load(
partial(LLMTrain.update_input_model, arg_cls=RLHFArguments),
Expand Down
13 changes: 7 additions & 6 deletions swift/ui/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import sys
import time
import typing
from collections import OrderedDict
from dataclasses import fields
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List, OrderedDict, Type
from typing import Any, Dict, List, Type

import gradio as gr
import json
Expand Down Expand Up @@ -220,12 +221,12 @@ def elements(cls):

@classmethod
def valid_elements(cls):
valid_elements = OrderedDict()
elements = cls.elements()
return {
key: value
for key, value in elements.items()
if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record'
}
for key, value in elements.items():
if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record':
valid_elements[key] = value
return valid_elements

@classmethod
def element_keys(cls):
Expand Down
26 changes: 26 additions & 0 deletions swift/ui/llm_infer/llm_infer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import atexit
import os
import re
import signal
import sys
import time
from copy import deepcopy
Expand Down Expand Up @@ -298,13 +300,37 @@ def deploy_model(cls, *args):
cnt += 1
if cnt >= 60:
logger.warning_once(f'Deploy costing too much time, please check log file: {log_file}')
if cls.is_gradio_app:
cls.register_clean_hook()
logger.info('Deploy done.')
cls.deployed = True
running_task = Runtime.refresh_tasks(log_file)
if cls.is_gradio_app:
cls.running_task = running_task['value']
return gr.update(open=True), running_task

@classmethod
def clean_deployment(cls):
if not cls.is_gradio_app:
return

logger.info('Killing deployment')
_, args = Runtime.parse_info_from_cmdline(cls.running_task)
os.system(f'pkill -9 -f {args["log_file"]}')
logger.info('Done.')

@classmethod
def register_clean_hook(cls):
atexit.register(LLMInfer.clean_deployment)
signal.signal(signal.SIGINT, LLMInfer.signal_handler)
if os.name != 'nt':
signal.signal(signal.SIGTERM, LLMInfer.signal_handler)

@staticmethod
def signal_handler(*args, **kwargs):
LLMInfer.clean_deployment()
sys.exit(0)

@classmethod
def clear_session(cls):
return '', [], gr.update(value=None), gr.update(value=None), gr.update(value=None), []
Expand Down
Loading