Skip to content

Commit

Permalink
[dequantizer] restore {add, mul}_scalar (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 authored Sep 5, 2023
1 parent aa34daa commit 6008cb6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
36 changes: 36 additions & 0 deletions tests/quantizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,42 @@ def forward(self, x):

check_dequantize_rewrite(model, inputs)

def test_simple_q_add_scalar(self):
class Model(nn.Module):
def __init__(self):
super().__init__()

self.q = torch.quantization.QuantStub()
self.dq = torch.quantization.DeQuantStub()
self.f = torch.nn.quantized.FloatFunctional()

def forward(self, x):
x = self.q(x)
return self.f.add_scalar(x, 0.5)

model = Model()
inputs = torch.randn(1, 3, 224, 224)

check_dequantize_rewrite(model, inputs)

def test_simple_q_mul_scalar(self):
class Model(nn.Module):
def __init__(self):
super().__init__()

self.q = torch.quantization.QuantStub()
self.dq = torch.quantization.DeQuantStub()
self.f = torch.nn.quantized.FloatFunctional()

def forward(self, x):
x = self.q(x)
return self.f.mul_scalar(x, 0.5)

model = Model()
inputs = torch.randn(1, 3, 224, 224)

check_dequantize_rewrite(model, inputs)

def test_conv_transpose_bn_fusion(self):
class Model(nn.Module):
def __init__(self):
Expand Down
19 changes: 19 additions & 0 deletions tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3948,6 +3948,25 @@ def _is_add_relu_node(node: TraceNode, custom_data):
next_out = torch.relu(n.next_tensors[0])
graph.insert_after(n, next_func, [next_out])

# Replace FloatFunctional.{add_scalar, mul_scalar} with torch.{add, mul}
def _is_add_mul_scalar_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind in ('add_scalar', 'mul_scalar')
and len(node.prev_nodes) > 1
and node.prev_nodes[0].type() == nnq.FloatFunctional
)

add_mul_scalar_nodes = graph.filter_forward_nodes(_is_add_mul_scalar_node)
for n in add_mul_scalar_nodes:
n.module.kind = n.module.kind.split('_')[0]
n.module.func_type = n.module.kind

parts = n.module.full_name.split('.')[:-1] + [n.module.func_type]
n.module.full_name = '.'.join(parts)

# Replace FloatFunctional.{add, mul, cat} with torch.{add, mul, cat}
def _is_add_mul_cat_node(node: TraceNode, custom_data):
cur_module = node.module
Expand Down

0 comments on commit 6008cb6

Please sign in to comment.