Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic committed Nov 16, 2024
1 parent d2ae4a5 commit 7dd18e0
Showing 1 changed file with 1 addition and 10 deletions.
11 changes: 1 addition & 10 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,10 +581,6 @@ def execute_model(
model_input.num_samples,
kv_caches,
is_prompt=True)
print(f"{token_ids=}")
print(f"{position_ids=}")
print(f"{attn_metadata=}")
# breakpoint()
next_token_ids.append(output_token_ids[0])
start_idx = end_idx

Expand Down Expand Up @@ -629,7 +625,6 @@ def execute_model(
input_lens = model_input.input_lens.to(self.device)
for i in range(num_steps):
slot_mapping = attn_metadata.slot_mapping

output_token_ids = self.model(token_ids,
position_ids,
attn_metadata,
Expand All @@ -639,10 +634,6 @@ def execute_model(
model_input.num_samples,
kv_caches,
is_prompt=False)
print(f"{token_ids=}")
print(f"{position_ids=}")
print(f"{attn_metadata=}")
# breakpoint()
self.cached_step_outputs.append(output_token_ids)

if i < num_steps - 1:
Expand Down Expand Up @@ -774,7 +765,7 @@ def forward(
logits = self.model.compute_logits(hidden_states, sampling_metadata)

# Argmax sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)

# Zero temperature means greedy decoding. Avoid division by zero.
Expand Down

0 comments on commit 7dd18e0

Please sign in to comment.