Skip to content

Commit

Permalink
Use runner_type instead of task in GritLM
Browse files Browse the repository at this point in the history
Signed-off-by: Pooya Davoodi <[email protected]>
  • Loading branch information
pooyadavoodi committed Dec 12, 2024
1 parent d4d5291 commit 7a1b7ab
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions tests/models/embedding/language/test_gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_find_array(monkeypatch):
from vllm.model_executor.models.gritlm import GritLMPooler

# Create an LLM object to get the model config.
llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN)
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
pooler = GritLMPooler(model_config=llm.llm_engine.model_config)

arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Expand All @@ -55,7 +55,7 @@ def server_embedding():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")

args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)]
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server

Expand Down Expand Up @@ -141,7 +141,7 @@ def test_gritlm_offline_embedding(monkeypatch):

queries, q_instruction, documents, d_instruction = get_test_data()

llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN)
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)

d_rep = run_llm_encode(
llm,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def __init__(
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

self.task = vllm_config.model_config.task
self.runner_type = vllm_config.model_config.runner_type

self._pooler = GritLMPooler(vllm_config.model_config)

for layer in self.model.layers:
if self.task == "embedding" and hasattr(layer, "self_attn"):
if self.runner_type == "pooling" and hasattr(layer, "self_attn"):
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
"GritLM embedding is only supported by XFormers backend, "
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
Expand All @@ -222,8 +222,8 @@ def forward(
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:

# Change attention to non-causal for embedding task.
if self.task == "embedding":
# Change attention to non-causal for pooling tasks.
if self.runner_type == "pooling":
assert attn_metadata.prefill_metadata.attn_bias is None
attn_metadata.prefill_metadata.attn_bias = [
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
Expand Down

0 comments on commit 7a1b7ab

Please sign in to comment.