Skip to content

Commit

Permalink
pytorch/ao/torchao/experimental/ops/mps/test (#1442)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1442

Reviewed By: avikchaudhuri, ydwu4

Differential Revision: D67388057
  • Loading branch information
gmagogsfm authored and facebook-github-bot committed Dec 19, 2024
1 parent aea2356 commit 316c096
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import copy
import itertools
import os
import sys
import unittest
from typing import Optional

import torch
import unittest

from parameterized import parameterized
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
from torchao.experimental.quant_api import _quantize
from torchao.experimental.quant_api import _quantize, UIntxWeightOnlyLinearQuantizer

libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
Expand Down Expand Up @@ -80,7 +79,7 @@ def test_export(self, nbit):
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
exported = torch.export.export(quantized_model, (activations,))
exported = torch.export.export(quantized_model, (activations,), strict=True)

for node in exported.graph.nodes:
if node.op == "call_function":
Expand Down

0 comments on commit 316c096

Please sign in to comment.