-
-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add span metrics for model_forward, scheduler and sampler time #21
base: v0.5.3-main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
… sampling time + get elapsed event time after sampler
04c2ef5
to
921ef53
Compare
34e10fb
to
37a4bd8
Compare
37a4bd8
to
36f876c
Compare
if seq_group.metrics.scheduler_time is not None: | ||
seq_group.metrics.scheduler_time += scheduler_time | ||
else: | ||
seq_group.metrics.scheduler_time = scheduler_time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to count time being swapped out separately? I think counting scheduler time for a non-running request is a bit hard to interpret, because we'd normally want to compare it with model forward time and execute time to understand the overhead, and those are only collected for running requests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, counted scheduler time only over the running ones.
For a future change, we can track a different metric of swapped_out_time for each request.
vllm/worker/model_runner.py
Outdated
@@ -1335,6 +1341,13 @@ def execute_model( | |||
logits=logits, | |||
sampling_metadata=model_input.sampling_metadata, | |||
) | |||
if self.observability_config.collect_model_forward_time: | |||
torch.cuda.synchronize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like there's a Event.synchronize()
which seems a better fit than this global synchronize() https://pytorch.org/docs/stable/generated/torch.cuda.Event.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for the tip!
vllm/worker/worker_base.py
Outdated
@@ -268,12 +270,15 @@ def execute_model( | |||
if not get_pp_group().is_first_rank: | |||
intermediate_tensors = IntermediateTensors( | |||
get_pp_group().recv_tensor_dict()) | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extraneous change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
a7db9de
to
055cb66
Compare
055cb66
to
c836bca
Compare
See an example trace log here
s3://ml-dev-sfc-or-dev-misc1-k8s/vllm_bench/spanlogs/mkeralapura_llama-3-70b-0ffc9601_vllm-engine_0.span_trace_log to maheshkm/mkeralapura_llama-3-70b-0ffc9601_vllm-engine_0.span_trace_log
An entry there will look like
Span #0
Trace ID : e3e70682c2094cac629f6fbed82c07cd
Parent ID :
ID : 0a5d2f346baa9455
Name : llm_request
Kind : Server
Start time : 2024-08-02 18:18:39.590171648 +0000 UTC
End time : 2024-08-02 18:18:43.366565895 +0000 UTC
Status code : Unset
Status message :
Attributes:
-> gen_ai.response.model: Str(/data-fast/s3/ml-dev-sfc-or-dev-misc1-k8s/yak/hf_models/meta-llama/Meta-Llama-3-70B)
-> gen_ai.request.id: Str(fcdae8c235f84b618c24b656f3ad4024)
-> gen_ai.request.temperature: Double(0)
-> gen_ai.request.top_p: Double(1)
-> gen_ai.request.max_tokens: Int(256)
-> gen_ai.request.best_of: Int(1)
-> gen_ai.request.n: Int(1)
-> gen_ai.usage.num_sequences: Int(1)
-> gen_ai.usage.prompt_tokens: Int(2106)
-> gen_ai.usage.completion_tokens: Int(256)
-> gen_ai.latency.time_in_queue: Double(0.012117147445678711)
-> gen_ai.latency.time_to_first_token: Double(0.2345738410949707)
-> gen_ai.latency.e2e: Double(3.775825023651123)
-> gen_ai.latency.time_in_scheduler: Double(0.017550230026245117)
-> gen_ai.latency.time_in_model_forward: Double(3.151565277099609)
-> gen_ai.latency.time_in_model_execute: Double(3.6468167304992676)
{"kind": "exporter", "data_type": "traces", "name": "debug"}
Had to put the model_forward time behind a flag since I could not get it working without the synchronize.
The synchronize does not seem to add a delay though. For example the llama3 example benchmark run 9 seconds faster (27s vs 267). Overall, our expectation is that there should be no operations pending by the time the sampler is finished. Maybe if there is no token or in some other corner case there is something pending and elapsed_time() crashes with Cuda runtime error: device is busy.
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]
for bug fixes.[CI/Build]
for build or continuous integration improvements.[Doc]
for documentation fixes and improvements.[Model]
for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]
For changes on the vLLM frontend (e.g., OpenAI API server,LLM
class, etc.)[Kernel]
for changes affecting CUDA kernels or other compute kernels.[Core]
for changes in the core vLLM logic (e.g.,LLMEngine
,AsyncLLMEngine
,Scheduler
, etc.)[Hardware][Vendor]
for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]
).[Misc]
for PRs that do not fit the above categories. Please use this sparingly.Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
format.sh
to format your code.docs/source/
if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with
rfc-required
and might not go through the PR.What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
action-required
label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!