diff --git a/swift/ui/llm_train/llm_train.py b/swift/ui/llm_train/llm_train.py index b0d52c36f..d5cf5ec98 100644 --- a/swift/ui/llm_train/llm_train.py +++ b/swift/ui/llm_train/llm_train.py @@ -243,6 +243,14 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): if not isinstance(value, (Tab, Accordion)) ] + [cls.element('log')] + Runtime.all_plots, cancels=Runtime.log_event) + Runtime.element('kill_task').click( + Runtime.kill_task, + [Runtime.element('running_tasks')], + [Runtime.element('running_tasks')] + + [Runtime.element('log')] + Runtime.all_plots, + cancels=[Runtime.log_event], + ).then(Runtime.reset, [], [Runtime.element('logging_dir')] + + [Save.element('output_dir')]) @classmethod def update_runtime(cls): diff --git a/swift/ui/llm_train/runtime.py b/swift/ui/llm_train/runtime.py index a8d17869e..b638c6e5b 100644 --- a/swift/ui/llm_train/runtime.py +++ b/swift/ui/llm_train/runtime.py @@ -248,14 +248,6 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): [base_tab.element('running_tasks')], ) - base_tab.element('kill_task').click( - Runtime.kill_task, - [base_tab.element('running_tasks')], - [base_tab.element('running_tasks')] + [cls.element('log')] - + cls.all_plots, - cancels=[cls.log_event], - ) - @classmethod def update_log(cls): return [gr.update(visible=True)] * (len(Runtime.sft_plot) + 1) @@ -411,6 +403,10 @@ def kill_task(task): return [Runtime.refresh_tasks()] + [gr.update(value=None)] * ( len(Runtime.sft_plot) + 1) + @staticmethod + def reset(): + return None, 'output' + @staticmethod def task_changed(task, base_tab): if task: