-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model with stack does not work with int8 target type #359
Comments
Well, we need to apply the same logic to |
Seems to also happen if I turn it into a class CatModel(nn.Module):
def forward(self, x: torch.Tensor):
"""
Args:
x: [N, H, W, C]
"""
return torch.cat([-x.unsqueeze(-1), x.unsqueeze(-1)], dim=-1) |
@spacycoder class CatModel(nn.Module):
def forward(self, x: torch.Tensor):
"""
Args:
x: [N, H, W, C]
"""
z = x.unsqueeze(-1)
return torch.cat([-z, z], dim=-1) |
That also fails |
Or this? class CatModel(nn.Module):
def forward(self, x: torch.Tensor):
"""
Args:
x: [N, H, W, C]
"""
return torch.cat([-x, x], dim=-1).view(x.shape[:-1] + [-1, 2]) |
Nope, doesn't work either |
Okay, will look into it tomorrow. |
@spacycoder It seems that the problem is on |
@spacycoder Things should work with #360 |
This also fails with the same concatenation error: import torch.nn as nn
import torch
from tinynn.graph.quantization.quantizer import PostQuantizer
from tinynn.converter import TFLiteConverter
class EncoderLayer(nn.Module):
def __init__(
self,
d_model: int = 256
):
super().__init__()
self.mlp0 = nn.Linear(d_model, d_model, bias=False)
self.mlp1 = nn.Linear(d_model * 2, d_model, bias=False)
def forward(
self,
x: torch.Tensor,
):
x = x.permute(0, 2, 3, 1)
m = self.mlp0(x)
m = torch.cat([x, m], dim=-1)
m = self.mlp1(m)
return x + m
class Dummy(nn.Module):
def __init__(self):
super().__init__()
self.encoder = EncoderLayer(256)
def forward(self, x, y):
x = self.encoder(x)
y = self.encoder(y)
return x, y
def _main():
dummy_input0 = torch.rand(1, 256, 60, 60).float()
dummy_input1 = torch.rand(1, 256, 60, 60).float()
model = Dummy()
ptq_config = {
"backend": "qnnpack",
"per_tensor": True,
"disable_requantization_for_cat": True
}
quantizer = PostQuantizer(
model, (dummy_input0, dummy_input1), work_dir="cat_model", config=ptq_config
)
ptq_model = quantizer.quantize()
ptq_model(dummy_input0, dummy_input1)
with torch.no_grad():
ptq_model.eval()
ptq_model.cpu()
ptq_model = quantizer.convert(ptq_model)
torch.backends.quantized.engine = quantizer.backend
converter = TFLiteConverter(
ptq_model,
(dummy_input0, dummy_input1),
"cat_model.tflite",
fuse_quant_dequant=True,
quantize_target_type="int8"
)
converter.convert()
if __name__ == '__main__':
_main() |
FYI having two separate encoders works (but I need them to be the same): class Dummy(nn.Module):
def __init__(self):
super().__init__()
self.encoder0 = EncoderLayer(256)
self.encoder1 = EncoderLayer(256)
def forward(self, x, y):
x = self.encoder0(x)
y = self.encoder1(y)
return x, y |
Okay, I guess it is because we refuse to traverse into the same nodes in the computation graph again. We need to refine the constraints a little bit. |
This seems to be a decent workaround for the moment: class Dummy(nn.Module):
def __init__(self):
super().__init__()
self.encoder = EncoderLayer(256)
def forward(self, x, y):
x_cat = torch.cat([x, y], dim=0)
x_cat = self.encoder(x_cat)
x, y = torch.chunk(x_cat, 2, dim=0)
return x, y |
@spacycoder I'm glad it works and it looks cleaner. |
Converting this dummy model with quantize_target_type="int8" and per_tensor=True throws an error in tflite
Tflite error:
Note that the model works fine if I remove the "negative x" and instead send the same tensor twice, and it works with uint8
The text was updated successfully, but these errors were encountered: