Skip to content

Commit

Permalink
remove torch warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic committed Jan 7, 2025
1 parent a1d7b4a commit 2b4ecfd
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions tests/tpu/test_quantization_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class GSM8KAccuracyTestConfig:

def get_model_args(self) -> str:
return (f"pretrained={self.model_name},"
"max_model_len=4096,max_num_seqs=128")
"max_model_len=4096,max_num_seqs=128,tensor_parallel_size=4")


# NOTE: Accuracy scores measured on GPUs.
Expand All @@ -28,7 +28,7 @@ def get_model_args(self) -> str:
# a follow up, move this into the LM-EVAL section of the CI.
# GSM8KAccuracyTestConfig(
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
# excepted_value=0.66), # bias in QKV layers
# excepted_value=0.66), # bias in QKV layers
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
setattr(layer, self.azp_adj_name,
setattr(layer, self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False))
else:
setattr(layer, self.azp_adj_name, None)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Tuple

from functorch.experimental.control_flow import cond # noqa: F401
import torch
from functorch.experimental.control_flow import cond # noqa: F401

from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down Expand Up @@ -66,13 +66,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)

def _no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
return x

def _add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
return x + bias


def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
Expand All @@ -90,4 +89,4 @@ def apply_weights(self,

# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias, self._add_bias, self._no_add_bias, [out, bias])
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])

0 comments on commit 2b4ecfd

Please sign in to comment.