import torch.nn as nn
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
-import torch.nn.intrinsic.quantized._reference as nniqr
import torch.nn.quantized as nnq
-import torch.nn.quantized._reference as nnqr
import torch.nn.quantized.dynamic as nnqd
import torch.quantization
self.assertEqual(rqr, rqr2)
def _test_conv_api_impl(
- self, module_name, qconv_module, conv_module, batch_size,
- in_channels_per_group, input_feature_map_size, out_channels_per_group,
- groups, kernel_size, stride, padding, padding_mode, dilation,
- X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
- use_bias, use_fused, use_channelwise, is_reference
- ):
+ self, module_name, qconv_module, conv_module, batch_size,
+ in_channels_per_group, input_feature_map_size, out_channels_per_group,
+ groups, kernel_size, stride, padding, padding_mode, dilation,
+ X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
+ use_bias, use_fused, use_channelwise):
for i in range(len(kernel_size)):
assume(input_feature_map_size[i] + 2 * padding[i]
>= dilation[i] * (kernel_size[i] - 1) + 1)
# Test members
self.assertTrue(module_name == qconv_module._get_name(), module_name + " " + qconv_module._get_name())
- if not is_reference:
- self.assertTrue(hasattr(qconv_module, '_packed_params'))
+ self.assertTrue(hasattr(qconv_module, '_packed_params'))
self.assertTrue(hasattr(qconv_module, 'scale'))
self.assertTrue(hasattr(qconv_module, 'zero_point'))
# For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is
# 4 assuming the rounding mode is round-to-nearest, ties-to-even.
# skip numerics checking for reference module
- if not is_reference:
- np.testing.assert_array_almost_equal(
- Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0)
+ np.testing.assert_array_almost_equal(
+ Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0)
# Test serialization of quantized Conv Module using state_dict
model_dict = qconv_module.state_dict()
self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module))
self.assertTrue(module_name == loaded_qconv_module._get_name())
- if not is_reference:
- self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
+ self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias'))
self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight())
self.assertEqual(qconv_module.zero_point,
loaded_qconv_module.zero_point)
Y_loaded = loaded_qconv_module(X_q)
- if not is_reference:
- np.testing.assert_array_almost_equal(
- Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)
+ np.testing.assert_array_almost_equal(
+ Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)
# Test serialization
b = io.BytesIO()
self.assertEqual(copied_conv.zero_point,
qconv_module.zero_point)
Y_copied = copied_conv(X_q)
- if not is_reference:
- np.testing.assert_array_almost_equal(
- Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)
+ np.testing.assert_array_almost_equal(
+ Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)
deepcopied_conv = copy.deepcopy(qconv_module)
self.assertEqual(deepcopied_conv.bias(), qconv_module.bias())
self.assertEqual(deepcopied_conv.zero_point,
qconv_module.zero_point)
Y_deepcopied = copied_conv(X_q)
- if not is_reference:
- np.testing.assert_array_almost_equal(
- Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)
+ np.testing.assert_array_almost_equal(
+ Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)
# JIT testing
self.checkScriptable(
[True, False], # use_bias
[True, False], # use_fused
[True, False], # use_channelwise
- [True, False] # is_reference
)
- for pad_mode, use_bias, use_fused, use_channelwise, is_reference in options:
+ for pad_mode, use_bias, use_fused, use_channelwise in options:
if torch.backends.quantized.engine == "qnnpack":
use_channelwise = False
batch_size = 2
Y_zero_point = 4
if torch.backends.quantized.engine == 'qnnpack':
use_channelwise = False
- # (use_fused, is_reference) -> quantized class
+ # use_fused -> quantized class
class_map = {
- (True, True): (nniqr.ConvReLU1d, "QuantizedConvReLU1d(Reference)"),
- (True, False): (nniq.ConvReLU1d, "QuantizedConvReLU1d"),
- (False, True): (nnqr.Conv1d, "QuantizedConv1d(Reference)"),
- (False, False): (nnq.Conv1d, "QuantizedConv1d")
+ True: (nniq.ConvReLU1d, "QuantizedConvReLU1d"),
+ False: (nnq.Conv1d, "QuantizedConv1d")
}
- qconv_cls, module_name = class_map[(use_fused, is_reference)]
+ qconv_cls, module_name = class_map[use_fused]
qconv_module = qconv_cls(
in_channels, out_channels, kernel, stride, pad,
dilation, groups, use_bias, padding_mode=pad_mode
in_channels_per_group, input_feature_map_size,
out_channels_per_group, groups, kernel_size, stride, pad, pad_mode,
dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
- Y_zero_point, use_bias, use_fused, use_channelwise, is_reference)
+ Y_zero_point, use_bias, use_fused, use_channelwise)
@override_qengines
def test_conv2d_api(self):
[True, False], # use_bias
[True, False], # use_fused
[True, False], # use_channelwise
- [True, False] # is_reference
)
- for pad_mode, use_bias, use_fused, use_channelwise, is_reference in options:
+ for pad_mode, use_bias, use_fused, use_channelwise in options:
if torch.backends.quantized.engine == "qnnpack":
use_channelwise = False
batch_size = 2
W_zero_point = [3]
Y_scale = 5.0
Y_zero_point = 4
- # (use_fused, is_reference) -> quantized class
+ # use_fused -> quantized class
class_map = {
- (True, True): (nniqr.ConvReLU2d, "QuantizedConvReLU2d(Reference)"),
- (True, False): (nniq.ConvReLU2d, "QuantizedConvReLU2d"),
- (False, True): (nnqr.Conv2d, "QuantizedConv2d(Reference)"),
- (False, False): (nnq.Conv2d, "QuantizedConv2d")
+ True: (nniq.ConvReLU2d, "QuantizedConvReLU2d"),
+ False: (nnq.Conv2d, "QuantizedConv2d")
}
- qconv_cls, module_name = class_map[(use_fused, is_reference)]
+ qconv_cls, module_name = class_map[use_fused]
qconv_module = qconv_cls(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode=pad_mode
in_channels_per_group, input_feature_map_size,
out_channels_per_group, groups, kernel_size, stride, padding,
pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
- Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise, is_reference)
+ Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise)
@skipIfNoFBGEMM
def test_conv3d_api(self):
[True, False], # use_bias
[True, False], # use_fused
[True, False], # use_channelwise
- [True, False] # is_reference
)
- for use_bias, use_fused, use_channelwise, is_reference in options:
+ for use_bias, use_fused, use_channelwise in options:
if torch.backends.quantized.engine == "qnnpack":
use_channelwise = False
batch_size = 2
W_zero_point = [3]
Y_scale = 5.0
Y_zero_point = 4
- # (use_fused, is_reference) -> quantized class
+ # use_fused -> quantized class
class_map = {
- (True, True): (nniqr.ConvReLU3d, "QuantizedConvReLU3d(Reference)"),
- (True, False): (nniq.ConvReLU3d, "QuantizedConvReLU3d"),
- (False, True): (nnqr.Conv3d, "QuantizedConv3d(Reference)"),
- (False, False): (nnq.Conv3d, "QuantizedConv3d")
+ True: (nniq.ConvReLU3d, "QuantizedConvReLU3d"),
+ False: (nnq.Conv3d, "QuantizedConv3d")
}
with override_quantized_engine('fbgemm'):
- qconv_cls, module_name = class_map[(use_fused, is_reference)]
+ qconv_cls, module_name = class_map[use_fused]
qconv_module = qconv_cls(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, use_bias, padding_mode=pad_mode
out_channels_per_group, groups, kernel_size, stride, padding,
pad_mode, dilation, X_scale, X_zero_point, W_scale,
W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused,
- use_channelwise, is_reference)
+ use_channelwise)
def test_pool_api(self):
"""Tests the correctness of the pool module.
Conv1d,
conv1d_module_args,
(conv1d_input,),
- ns.call_module(nn.Conv1d if is_reference else nnq.Conv1d),
+ ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d),
None
),
(
Conv2d,
conv2d_module_args,
(conv2d_input,),
- ns.call_module(nn.Conv2d if is_reference else nnq.Conv2d),
+ ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d),
None
),
(
Conv3d,
conv3d_module_args,
(conv3d_input,),
- ns.call_module(nn.Conv3d if is_reference else nnq.Conv3d),
+ ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d),
None
),
(
qr = result_dict["quantized_reference"]
def checkWeightQParams(model):
- for module_name in ("conv",):
- if hasattr(model, module_name):
- self.assertTrue(hasattr(qr.get_submodule(module_name), "_weight_qparams"))
- self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
- for module_name in ("linear",):
+ for module_name in ("linear", "conv"):
if hasattr(model, module_name):
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
def checkSerDeser(model, is_dynamic):
- for module_name in ("conv",):
- if hasattr(model, module_name):
- # make sure seralization works
- state_dict = copy.deepcopy(model.state_dict())
- self.assertTrue(module_name + "._weight_qparams" in state_dict)
-
- # check load_state_dict restores states
- module = getattr(model, module_name)
- prev_scale = module._weight_qparams["scale"]
- module._weight_qparams["scale"] = None
- model.load_state_dict(state_dict)
- self.assertTrue(torch.equal(prev_scale, module._weight_qparams["scale"]))
- for module_name in ("linear",):
+ for module_name in ("linear", "conv"):
if hasattr(model, module_name):
# make sure seralization works
state_dict = copy.deepcopy(model.state_dict())
result_ref = m_ref(data)
self.assertTrue(torch.equal(result, result_ref))
+ def test_ref_conv_module(self):
+ """ Make sure the numerics for models with ref conv module
+ matches models with fbgemm/qnnpack module
+ """
+ convs = {
+ 1: nn.Conv1d,
+ 2: nn.Conv2d,
+ 3: nn.Conv3d,
+ }
+
+ class M1(torch.nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.conv = convs[dim](3, 3, 3)
+
+ def forward(self, x):
+ return self.conv(x)
+
+ class M2(torch.nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.conv = convs[dim](3, 3, 3)
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ return self.relu(self.conv(x))
+
+ for dim, M in itertools.product([1, 2, 3], [M1, M2]):
+ m = M(dim).eval()
+ m = prepare_fx(m, {"": default_qconfig})
+ m_copy = copy.deepcopy(m)
+ m = convert_fx(m, is_reference=False)
+ m_ref = convert_fx(m_copy, is_reference=True)
+ data = self.img_data_dict[dim][0][0]
+ result = m(data)
+ result_ref = m_ref(data)
+ self.assertTrue(torch.equal(result, result_ref))
+
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
reference_order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize'),
- ns.call_module(nn.Conv2d),
+ ns.call_module(nnqr.Conv2d),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize'),
ns.call_module(nn.Sigmoid),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize'),
- ns.call_module(nn.Conv2d),
+ ns.call_module(nnqr.Conv2d),
ns.call_function(torch.quantize_per_tensor),
ns.call_method('dequantize'),
]
+++ /dev/null
-from .modules import * # noqa: F403
+++ /dev/null
-import torch
-from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
-
-__all__ = [
- 'ConvReLU1d',
- 'ConvReLU2d',
- 'ConvReLU3d',
-]
+++ /dev/null
-import torch
-import torch.nn.quantized._reference as nnqr
-import torch.nn.functional as F
-
-class ConvReLU1d(nnqr.Conv1d):
- _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_dequant = x.dequantize()
- weight_dequant = self._qweight.dequantize()
- float_result = F.conv1d(
- x_dequant, weight_dequant, self._bias, self._conv1d_stride, # type: ignore[has-type]
- self._conv1d_padding, self._conv1d_dilation, self.groups) # type: ignore[has-type]
- float_result = F.relu(float_result, inplace=True)
- # NEEDFIX: we don't have dtype in the Linear module APIs right now!
- result = torch.quantize_per_tensor(
- float_result, self.scale, self.zero_point, torch.quint8)
- return result
-
- def _get_name(self):
- return "QuantizedConvReLU1d(Reference)"
-
-
-class ConvReLU2d(nnqr.Conv2d):
- _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_dequant = x.dequantize()
- weight_dequant = self._qweight.dequantize()
- float_result = F.conv2d(
- x_dequant, weight_dequant, self._bias, self.stride,
- self.padding, self.dilation, self.groups)
- float_result = F.relu(float_result, inplace=True)
- # NEEDFIX: we don't have dtype in the Linear module APIs right now!
- result = torch.quantize_per_tensor(
- float_result, self.scale, self.zero_point, torch.quint8)
- return result
-
- def _get_name(self):
- return "QuantizedConvReLU2d(Reference)"
-
-class ConvReLU3d(nnqr.Conv3d):
- _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_dequant = x.dequantize()
- weight_dequant = self._qweight.dequantize()
- float_result = F.conv3d(
- x_dequant, weight_dequant, self._bias, self.stride,
- self.padding, self.dilation, self.groups)
- float_result = F.relu(float_result, inplace=True)
- # NEEDFIX: we don't have dtype in the Linear module APIs right now!
- result = torch.quantize_per_tensor(
- float_result, self.scale, self.zero_point, torch.quint8)
- return result
-
- def _get_name(self):
- return "QuantizedConvReLU3d(Reference)"
import torch
-import torch.nn.quantized as nnq
+import torch.nn as nn
import torch.nn.functional as F
-from typing import Optional
+from typing import Optional, Dict, Any
from torch.nn.common_types import _size_1_t
-from torch.nn.modules.utils import _single
+from .utils import _quantize_and_dequantize_weight
+from .utils import _save_weight_qparams
+from .utils import _get_weight_qparam_keys
-class _ConvNd(nnq._ConvNd):
+class _ConvNd(torch.nn.modules.conv._ConvNd):
""" A reference version of nn.quantized.Conv2d
we will not pack the parameters in this module, since weight packing is an
optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
this is useful when user want to use this module in other backends like Glow.
"""
- __annotations__ = {"_bias": Optional[torch.Tensor]}
+ __annotations__ = {"bias": Optional[torch.Tensor]}
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
- destination[prefix + '_qweight'] = self._qweight
- destination[prefix + '_bias'] = self._bias
+ _save_weight_qparams(
+ destination, prefix, self.weight_qscheme, self.weight_dtype,
+ self.weight_scale, self.weight_zero_point, self.weight_axis)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
- self._qweight = state_dict[prefix + '_qweight']
- self._bias = state_dict[prefix + '_bias']
- state_dict.pop(prefix + '_qweight')
- state_dict.pop(prefix + '_bias')
+ for key in _get_weight_qparam_keys(state_dict, prefix):
+ setattr(self, key, state_dict[prefix + key])
+ state_dict.pop(prefix + key)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
- def _weight_bias(self):
- return self._qweight, self._bias
-
- def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
- self._qweight = w
- self._bias = b
-
-class Conv1d(_ConvNd, nnq.Conv1d):
+ def _init_weight_qparams(self, weight_qparams, device):
+ if weight_qparams is None:
+ weight_qparams = {
+ "qscheme": torch.per_tensor_affine,
+ "dtype": torch.quint8,
+ "scale": 1.0,
+ "zero_point": 0
+ }
+ self.weight_qscheme = weight_qparams["qscheme"]
+ self.weight_dtype = weight_qparams["dtype"]
+ assert self.weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
+ Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized linear module")
+ if self.weight_qscheme is not None:
+ self.register_buffer(
+ "weight_scale",
+ torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
+ self.register_buffer(
+ "weight_zero_point",
+ torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device))
+ if self.weight_qscheme == torch.per_channel_affine:
+ self.register_buffer(
+ "weight_axis",
+ torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device))
+ else:
+ # added for TorchScriptability, not used
+ self.register_buffer(
+ "weight_axis", torch.tensor(0, dtype=torch.int, device=device))
+
+ def get_weight(self):
+ """
+ Fake quantize (quantize and dequantize) the weight with
+ the quantization parameters for weight, this is used to
+ simulate the numerics for the quantized weight in a quantized
+ model
+ """
+ # supress mypy warning
+ assert isinstance(self.weight, torch.Tensor)
+ assert isinstance(self.weight_scale, torch.Tensor)
+ assert isinstance(self.weight_zero_point, torch.Tensor)
+ assert isinstance(self.weight_axis, torch.Tensor)
+ return _quantize_and_dequantize_weight(
+ self.weight, self.weight_qscheme,
+ self.weight_dtype, self.weight_scale, self.weight_zero_point, self.weight_axis)
+
+ @staticmethod
+ def from_float(cls, float_conv, weight_qparams):
+ qref_conv = cls(
+ float_conv.in_channels,
+ float_conv.out_channels,
+ float_conv.kernel_size, # type: ignore[arg-type]
+ float_conv.stride, # type: ignore[arg-type]
+ float_conv.padding, # type: ignore[arg-type]
+ float_conv.dilation, # type: ignore[arg-type]
+ float_conv.groups,
+ float_conv.bias is not None, # type: ignore[arg-type]
+ float_conv.padding_mode,
+ device=float_conv.weight.device,
+ dtype=float_conv.weight.dtype,
+ weight_qparams=weight_qparams)
+ qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
+ if float_conv.bias is not None:
+ qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
+ return qref_conv
+
+class Conv1d(_ConvNd, nn.Conv1d):
def __init__(self,
in_channels: int,
out_channels: int,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
- padding_mode: str = 'zeros'):
- nnq.Conv1d.__init__(
+ padding_mode: str = "zeros",
+ device=None,
+ dtype=None,
+ weight_qparams: Optional[Dict[str, Any]] = None):
+ nn.Conv1d.__init__(
self, in_channels, out_channels, kernel_size, stride, padding, dilation,
- groups, bias, padding_mode)
- # self.stride, self.padding, self.dilation are 2d tuple since
- # current quantized conv1d is using Conv2dPackedParams
- # TODO: we should fix this if we implemenet Conv1dPackedParams
- self._conv1d_stride = _single(self.stride[0])
- self._conv1d_padding = _single(self.padding[0])
- self._conv1d_dilation = _single(self.dilation[0])
+ groups, bias, padding_mode, device, dtype)
+ self._init_weight_qparams(weight_qparams, device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_dequant = x.dequantize()
- weight_dequant = self._qweight.dequantize()
- float_result = F.conv1d(
- x_dequant, weight_dequant, self._bias, self._conv1d_stride,
- self._conv1d_padding, self._conv1d_dilation, self.groups)
- # NEEDFIX: we don't have dtype in the Linear module APIs right now!
- result = torch.quantize_per_tensor(
- float_result, self.scale, self.zero_point, torch.quint8)
+ """
+ we have:
+ w(float) -- quant - dequant \
+ x(float) ------------- F.conv1d ---
+
+ In the full model, we will see
+ w(float) -- quant - *dequant \
+ x -- quant --- *dequant -- *F.conv1d --- *quant - dequant
+ and the backend should be able to fuse the ops with `*` into a quantized conv1d
+ """
+ weight_dequant = self.get_weight()
+ result = F.conv1d(
+ x, weight_dequant, self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
return result
def _get_name(self):
- return 'QuantizedConv1d(Reference)'
-
- @torch.jit.export
- def __setstate__(self, state):
- self.in_channels = state[0]
- self.out_channels = state[1]
- self.kernel_size = state[2]
- self.stride = state[3]
- self.padding = state[4]
- self.dilation = state[5]
- self.transposed = state[6]
- self.output_padding = state[7]
- self.groups = state[8]
- self.padding_mode = state[9]
- self.set_weight_bias(state[10], state[11])
- self.scale = state[12]
- self.zero_point = state[13]
- self.training = state[14]
- self._conv1d_stride = (self.stride[0],)
- self._conv1d_padding = (self.padding[0],)
- self._conv1d_dilation = (self.dilation[0],)
-
-class Conv2d(_ConvNd, nnq.Conv2d):
+ return "QuantizedConv1d(Reference)"
+
+ @classmethod
+ def from_float(cls, float_conv, weight_qparams):
+ return _ConvNd.from_float(cls, float_conv, weight_qparams)
+
+class Conv2d(_ConvNd, nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
- padding_mode='zeros'):
- nnq.Conv2d.__init__(
+ padding_mode='zeros',
+ device=None,
+ dtype=None,
+ weight_qparams: Optional[Dict[str, Any]] = None):
+ nn.Conv2d.__init__(
self, in_channels, out_channels, kernel_size, stride, padding, dilation,
- groups, bias, padding_mode)
+ groups, bias, padding_mode, device, dtype)
+ self._init_weight_qparams(weight_qparams, device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_dequant = x.dequantize()
- weight_dequant = self._qweight.dequantize()
- float_result = F.conv2d(
- x_dequant, weight_dequant, self._bias, self.stride,
+ """
+ we have:
+ w(float) -- quant - dequant \
+ x(float) ------------- F.conv2d ---
+
+ In the full model, we will see
+ w(float) -- quant - *dequant \
+ x -- quant --- *dequant -- *F.conv2d --- *quant - dequant
+ and the backend should be able to fuse the ops with `*` into a quantized conv2d
+ """
+ weight_dequant = self.get_weight()
+ result = F.conv2d(
+ x, weight_dequant, self.bias, self.stride,
self.padding, self.dilation, self.groups)
- # NEEDFIX: we don't have dtype in the Linear module APIs right now!
- result = torch.quantize_per_tensor(
- float_result, self.scale, self.zero_point, torch.quint8)
return result
def _get_name(self):
- return 'QuantizedConv2d(Reference)'
+ return "QuantizedConv2d(Reference)"
-class Conv3d(_ConvNd, nnq.Conv3d):
+ @classmethod
+ def from_float(cls, float_conv, weight_qparams):
+ return _ConvNd.from_float(cls, float_conv, weight_qparams)
+
+class Conv3d(_ConvNd, nn.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
- padding_mode='zeros'):
- nnq.Conv3d.__init__(
+ padding_mode="zeros",
+ device=None,
+ dtype=None,
+ weight_qparams: Optional[Dict[str, Any]] = None):
+ nn.Conv3d.__init__(
self, in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode)
+ self._init_weight_qparams(weight_qparams, device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_dequant = x.dequantize()
- weight_dequant = self._qweight.dequantize()
- float_result = F.conv3d(
- x_dequant, weight_dequant, self._bias, self.stride,
+ """
+ we have:
+ w(float) -- quant - dequant \
+ x(float) ------------- F.conv3d ---
+
+ In the full model, we will see
+ w(float) -- quant - *dequant \
+ x -- quant --- *dequant -- *F.conv3d --- *quant - dequant
+ and the backend should be able to fuse the ops with `*` into a quantized conv3d
+ """
+ weight_dequant = self.get_weight()
+ result = F.conv3d(
+ x, weight_dequant, self.bias, self.stride,
self.padding, self.dilation, self.groups)
- # NEEDFIX: we don't have dtype in the Linear module APIs right now!
- result = torch.quantize_per_tensor(
- float_result, self.scale, self.zero_point, torch.quint8)
return result
def _get_name(self):
- return 'QuantizedConv3d(Reference)'
+ return "QuantizedConv3d(Reference)"
+
+ @classmethod
+ def from_float(cls, float_conv, weight_qparams):
+ return _ConvNd.from_float(cls, float_conv, weight_qparams)
# and qparam is a dictionary of
# {"qscheme": ..., "scale": ..., "zero_point": ...} for per tensor quantization or
# {"qscheme": ..., "scale": ..., "zero_point": ..., "axis": ...} for per channel quantization
+ float_conv = self.conv
+ fused_conv = None
if isinstance(
- self.conv,
+ float_conv,
QAT_CONV_MODULE_CLASSES):
# case 1. converting qat conv module to
# a float conv module, we need to attch
# weight fake_quant to the conv module,
# weight fake_quant is assumed to be run during
# QAT so we don't need to run it again here
- float_conv = self.conv.to_float()
+ float_conv = self.conv.to_float() # type: ignore[operator]
# change qat conv to conv
parent_name, name = _parent_name(self.conv_node.target)
setattr(modules[parent_name], name, float_conv)
if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
+ fused_conv = float_conv
float_conv = float_conv[0]
weight_post_process = self.conv.weight_fake_quant
else:
# to float conv module, we need to attach
# weight observer to the conv module and run it
# with conv weight
- float_conv = self.conv
- if isinstance(self.conv, torch.nn.intrinsic._FusedModule):
- float_conv = self.conv[0]
+ if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
+ fused_conv = float_conv
+ float_conv = float_conv[0] # type: ignore[index]
assert qconfig is not None
weight_post_process = qconfig.weight()
# run weight observer
- weight_post_process(float_conv.weight)
+ weight_post_process(float_conv.weight) # type: ignore[operator]
weight_qparams = get_qparam_dict(weight_post_process)
- _to_reference(float_conv, weight_qparams)
+ # hardcoded for now, TODO: expose the api to user,
+ # we can have a map from module to reference module
+ # and allow user to register new ones
+ qconv_cls = get_static_quant_module_class(
+ type(float_conv), is_reference=is_reference)
+ ref_conv = qconv_cls.from_float(float_conv, weight_qparams) # type: ignore[attr-defined]
+ # if the parent is a fused conv (Sequential), we can replace the first
+ # item to ref conv, otherwise we can update
+ # the conv instance in the module tree
+ if fused_conv is not None:
+ fused_conv[0] = ref_conv
+ else:
+ parent_name, name = _parent_name(self.conv_node.target)
+ setattr(modules[parent_name], name, ref_conv)
op_out = quantized_graph.create_node(
'call_module',
self.conv_node.target,
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized.dynamic as nniqd
-import torch.nn.intrinsic.quantized._reference as nniqr
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized._reference as nnqr
nn.Conv1d: nnqr.Conv1d,
nn.Conv2d: nnqr.Conv2d,
nn.Conv3d: nnqr.Conv3d,
- nni.ConvReLU1d: nniqr.ConvReLU1d,
- nni.ConvReLU2d: nniqr.ConvReLU2d,
- nni.ConvReLU3d: nniqr.ConvReLU3d,
- # QAT Modules
- nnqat.Conv2d: nnqr.Conv2d,
- nnqat.Conv3d: nnqr.Conv3d,
- nniqat.ConvBn1d: nnqr.Conv1d,
- nniqat.ConvBn2d: nnqr.Conv2d,
- nniqat.ConvBn3d: nnqr.Conv3d,
- nniqat.ConvBnReLU1d: nniqr.ConvReLU1d,
- nniqat.ConvBnReLU2d: nniqr.ConvReLU2d,
- nniqat.ConvBnReLU3d: nniqr.ConvReLU3d,
- nniqat.ConvReLU2d: nniqr.ConvReLU2d,
- nniqat.ConvReLU3d: nniqr.ConvReLU3d,
}
# Default map for swapping float module to quantized ones