Skip to content

Commit

Permalink
fix device check (#1453)
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng authored Dec 17, 2024
1 parent 9948333 commit 7e6f865
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def set_ipex_linear(self, x: torch.Tensor):
and not self.training
and x.requires_grad == False
):
enable_ipex_fusion(self)
enable_ipex_fusion(self, x)

def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
Expand Down
8 changes: 4 additions & 4 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,15 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def enable_ipex_fusion(linear):
def enable_ipex_fusion(linear, x):
from bitsandbytes.backends.cpu_xpu_common import (
_ipex_cpu_version_prereq,
_ipex_xpu_version_prereq,
ipex_cpu_only,
ipex_cpu,
ipex_xpu,
)

if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5):
if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
Expand All @@ -221,7 +221,7 @@ def enable_ipex_fusion(linear):
quant_state.blocksize,
2,
)
elif ipex_xpu and _ipex_xpu_version_prereq(2, 5):
elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def get_latest_semver_tag():
tags = subprocess.check_output(["git", "tag"], text=True).splitlines()
semver_tags = [tag for tag in tags if tag.count(".") == 2 and all(part.isdigit() for part in tag.split("."))]
if not semver_tags:
print("No valid semantic version tags found, use 0.0.1 defaultly")
semver_tags = ["0.0.1"]
print("No valid semantic version tags found, use 1.0.0 defaultly")
semver_tags = ["1.0.0"]
return sorted(semver_tags, key=lambda s: list(map(int, s.split("."))))[-1]


Expand Down

0 comments on commit 7e6f865

Please sign in to comment.