Skip to content

Commit

Permalink
Fix app ui (#2780)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Dec 26, 2024
1 parent 5394585 commit 1e2fda1
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 18 deletions.
10 changes: 1 addition & 9 deletions swift/llm/infer/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,4 @@ def func(target, queue, args, kwargs):

@staticmethod
def safe_asyncio_run(coro):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop:
result = InferEngine.thread_run(asyncio.run, args=(coro, ))
else:
result = asyncio.run(coro)
return result
return InferEngine.thread_run(asyncio.run, args=(coro, ))
18 changes: 15 additions & 3 deletions swift/ui/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import copy
from dataclasses import fields
from functools import partial
from typing import List, Union
Expand Down Expand Up @@ -72,6 +73,7 @@ def run(self):
for f in fields(self.args):
if getattr(self.args, f.name):
LLMInfer.default_dict[f.name] = getattr(self.args, f.name)

LLMInfer.is_gradio_app = True
LLMInfer.is_multimodal = self.args.model_meta.is_multimodal
LLMInfer.build_ui(LLMInfer)
Expand All @@ -94,9 +96,19 @@ def run(self):
if isinstance(value, list):
value = ' '.join([v or '' for v in value])
LLMInfer.elements()[f.name].value = value
app.load(LLMInfer.deploy_model, list(LLMInfer.valid_elements().values()),
[LLMInfer.element('runtime_tab'),
LLMInfer.element('running_tasks')])

args = copy(self.args)
args.port = find_free_port()

values = []
for key in LLMInfer.valid_elements():
if key in args.__dict__:
value = getattr(args, key)
else:
value = LLMInfer.element(key).value
values.append(value)
_, running_task = LLMInfer.deploy_model(*values)
LLMInfer.element('running_tasks').value = running_task['value']
else:
app.load(
partial(LLMTrain.update_input_model, arg_cls=RLHFArguments),
Expand Down
13 changes: 7 additions & 6 deletions swift/ui/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import sys
import time
import typing
from collections import OrderedDict
from dataclasses import fields
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List, OrderedDict, Type
from typing import Any, Dict, List, Type

import gradio as gr
import json
Expand Down Expand Up @@ -220,12 +221,12 @@ def elements(cls):

@classmethod
def valid_elements(cls):
valid_elements = OrderedDict()
elements = cls.elements()
return {
key: value
for key, value in elements.items()
if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record'
}
for key, value in elements.items():
if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record':
valid_elements[key] = value
return valid_elements

@classmethod
def element_keys(cls):
Expand Down
26 changes: 26 additions & 0 deletions swift/ui/llm_infer/llm_infer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import atexit
import os
import re
import signal
import sys
import time
from copy import deepcopy
Expand Down Expand Up @@ -298,13 +300,37 @@ def deploy_model(cls, *args):
cnt += 1
if cnt >= 60:
logger.warning_once(f'Deploy costing too much time, please check log file: {log_file}')
if cls.is_gradio_app:
cls.register_clean_hook()
logger.info('Deploy done.')
cls.deployed = True
running_task = Runtime.refresh_tasks(log_file)
if cls.is_gradio_app:
cls.running_task = running_task['value']
return gr.update(open=True), running_task

@classmethod
def clean_deployment(cls):
if not cls.is_gradio_app:
return

logger.info('Killing deployment')
_, args = Runtime.parse_info_from_cmdline(cls.running_task)
os.system(f'pkill -9 -f {args["log_file"]}')
logger.info('Done.')

@classmethod
def register_clean_hook(cls):
atexit.register(LLMInfer.clean_deployment)
signal.signal(signal.SIGINT, LLMInfer.signal_handler)
if os.name != 'nt':
signal.signal(signal.SIGTERM, LLMInfer.signal_handler)

@staticmethod
def signal_handler(*args, **kwargs):
LLMInfer.clean_deployment()
sys.exit(0)

@classmethod
def clear_session(cls):
return '', [], gr.update(value=None), gr.update(value=None), gr.update(value=None), []
Expand Down

0 comments on commit 1e2fda1

Please sign in to comment.