return False
+def _is_quantized_tensor(data, prelude):
+ # If a quantized Torch module is saved and loaded back, dtype will be dropped
+ # Since dtypes from Torch tensors are not reliable in such cases, we use
+ # Relay's type inference result to decide if an input tensor is quantized
+ ty = _infer_type_with_prelude(data, prelude)
+ return ty.dtype == "uint8"
+
+
# operator implementation
def _elemwise(name):
def _impl(inputs, input_types):
return _impl
-def _relu():
+def _relu(prelude):
def _impl(inputs, input_types):
data = inputs[0]
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
return _op.log(_op.tensor.sigmoid(data))
return _impl
-def _adaptive_avg_pool_2d():
+def _adaptive_avg_pool_2d(prelude):
def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])
def func(x):
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
return qnn_torch.apply_with_upcast(data, func)
return func(data)
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
return _impl
-def _avg_pool2d():
+def _avg_pool2d(prelude):
def _impl(inputs, input_types):
data = inputs[0]
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
return qnn_torch.apply_with_upcast(data, func)
return func(data)
return _impl
-def _mean():
+def _mean(prelude):
def _impl(inputs, input_types):
data = inputs[0]
def func(x):
return _op.mean(x, axis, keepdims, exclude)
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
return _impl
-def _upsample(method):
+def _upsample(method, prelude):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
import torch
from packaging import version
"aten::take" : _take(),
"aten::where" : _where(),
"aten::topk" : _topk(),
- "aten::relu" : _relu(),
- "aten::relu_" : _relu(),
+ "aten::relu" : _relu(prelude),
+ "aten::relu_" : _relu(prelude),
"aten::prelu" : _prelu(),
"aten::leaky_relu" : _leaky_relu(),
"aten::elu" : _elu(),
"aten::gelu" : _gelu(),
"aten::selu" : _selu(),
"aten::log_sigmoid" : _log_sigmoid(),
- "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(),
+ "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(prelude),
"aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
"aten::max_pool2d" : _maxpool_2d(),
"aten::max_pool2d_with_indices" : _maxpool_2d_with_indices(),
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
"aten::softplus" : _softplus(),
- "aten::avg_pool2d" : _avg_pool2d(),
+ "aten::avg_pool2d" : _avg_pool2d(prelude),
"aten::avg_pool3d" : _avg_pool3d(),
"aten::dropout" : _dropout(),
"aten::dropout_" : _dropout(),
"aten::feature_dropout" : _dropout(),
"aten::alpha_dropout" : _dropout(),
- "aten::mean" : _mean(),
+ "aten::mean" : _mean(prelude),
"aten::chunk" : _chunk(prelude),
"aten::matmul" : _matmul(prelude),
"aten::expand" : _expand(),
"aten::isnan" : _unary("isnan"),
"aten::clamp" : _clamp(),
"aten::detach" : _identity(),
- "aten::upsample_bilinear2d" : _upsample("bilinear"),
- "aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
+ "aten::upsample_bilinear2d" : _upsample("bilinear", prelude),
+ "aten::upsample_nearest2d" : _upsample("nearest_neighbor", prelude),
"aten::upsample_trilinear3d" : _upsample3d("trilinear"),
"aten::upsample_nearest3d" : _upsample3d("nearest_neighbor"),
"aten::expand_as" : _expand_as(),
weight=default_weight_observer)
-def quantize_model(model, inp, per_channel=False, dummy=True):
+def quantize_model(model, inp, per_channel=False):
model.fuse_model()
model.qconfig = get_qconfig(per_channel)
torch.quantization.prepare(model, inplace=True)
pass
+class AdaptiveAvgPool2d(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = QuantWrapper(nn.AdaptiveAvgPool2d((1, 1)))
+
+ def forward(self, x):
+ return self.pool(x)
+
+ def fuse_model(self):
+ pass
+
+
def test_quantized_modules():
imagenet_ishape = (1, 3, 224, 224)
raw_module.eval()
inp = torch.rand(ishape)
- quantize_model(raw_module, inp, per_channel=per_channel, dummy=True)
+ quantize_model(raw_module, inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_module, inp).eval()
with torch.no_grad():
inp = get_imagenet_input()
pt_inp = torch.from_numpy(inp)
- quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False)
+ quantize_model(raw_model, pt_inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_model, pt_inp).eval()
with torch.no_grad():
mean abs_diff: 0.054197952
558 in 1000 raw outputs identical.
"""
+
+
+def test_serialized_modules():
+ ishape = (1, 16, 64, 64)
+ raw_module = AdaptiveAvgPool2d().eval()
+ inp = torch.rand(ishape)
+
+ quantize_model(raw_module, inp)
+ script_module = torch.jit.trace(raw_module, inp).eval()
+
+ fname = "tmp.pt"
+ torch.jit.save(script_module, fname)
+ loaded = torch.jit.load(fname)
+ os.remove(fname)
+
+ with torch.no_grad():
+ pt_result = loaded(inp.clone()).numpy()
+
+ input_name = "input"
+ runtime = get_tvm_runtime(loaded, input_name, ishape)
+ runtime.set_input(input_name, inp.numpy().copy())
+ runtime.run()
+ tvm_result = runtime.get_output(0).asnumpy()
+
+ num_identical = np.sum(tvm_result == pt_result)
+ match_ratio = num_identical / float(np.prod(tvm_result.shape))
+ assert match_ratio > 0.2