Skip to content

Commit

Permalink
Sync with [email protected]
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrifiro committed Jul 23, 2024
2 parents b2dc37f + 38c4b7e commit 02ce364
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 38 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Register now [here](https://lu.ma/lp0gyjqr) and be part of the event!
---

*Latest News* 🔥
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
Expand Down
4 changes: 2 additions & 2 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ Decoder-only Language Models
- :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎
* - :code:`LlamaForCausalLM`
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
- ✅︎
* - :code:`MiniCPMForCausalLM`
- MiniCPM
Expand Down
8 changes: 4 additions & 4 deletions docs/source/serving/distributed_serving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ Pick a node as the head node, and run the following command:
$ bash run_cluster.sh \
$ vllm/vllm-openai \
$ ip_of_head_node \
$ /path/to/the/huggingface/home/in/this/node \
$ --head
$ --head \
$ /path/to/the/huggingface/home/in/this/node
On the rest of the worker nodes, run the following command:

Expand All @@ -76,8 +76,8 @@ On the rest of the worker nodes, run the following command:
$ bash run_cluster.sh \
$ vllm/vllm-openai \
$ ip_of_head_node \
$ /path/to/the/huggingface/home/in/this/node \
$ --worker
$ --worker \
$ /path/to/the/huggingface/home/in/this/node
Then you get a ray cluster of containers. Note that you need to keep the shells running these commands alive to hold the cluster. Any shell disconnect will terminate the cluster.

Expand Down
55 changes: 27 additions & 28 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,6 @@ def __init__(
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)

if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
and getattr(self.hf_config, "rope_scaling", None) is None):
# Note(simon): this is a special case for a model that doesn't
# supply rope_scaling. We should remove this once the model is
# updated.
self.hf_config.update({"rope_scaling": {
"type": "extended",
}})

if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
Expand Down Expand Up @@ -814,7 +805,7 @@ def __init__(self,
if enable_chunked_prefill:
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
max_num_batched_tokens)
self.max_num_batched_tokens)

self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
Expand Down Expand Up @@ -1492,24 +1483,32 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_scaling is not None and rope_scaling["type"] not in {
"su", "longrope", "extended"
}:
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
if rope_scaling is not None:
if "type" in rope_scaling:
rope_type = rope_scaling["type"]
elif "rope_type" in rope_scaling:
rope_type = rope_scaling["rope_type"]
else:
raise ValueError(
"rope_scaling must have a 'type' or 'rope_type' key.")

# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")

assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,12 +794,13 @@ def get_rope(
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
scaling_type = rope_scaling["type"]
scaling_type = rope_scaling[
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type not in {"su", "longrope", "extended"}:
if scaling_type not in {"su", "longrope", "llama3"}:
scaling_factor = rope_scaling["factor"]
if scaling_type == "extended":
if scaling_type == "llama3":
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
Expand Down
2 changes: 1 addition & 1 deletion vllm/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER"

__version__ = "0.5.3"
__version__ = "0.5.3.post1"

0 comments on commit 02ce364

Please sign in to comment.