Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Improve TP>1 Error Handling + Stack Trace #11721

Merged
merged 44 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
2d857cd
remove shutdown from LLMEngine
robertgshaw2-neuralmagic Dec 31, 2024
9e70c5f
format
robertgshaw2-neuralmagic Dec 31, 2024
f34875c
no need for shutdown in asyncllm
robertgshaw2-neuralmagic Dec 31, 2024
7a777d9
remove from asyncllm
robertgshaw2-neuralmagic Dec 31, 2024
dfc9dee
stash
robertgshaw2-neuralmagic Dec 31, 2024
c72b45a
update
robertgshaw2-neuralmagic Dec 31, 2024
4e2dc00
fix
robertgshaw2-neuralmagic Dec 31, 2024
0b4b6af
added back explicit del
robertgshaw2-neuralmagic Dec 31, 2024
4c445af
stash
robertgshaw2-neuralmagic Dec 31, 2024
567b424
working
robertgshaw2-neuralmagic Jan 1, 2025
7d04b98
fix failing test
robertgshaw2-neuralmagic Jan 3, 2025
62e1022
remove explicit shutdown calls
robertgshaw2-neuralmagic Jan 3, 2025
0b0ca08
updated
robertgshaw2-neuralmagic Jan 3, 2025
729938a
pdated
robertgshaw2-neuralmagic Jan 3, 2025
0259241
update
robertgshaw2-neuralmagic Jan 3, 2025
58e4b36
working
robertgshaw2-neuralmagic Jan 3, 2025
cacf6b0
updated
robertgshaw2-neuralmagic Jan 3, 2025
ccc747d
fixup
robertgshaw2-neuralmagic Jan 3, 2025
ddc2a97
fixup
robertgshaw2-neuralmagic Jan 3, 2025
af0d529
reduce cruft
robertgshaw2-neuralmagic Jan 3, 2025
17e152b
updated
robertgshaw2-neuralmagic Jan 3, 2025
37859d7
finish
robertgshaw2-neuralmagic Jan 3, 2025
c29f329
updated
robertgshaw2-neuralmagic Jan 3, 2025
1c4b92a
updated
robertgshaw2-neuralmagic Jan 3, 2025
eb9b00b
stash
robertgshaw2-neuralmagic Jan 3, 2025
1da99a8
updated
robertgshaw2-neuralmagic Jan 3, 2025
ca7b92d
Merge branch 'main' into tp-shutdown
robertgshaw2-neuralmagic Jan 3, 2025
2743166
updated
robertgshaw2-neuralmagic Jan 3, 2025
8e257c1
stash
robertgshaw2-neuralmagic Jan 3, 2025
b7c50dc
revert spurious change
robertgshaw2-neuralmagic Jan 3, 2025
dcfd3b8
updated
robertgshaw2-neuralmagic Jan 3, 2025
6e0e0d4
stash
robertgshaw2-neuralmagic Jan 3, 2025
55a6195
updated
robertgshaw2-neuralmagic Jan 3, 2025
aa6954f
updated
robertgshaw2-neuralmagic Jan 3, 2025
1d15ae0
remove cruft
robertgshaw2-neuralmagic Jan 3, 2025
0347baa
Update vllm/v1/executor/multiproc_executor.py
robertgshaw2-neuralmagic Jan 3, 2025
20b8fa2
stash
robertgshaw2-neuralmagic Jan 3, 2025
32840f2
Merge branch 'tp-shutdown' of https://github.com/neuralmagic/vllm int…
robertgshaw2-neuralmagic Jan 3, 2025
884879a
switch to SIGUSR1
robertgshaw2-neuralmagic Jan 3, 2025
bb86a03
updated
robertgshaw2-neuralmagic Jan 3, 2025
405bcc1
Update vllm/v1/engine/core_client.py
robertgshaw2-neuralmagic Jan 3, 2025
25e0fea
update message
robertgshaw2-neuralmagic Jan 3, 2025
efd6270
updated
robertgshaw2-neuralmagic Jan 3, 2025
a5a306e
fixed!
robertgshaw2-neuralmagic Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import os
import signal
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union

from vllm.config import ModelConfig, VllmConfig
Expand Down Expand Up @@ -42,21 +41,6 @@ def __init__(
start_engine_loop: bool = True,
) -> None:

# The child processes will send SIGQUIT when unrecoverable
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: moved to CoreClient so that it can be shared across AsyncLLM and LLMEngine

# errors happen. We kill the process tree here so that the
# stack trace is very evident.
# TODO: rather than killing the main process, we should
# figure out how to raise an AsyncEngineDeadError and
# handle at the API server level so we can return a better
# error code to the clients calling VLLM.
def sigquit_handler(signum, frame):
logger.fatal(
"AsyncLLM got SIGQUIT from worker processes, shutting "
"down. See stack trace above for root cause issue.")
kill_process_tree(os.getpid())

signal.signal(signal.SIGQUIT, sigquit_handler)

assert start_engine_loop

self.log_requests = log_requests
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def signal_handler(signum, frame):
except Exception:
traceback = get_exception_traceback()
logger.error("EngineCore hit an exception: %s", traceback)
parent_process.send_signal(signal.SIGQUIT)
parent_process.send_signal(signal.SIGUSR1)

finally:
if engine_core is not None:
Expand Down
19 changes: 18 additions & 1 deletion vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import signal
import weakref
from abc import ABC, abstractmethod
from typing import List, Type
Expand All @@ -8,7 +10,8 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path, make_zmq_socket
from vllm.utils import (get_open_zmq_ipc_path, make_zmq_socket,
kill_process_tree)
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestUnion)
Expand Down Expand Up @@ -134,6 +137,20 @@ def __init__(
executor_class: Type[Executor],
log_stats: bool = False,
):
# The child processes will send SIGUSR1 when unrecoverable
# errors happen. We kill the process tree here so that the
# stack trace is very evident.
# TODO(rob): rather than killing the main process, we should
# figure out how to raise an AsyncEngineDeadError and
# handle at the API server level so we can return a better
# error code to the clients calling VLLM.
def sigusr1_handler(signum, frame):
logger.fatal(
"Got SIGUSR1 from worker processes, shutting "
"down. See stack trace above for root cause issue.")
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved
kill_process_tree(os.getpid())
signal.signal(signal.SIGUSR1, sigusr1_handler)

# Serialization setup.
self.encoder = PickleEncoder()
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
Expand Down
24 changes: 21 additions & 3 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from multiprocessing.process import BaseProcess
from typing import Any, Dict, List, Optional, Tuple

import psutil
import zmq

from vllm.config import VllmConfig
Expand Down Expand Up @@ -38,6 +39,19 @@ def __init__(self, vllm_config: VllmConfig) -> None:
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)

# The child processes will send SIGUSR1 when unrecoverable
# errors happen.
def sigusr1_handler(signum, frame):
logger.fatal(
"MulitprocExecutor got SIGUSR1 from worker processes, shutting "
"down. See stack trace above for root cause issue.")
# Propagate error up to parent process.
parent_process = psutil.Process().parent()
parent_process.send_signal(signal.SIGUSR1)
self.shutdown()

signal.signal(signal.SIGUSR1, sigusr1_handler)

self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config

Expand Down Expand Up @@ -335,8 +349,11 @@ def signal_handler(signum, frame):
except SystemExit:
logger.debug("Worker interrupted.")

except BaseException as e:
logger.exception(e)
except Exception:
# worker_busy_loop sends exceptions exceptons to Executor
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved
# for shutdown, but if there is an error in startup or an
# error with IPC itself, we need to alert the parent.
psutil.Process().parent().send_signal(signal.SIGUSR1)
raise

finally:
Expand Down Expand Up @@ -377,9 +394,10 @@ def worker_busy_loop(self):

try:
output = getattr(self.worker, method)(*args, **kwargs)
except BaseException as e:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: we should not catch BaseException since it is too broad, per professor Gemini

image

except Exception as e:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e))
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
continue

self.worker_response_mq.enqueue(
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(
distributed_init_method: str,
):

self.i = 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE FOR REVIEWER: this is just a simple POC to show an example. Will remove this before landing.


# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
Expand Down Expand Up @@ -201,6 +203,10 @@ def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
if self.rank == 0 and self.i == 10:
raise ValueError("ERROR FROM HERE :)")
self.i += 1
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved

output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None

Expand Down
Loading