Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Improve TP>1 Error Handling + Stack Trace #11721

Merged
merged 44 commits into from
Jan 3, 2025

Conversation

robertgshaw2-neuralmagic
Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic commented Jan 3, 2025

SUMMARY:

  • handle startup error (VLLM currently hangs if there is an error initializing the model with TP>1)
  • handle runtime error in LLM when TP>1 (VLLM currently does not clean up all resources in this case)
  • improve WORKER runtime error stack trace for TP>1 (we currently do not see the root cause)

Resulting Stack Trace For "CUDA ERROR" in Worker (Runtime Error)

  • We can see that the exception in the Worker is logged in each process (the Worker (gpu_worker.py), the EngineCore (core.py) and the AsyncLLM, providing good visibility. This will help to understand the root cause error in cases like an illegal memory access.
INFO:     127.0.0.1:40178 - "GET /v1/models HTTP/1.1" 200 OK
INFO 01-03 18:37:00 logger.py:37] Received request cmpl-8304a46139974ffaa6b29bd39ab8dbf6-0: prompt: 'Hello my name is', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=100, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None), prompt_token_ids: [128000, 9906, 856, 836, 374], lora_request: None, prompt_adapter_request: None.
INFO:     127.0.0.1:40178 - "POST /v1/completions HTTP/1.1" 200 OK
INFO 01-03 18:37:01 async_llm.py:191] Added request cmpl-8304a46139974ffaa6b29bd39ab8dbf6-0.
INFO 01-03 18:37:02 core.py:247] RUNNING: 1 | WAITING: 0
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402] WorkerProc hit an exception: %s
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402] Traceback (most recent call last):
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402]   File "/home/rshaw/vllm/vllm/v1/executor/multiproc_executor.py", line 398, in worker_busy_loop
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402]     output = getattr(self.worker, method)(*args, **kwargs)
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402]   File "/home/rshaw/vllm/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402]     return func(*args, **kwargs)
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402]   File "/home/rshaw/vllm/vllm/v1/worker/gpu_worker.py", line 207, in execute_model
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402]     raise ValueError("ERROR FROM HERE :)")
(VllmWorker rank=0 pid=7290) ERROR 01-03 18:37:02 multiproc_executor.py:402] ValueError: ERROR FROM HERE :)
ERROR 01-03 18:37:02 core.py:200] EngineCore hit an exception: Traceback (most recent call last):
ERROR 01-03 18:37:02 core.py:200]   File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 193, in run_engine_core
ERROR 01-03 18:37:02 core.py:200]     engine_core.run_busy_loop()
ERROR 01-03 18:37:02 core.py:200]   File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 231, in run_busy_loop
ERROR 01-03 18:37:02 core.py:200]     outputs = self.step()
ERROR 01-03 18:37:02 core.py:200]               ^^^^^^^^^^^
ERROR 01-03 18:37:02 core.py:200]   File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 124, in step
ERROR 01-03 18:37:02 core.py:200]     output = self.model_executor.execute_model(scheduler_output)
ERROR 01-03 18:37:02 core.py:200]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-03 18:37:02 core.py:200]   File "/home/rshaw/vllm/vllm/v1/executor/multiproc_executor.py", line 166, in execute_model
ERROR 01-03 18:37:02 core.py:200]     model_output = self.collective_rpc("execute_model",
ERROR 01-03 18:37:02 core.py:200]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-03 18:37:02 core.py:200]   File "/home/rshaw/vllm/vllm/v1/executor/multiproc_executor.py", line 160, in collective_rpc
ERROR 01-03 18:37:02 core.py:200]     raise e
ERROR 01-03 18:37:02 core.py:200]   File "/home/rshaw/vllm/vllm/v1/executor/multiproc_executor.py", line 149, in collective_rpc
ERROR 01-03 18:37:02 core.py:200]     raise result
ERROR 01-03 18:37:02 core.py:200] ValueError: ERROR FROM HERE :)
ERROR 01-03 18:37:02 core.py:200] 
CRITICAL 01-03 18:37:02 async_llm.py:53] AsyncLLM got SIGQUIT from worker processes, shutting down. See stack trace above for root cause issue.
Killed

Resulting Stack Trace For "OOM" in Worker (Startup Error)

  • We can see that the root cause error is very clear
(venv) rshaw@beaker:~/vllm$ vllm serve $MODEL --tensor-parallel-size 2 --enforce-eager --port 8001 --trust-remote-code
INFO 01-03 18:47:17 api_server.py:764] vLLM API server version 0.6.4.post2.dev364+gea7bd68d1.d20241214
INFO 01-03 18:47:17 api_server.py:765] args: Namespace(subparser='serve', model_tag='deepseek-ai/DeepSeek-V3', config='', host=None, port=8001, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='deepseek-ai/DeepSeek-V3', task='auto', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=None, guided_decoding_backend='xgrammar', logits_processor_pattern=None, distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=2, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=None, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=True, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', generation_config=None, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, dispatch_function=<function serve at 0x7b6980249a80>)
INFO 01-03 18:47:18 __init__.py:179] Automatically detected platform cuda.
WARNING 01-03 18:47:18 arg_utils.py:1277] Setting max_num_batched_tokens to 2048 for OPENAI_API_SERVER usage context.
configuration_deepseek.py: 100%|████████████████████████████████████████████████| 10.6k/10.6k [00:00<00:00, 34.3MB/s]
A new version of the following files was downloaded from https://huggingface.co/deepseek-ai/DeepSeek-V3:
- configuration_deepseek.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
INFO 01-03 18:47:19 config.py:132] Replacing legacy 'type' key with 'rope_type'
INFO 01-03 18:47:26 config.py:517] This model supports multiple tasks: {'score', 'generate', 'classify', 'embed', 'reward'}. Defaulting to 'generate'.
INFO 01-03 18:47:27 config.py:1321] Defaulting to use mp for distributed inference
INFO 01-03 18:47:27 config.py:1469] Chunked prefill is enabled with max_num_batched_tokens=2048.
WARNING 01-03 18:47:27 cuda.py:98] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
WARNING 01-03 18:47:27 config.py:651] Async output processing is not supported on the current platform type cuda.
WARNING 01-03 18:47:27 fp8.py:50] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
tokenizer_config.json: 100%|████████████████████████████████████████████████████| 3.13k/3.13k [00:00<00:00, 33.0MB/s]
tokenizer.json: 100%|███████████████████████████████████████████████████████████| 7.85M/7.85M [00:00<00:00, 27.1MB/s]
INFO 01-03 18:47:29 config.py:132] Replacing legacy 'type' key with 'rope_type'
INFO 01-03 18:47:34 __init__.py:179] Automatically detected platform cuda.
INFO 01-03 18:47:35 core.py:48] Initializing an LLM engine (v0.6.4.post2.dev364+gea7bd68d1.d20241214) with config: model='deepseek-ai/DeepSeek-V3', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-V3', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=163840, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/DeepSeek-V3, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"candidate_compile_sizes":[],"compile_sizes":[],"capture_sizes":[],"max_capture_size":0}
WARNING 01-03 18:47:35 multiproc_worker_utils.py:292] Reducing Torch parallelism from 64 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 01-03 18:47:35 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
INFO 01-03 18:47:35 shm_broadcast.py:255] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[0, 1], buffer_handle=(2, 10485760, 10, 'psm_d7092a28'), local_subscribe_port=58351, remote_subscribe_port=None)
INFO 01-03 18:47:40 __init__.py:179] Automatically detected platform cuda.
(VllmWorker rank=0 pid=18959) INFO 01-03 18:47:41 shm_broadcast.py:255] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_4dda49de'), local_subscribe_port=33873, remote_subscribe_port=None)
INFO 01-03 18:47:47 __init__.py:179] Automatically detected platform cuda.
(VllmWorker rank=1 pid=19191) INFO 01-03 18:47:48 shm_broadcast.py:255] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_838aabdb'), local_subscribe_port=52735, remote_subscribe_port=None)
(VllmWorker rank=1 pid=19191) INFO 01-03 18:47:51 utils.py:948] Found nccl from library libnccl.so.2
(VllmWorker rank=1 pid=19191) INFO 01-03 18:47:51 pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=0 pid=18959) INFO 01-03 18:47:51 utils.py:948] Found nccl from library libnccl.so.2
(VllmWorker rank=0 pid=18959) INFO 01-03 18:47:51 pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=0 pid=18959) INFO 01-03 18:47:56 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /home/rshaw/.cache/vllm/gpu_p2p_access_cache_for_6,7.json
(VllmWorker rank=1 pid=19191) INFO 01-03 18:47:56 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /home/rshaw/.cache/vllm/gpu_p2p_access_cache_for_6,7.json
(VllmWorker rank=0 pid=18959) INFO 01-03 18:47:56 shm_broadcast.py:255] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_802c0622'), local_subscribe_port=50955, remote_subscribe_port=None)
(VllmWorker rank=0 pid=18959) INFO 01-03 18:47:57 gpu_model_runner.py:682] Starting to load model deepseek-ai/DeepSeek-V3...
(VllmWorker rank=1 pid=19191) INFO 01-03 18:47:57 gpu_model_runner.py:682] Starting to load model deepseek-ai/DeepSeek-V3...
(VllmWorker rank=0 pid=18959) Process SpawnProcess-1:1:
CRITICAL 01-03 18:47:58 multiproc_executor.py:45] MulitprocExecutor got SIGQUIT from worker processes, shutting down. See stack trace above for root cause issue.
(VllmWorker rank=0 pid=18959) Traceback (most recent call last):
CRITICAL 01-03 18:47:58 async_llm.py:53] AsyncLLM got SIGQUIT from worker processes, shutting down. See stack trace above for root cause issue.
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(VllmWorker rank=0 pid=18959)     self.run()
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/process.py", line 108, in run
(VllmWorker rank=0 pid=18959)     self._target(*self._args, **self._kwargs)
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/v1/executor/multiproc_executor.py", line 340, in worker_main
(VllmWorker rank=0 pid=18959)     worker = WorkerProc(*args, **kwargs)
(VllmWorker rank=0 pid=18959)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/v1/executor/multiproc_executor.py", line 273, in __init__
(VllmWorker rank=0 pid=18959)     self.worker.load_model()
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/v1/worker/gpu_worker.py", line 113, in load_model
(VllmWorker rank=0 pid=18959)     self.model_runner.load_model()
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/v1/worker/gpu_model_runner.py", line 684, in load_model
(VllmWorker rank=0 pid=18959)     self.model = get_model(vllm_config=self.vllm_config)
(VllmWorker rank=0 pid=18959)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
(VllmWorker rank=0 pid=18959)     return loader.load_model(vllm_config=vllm_config)
(VllmWorker rank=0 pid=18959)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/model_loader/loader.py", line 364, in load_model
(VllmWorker rank=0 pid=18959)     model = _initialize_model(vllm_config=vllm_config)
(VllmWorker rank=0 pid=18959)             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/model_loader/loader.py", line 117, in _initialize_model
(VllmWorker rank=0 pid=18959)     return model_class(vllm_config=vllm_config, prefix=prefix)
(VllmWorker rank=0 pid=18959)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/models/deepseek_v3.py", line 505, in __init__
(VllmWorker rank=0 pid=18959)     self.model = DeepseekV3Model(vllm_config=vllm_config,
(VllmWorker rank=0 pid=18959)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/models/deepseek_v3.py", line 440, in __init__
(VllmWorker rank=0 pid=18959)     self.start_layer, self.end_layer, self.layers = make_layers(
(VllmWorker rank=0 pid=18959)                                                     ^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/models/utils.py", line 551, in make_layers
(VllmWorker rank=0 pid=18959)     maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
(VllmWorker rank=0 pid=18959)                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/models/deepseek_v3.py", line 442, in <lambda>
(VllmWorker rank=0 pid=18959)     lambda prefix: DeepseekV3DecoderLayer(
(VllmWorker rank=0 pid=18959)                    ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/models/deepseek_v3.py", line 369, in __init__
(VllmWorker rank=0 pid=18959)     self.mlp = DeepseekV3MoE(
(VllmWorker rank=0 pid=18959)                ^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/models/deepseek_v3.py", line 125, in __init__
(VllmWorker rank=0 pid=18959)     self.experts = FusedMoE(
(VllmWorker rank=0 pid=18959)                    ^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 256, in __init__
(VllmWorker rank=0 pid=18959)     self.quant_method.create_weights(
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/vllm/model_executor/layers/quantization/fp8.py", line 408, in create_weights
(VllmWorker rank=0 pid=18959)     w13_weight = torch.nn.Parameter(torch.empty(num_experts,
(VllmWorker rank=0 pid=18959)                                     ^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959)   File "/home/rshaw/vllm/venv/lib/python3.12/site-packages/torch/utils/_device.py", line 106, in __torch_function__
(VllmWorker rank=0 pid=18959)     return func(*args, **kwargs)
(VllmWorker rank=0 pid=18959)            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=18959) torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.50 GiB. GPU 0 has a total capacity of 79.22 GiB of which 735.62 MiB is free. Including non-PyTorch memory, this process has 78.49 GiB memory in use. Of the allocated memory 77.07 GiB is allocated by PyTorch, and 55.93 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Copy link

github-actions bot commented Jan 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title Tp shutdown [V1] TP Error Handling Jan 3, 2025
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title [V1] TP Error Handling [V1] TP>1 Stack Trace Improvement Jan 3, 2025
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title [V1] TP>1 Stack Trace Improvement [V1] Improve TP>1 Stack Trace Jan 3, 2025
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title [V1] Improve TP>1 Stack Trace [V1] Improve TP>1 Error Handling + Stack Trace Jan 3, 2025
@@ -35,6 +35,8 @@ def __init__(
distributed_init_method: str,
):

self.i = 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE FOR REVIEWER: this is just a simple POC to show an example. Will remove this before landing.

while True:
method, args, kwargs = self.rpc_broadcast_mq.dequeue()

try:
output = getattr(self.worker, method)(*args, **kwargs)
except BaseException as e:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: we should not catch BaseException since it is too broad, per professor Gemini

image

@@ -42,21 +41,6 @@ def __init__(
start_engine_loop: bool = True,
) -> None:

# The child processes will send SIGQUIT when unrecoverable
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: moved to CoreClient so that it can be shared across AsyncLLM and LLMEngine

vllm/v1/executor/multiproc_executor.py Show resolved Hide resolved
Comment on lines 42 to 53
# The child processes will send SIGQUIT when unrecoverable
# errors happen.
def sigquit_handler(signum, frame):
logger.fatal(
"MulitprocExecutor got SIGQUIT from worker processes, shutting "
"down. See stack trace above for root cause issue.")
# Propagate error up to parent process.
parent_process = psutil.Process().parent()
parent_process.send_signal(signal.SIGQUIT)
self.shutdown()

signal.signal(signal.SIGQUIT, sigquit_handler)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW: why use SIGQUIT for these? Users might get weird error messages if they hit ctrl-\, so I was thinking it might be better to use SIGUSR1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly because this was inspired by SGL and they use SIGQUIT. I will switch to SIGUSR1

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

vllm/v1/engine/core_client.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_worker.py Outdated Show resolved Hide resolved
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic enabled auto-merge (squash) January 3, 2025 19:32
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 3, 2025
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 1543914 into vllm-project:main Jan 3, 2025
64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants