-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] Improve TP>1 Error Handling + Stack Trace #11721
Merged
robertgshaw2-neuralmagic
merged 44 commits into
vllm-project:main
from
neuralmagic:tp-shutdown
Jan 3, 2025
+40
−21
Merged
Changes from 39 commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
2d857cd
remove shutdown from LLMEngine
robertgshaw2-neuralmagic 9e70c5f
format
robertgshaw2-neuralmagic f34875c
no need for shutdown in asyncllm
robertgshaw2-neuralmagic 7a777d9
remove from asyncllm
robertgshaw2-neuralmagic dfc9dee
stash
robertgshaw2-neuralmagic c72b45a
update
robertgshaw2-neuralmagic 4e2dc00
fix
robertgshaw2-neuralmagic 0b4b6af
added back explicit del
robertgshaw2-neuralmagic 4c445af
stash
robertgshaw2-neuralmagic 567b424
working
robertgshaw2-neuralmagic 7d04b98
fix failing test
robertgshaw2-neuralmagic 62e1022
remove explicit shutdown calls
robertgshaw2-neuralmagic 0b0ca08
updated
robertgshaw2-neuralmagic 729938a
pdated
robertgshaw2-neuralmagic 0259241
update
robertgshaw2-neuralmagic 58e4b36
working
robertgshaw2-neuralmagic cacf6b0
updated
robertgshaw2-neuralmagic ccc747d
fixup
robertgshaw2-neuralmagic ddc2a97
fixup
robertgshaw2-neuralmagic af0d529
reduce cruft
robertgshaw2-neuralmagic 17e152b
updated
robertgshaw2-neuralmagic 37859d7
finish
robertgshaw2-neuralmagic c29f329
updated
robertgshaw2-neuralmagic 1c4b92a
updated
robertgshaw2-neuralmagic eb9b00b
stash
robertgshaw2-neuralmagic 1da99a8
updated
robertgshaw2-neuralmagic ca7b92d
Merge branch 'main' into tp-shutdown
robertgshaw2-neuralmagic 2743166
updated
robertgshaw2-neuralmagic 8e257c1
stash
robertgshaw2-neuralmagic b7c50dc
revert spurious change
robertgshaw2-neuralmagic dcfd3b8
updated
robertgshaw2-neuralmagic 6e0e0d4
stash
robertgshaw2-neuralmagic 55a6195
updated
robertgshaw2-neuralmagic aa6954f
updated
robertgshaw2-neuralmagic 1d15ae0
remove cruft
robertgshaw2-neuralmagic 0347baa
Update vllm/v1/executor/multiproc_executor.py
robertgshaw2-neuralmagic 20b8fa2
stash
robertgshaw2-neuralmagic 32840f2
Merge branch 'tp-shutdown' of https://github.com/neuralmagic/vllm int…
robertgshaw2-neuralmagic 884879a
switch to SIGUSR1
robertgshaw2-neuralmagic bb86a03
updated
robertgshaw2-neuralmagic 405bcc1
Update vllm/v1/engine/core_client.py
robertgshaw2-neuralmagic 25e0fea
update message
robertgshaw2-neuralmagic efd6270
updated
robertgshaw2-neuralmagic a5a306e
fixed!
robertgshaw2-neuralmagic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
from multiprocessing.process import BaseProcess | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import psutil | ||
import zmq | ||
|
||
from vllm.config import VllmConfig | ||
|
@@ -38,6 +39,19 @@ def __init__(self, vllm_config: VllmConfig) -> None: | |
# and ensure workers will be terminated. | ||
self._finalizer = weakref.finalize(self, self.shutdown) | ||
|
||
# The child processes will send SIGUSR1 when unrecoverable | ||
# errors happen. | ||
def sigusr1_handler(signum, frame): | ||
logger.fatal( | ||
"MulitprocExecutor got SIGUSR1 from worker processes, shutting " | ||
"down. See stack trace above for root cause issue.") | ||
# Propagate error up to parent process. | ||
parent_process = psutil.Process().parent() | ||
parent_process.send_signal(signal.SIGUSR1) | ||
self.shutdown() | ||
|
||
signal.signal(signal.SIGUSR1, sigusr1_handler) | ||
|
||
self.vllm_config = vllm_config | ||
self.parallel_config = vllm_config.parallel_config | ||
|
||
|
@@ -335,8 +349,11 @@ def signal_handler(signum, frame): | |
except SystemExit: | ||
logger.debug("Worker interrupted.") | ||
|
||
except BaseException as e: | ||
logger.exception(e) | ||
except Exception: | ||
# worker_busy_loop sends exceptions exceptons to Executor | ||
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# for shutdown, but if there is an error in startup or an | ||
# error with IPC itself, we need to alert the parent. | ||
psutil.Process().parent().send_signal(signal.SIGUSR1) | ||
raise | ||
|
||
finally: | ||
|
@@ -377,9 +394,10 @@ def worker_busy_loop(self): | |
|
||
try: | ||
output = getattr(self.worker, method)(*args, **kwargs) | ||
except BaseException as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
except Exception as e: | ||
self.worker_response_mq.enqueue( | ||
(WorkerProc.ResponseStatus.FAILURE, e)) | ||
logger.exception("WorkerProc hit an exception: %s", exc_info=e) | ||
continue | ||
|
||
self.worker_response_mq.enqueue( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,8 @@ def __init__( | |
distributed_init_method: str, | ||
): | ||
|
||
self.i = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NOTE FOR REVIEWER: this is just a simple POC to show an example. Will remove this before landing. |
||
|
||
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) | ||
self.vllm_config = vllm_config | ||
self.model_config = vllm_config.model_config | ||
|
@@ -201,6 +203,10 @@ def execute_model( | |
self, | ||
scheduler_output: "SchedulerOutput", | ||
) -> ModelRunnerOutput: | ||
if self.rank == 0 and self.i == 10: | ||
raise ValueError("ERROR FROM HERE :)") | ||
self.i += 1 | ||
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
output = self.model_runner.execute_model(scheduler_output) | ||
return output if self.rank == 0 else None | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: moved to
CoreClient
so that it can be shared acrossAsyncLLM
andLLMEngine