[quant][graphmode][fx] Add reference quantized conv module (#63828)
authorJerry Zhang <jerryzh@fb.com>
Mon, 30 Aug 2021 21:21:39 +0000 (14:21 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 21:23:17 +0000 (14:23 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63828

Added reference quantized conv module for the custom backend flow, the reference quantized module will
have the following code:
```
        w(float) -- quant - dequant \
        x(float) ------------- F.conv2d ---
```
In the full model, we will see
```
        w(float) -- quant - *dequant \
        x -- quant --- *dequant --  *F.conv2d --- *quant - dequant
```
and the backend should be able to fuse the ops with `*` into a quantized linear

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_conv_linear_reference

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D30504749

fbshipit-source-id: e1d8c43a0e0d6d9ea2375b8ca59a9c0f455514fb

test/quantization/core/test_quantized_module.py
test/quantization/fx/test_quantize_fx.py
torch/nn/intrinsic/quantized/_reference/__init__.py [deleted file]
torch/nn/intrinsic/quantized/_reference/modules/__init__.py [deleted file]
torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py [deleted file]
torch/nn/quantized/_reference/modules/conv.py
torch/quantization/fx/quantization_patterns.py
torch/quantization/quantization_mappings.py

index bc8a6b3..b0bc782 100644 (file)
@@ -2,9 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.intrinsic as nni
 import torch.nn.intrinsic.quantized as nniq
-import torch.nn.intrinsic.quantized._reference as nniqr
 import torch.nn.quantized as nnq
-import torch.nn.quantized._reference as nnqr
 import torch.nn.quantized.dynamic as nnqd
 import torch.quantization
 
@@ -211,12 +209,11 @@ class TestStaticQuantizedModule(QuantizationTestCase):
         self.assertEqual(rqr, rqr2)
 
     def _test_conv_api_impl(
-        self, module_name, qconv_module, conv_module, batch_size,
-        in_channels_per_group, input_feature_map_size, out_channels_per_group,
-        groups, kernel_size, stride, padding, padding_mode, dilation,
-        X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
-        use_bias, use_fused, use_channelwise, is_reference
-    ):
+            self, module_name, qconv_module, conv_module, batch_size,
+            in_channels_per_group, input_feature_map_size, out_channels_per_group,
+            groups, kernel_size, stride, padding, padding_mode, dilation,
+            X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
+            use_bias, use_fused, use_channelwise):
         for i in range(len(kernel_size)):
             assume(input_feature_map_size[i] + 2 * padding[i]
                    >= dilation[i] * (kernel_size[i] - 1) + 1)
@@ -245,8 +242,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
 
         # Test members
         self.assertTrue(module_name == qconv_module._get_name(), module_name + " " + qconv_module._get_name())
-        if not is_reference:
-            self.assertTrue(hasattr(qconv_module, '_packed_params'))
+        self.assertTrue(hasattr(qconv_module, '_packed_params'))
         self.assertTrue(hasattr(qconv_module, 'scale'))
         self.assertTrue(hasattr(qconv_module, 'zero_point'))
 
@@ -275,9 +271,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
         # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is
         # 4 assuming the rounding mode is round-to-nearest, ties-to-even.
         # skip numerics checking for reference module
-        if not is_reference:
-            np.testing.assert_array_almost_equal(
-                Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0)
+        np.testing.assert_array_almost_equal(
+            Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0)
 
         # Test serialization of quantized Conv Module using state_dict
         model_dict = qconv_module.state_dict()
@@ -297,8 +292,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
 
         self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module))
         self.assertTrue(module_name == loaded_qconv_module._get_name())
-        if not is_reference:
-            self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
+        self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
         self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias'))
 
         self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight())
@@ -308,9 +302,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
         self.assertEqual(qconv_module.zero_point,
                          loaded_qconv_module.zero_point)
         Y_loaded = loaded_qconv_module(X_q)
-        if not is_reference:
-            np.testing.assert_array_almost_equal(
-                Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)
+        np.testing.assert_array_almost_equal(
+            Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)
 
         # Test serialization
         b = io.BytesIO()
@@ -330,9 +323,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
         self.assertEqual(copied_conv.zero_point,
                          qconv_module.zero_point)
         Y_copied = copied_conv(X_q)
-        if not is_reference:
-            np.testing.assert_array_almost_equal(
-                Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)
+        np.testing.assert_array_almost_equal(
+            Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)
 
         deepcopied_conv = copy.deepcopy(qconv_module)
         self.assertEqual(deepcopied_conv.bias(), qconv_module.bias())
@@ -340,9 +332,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
         self.assertEqual(deepcopied_conv.zero_point,
                          qconv_module.zero_point)
         Y_deepcopied = copied_conv(X_q)
-        if not is_reference:
-            np.testing.assert_array_almost_equal(
-                Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)
+        np.testing.assert_array_almost_equal(
+            Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)
 
         # JIT testing
         self.checkScriptable(
@@ -377,9 +368,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
             [True, False],  # use_bias
             [True, False],  # use_fused
             [True, False],  # use_channelwise
-            [True, False]  # is_reference
         )
-        for pad_mode, use_bias, use_fused, use_channelwise, is_reference in options:
+        for pad_mode, use_bias, use_fused, use_channelwise in options:
             if torch.backends.quantized.engine == "qnnpack":
                 use_channelwise = False
             batch_size = 2
@@ -407,15 +397,13 @@ class TestStaticQuantizedModule(QuantizationTestCase):
             Y_zero_point = 4
             if torch.backends.quantized.engine == 'qnnpack':
                 use_channelwise = False
-            # (use_fused, is_reference) -> quantized class
+            # use_fused -> quantized class
             class_map = {
-                (True, True): (nniqr.ConvReLU1d, "QuantizedConvReLU1d(Reference)"),
-                (True, False): (nniq.ConvReLU1d, "QuantizedConvReLU1d"),
-                (False, True): (nnqr.Conv1d, "QuantizedConv1d(Reference)"),
-                (False, False): (nnq.Conv1d, "QuantizedConv1d")
+                True: (nniq.ConvReLU1d, "QuantizedConvReLU1d"),
+                False: (nnq.Conv1d, "QuantizedConv1d")
             }
 
-            qconv_cls, module_name = class_map[(use_fused, is_reference)]
+            qconv_cls, module_name = class_map[use_fused]
             qconv_module = qconv_cls(
                 in_channels, out_channels, kernel, stride, pad,
                 dilation, groups, use_bias, padding_mode=pad_mode
@@ -434,7 +422,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
                 in_channels_per_group, input_feature_map_size,
                 out_channels_per_group, groups, kernel_size, stride, pad, pad_mode,
                 dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
-                Y_zero_point, use_bias, use_fused, use_channelwise, is_reference)
+                Y_zero_point, use_bias, use_fused, use_channelwise)
 
     @override_qengines
     def test_conv2d_api(self):
@@ -443,9 +431,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
             [True, False],  # use_bias
             [True, False],  # use_fused
             [True, False],  # use_channelwise
-            [True, False]  # is_reference
         )
-        for pad_mode, use_bias, use_fused, use_channelwise, is_reference in options:
+        for pad_mode, use_bias, use_fused, use_channelwise in options:
             if torch.backends.quantized.engine == "qnnpack":
                 use_channelwise = False
             batch_size = 2
@@ -475,15 +462,13 @@ class TestStaticQuantizedModule(QuantizationTestCase):
             W_zero_point = [3]
             Y_scale = 5.0
             Y_zero_point = 4
-            # (use_fused, is_reference) -> quantized class
+            # use_fused -> quantized class
             class_map = {
-                (True, True): (nniqr.ConvReLU2d, "QuantizedConvReLU2d(Reference)"),
-                (True, False): (nniq.ConvReLU2d, "QuantizedConvReLU2d"),
-                (False, True): (nnqr.Conv2d, "QuantizedConv2d(Reference)"),
-                (False, False): (nnq.Conv2d, "QuantizedConv2d")
+                True: (nniq.ConvReLU2d, "QuantizedConvReLU2d"),
+                False: (nnq.Conv2d, "QuantizedConv2d")
             }
 
-            qconv_cls, module_name = class_map[(use_fused, is_reference)]
+            qconv_cls, module_name = class_map[use_fused]
             qconv_module = qconv_cls(
                 in_channels, out_channels, kernel_size, stride, padding,
                 dilation, groups, use_bias, padding_mode=pad_mode
@@ -502,7 +487,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
                 in_channels_per_group, input_feature_map_size,
                 out_channels_per_group, groups, kernel_size, stride, padding,
                 pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
-                Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise, is_reference)
+                Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise)
 
     @skipIfNoFBGEMM
     def test_conv3d_api(self):
@@ -510,9 +495,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
             [True, False],  # use_bias
             [True, False],  # use_fused
             [True, False],  # use_channelwise
-            [True, False]  # is_reference
         )
-        for use_bias, use_fused, use_channelwise, is_reference in options:
+        for use_bias, use_fused, use_channelwise in options:
             if torch.backends.quantized.engine == "qnnpack":
                 use_channelwise = False
             batch_size = 2
@@ -547,16 +531,14 @@ class TestStaticQuantizedModule(QuantizationTestCase):
             W_zero_point = [3]
             Y_scale = 5.0
             Y_zero_point = 4
-            # (use_fused, is_reference) -> quantized class
+            # use_fused -> quantized class
             class_map = {
-                (True, True): (nniqr.ConvReLU3d, "QuantizedConvReLU3d(Reference)"),
-                (True, False): (nniq.ConvReLU3d, "QuantizedConvReLU3d"),
-                (False, True): (nnqr.Conv3d, "QuantizedConv3d(Reference)"),
-                (False, False): (nnq.Conv3d, "QuantizedConv3d")
+                True: (nniq.ConvReLU3d, "QuantizedConvReLU3d"),
+                False: (nnq.Conv3d, "QuantizedConv3d")
             }
 
             with override_quantized_engine('fbgemm'):
-                qconv_cls, module_name = class_map[(use_fused, is_reference)]
+                qconv_cls, module_name = class_map[use_fused]
                 qconv_module = qconv_cls(
                     in_channels, out_channels, kernel_size, stride, padding,
                     dilation, groups, use_bias, padding_mode=pad_mode
@@ -576,7 +558,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
                     out_channels_per_group, groups, kernel_size, stride, padding,
                     pad_mode, dilation, X_scale, X_zero_point, W_scale,
                     W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused,
-                    use_channelwise, is_reference)
+                    use_channelwise)
 
     def test_pool_api(self):
         """Tests the correctness of the pool module.
index 7ae29e0..9682da1 100644 (file)
@@ -532,7 +532,7 @@ class TestQuantizeFx(QuantizationTestCase):
                 Conv1d,
                 conv1d_module_args,
                 (conv1d_input,),
-                ns.call_module(nn.Conv1d if is_reference else nnq.Conv1d),
+                ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d),
                 None
             ),
             (
@@ -540,7 +540,7 @@ class TestQuantizeFx(QuantizationTestCase):
                 Conv2d,
                 conv2d_module_args,
                 (conv2d_input,),
-                ns.call_module(nn.Conv2d if is_reference else nnq.Conv2d),
+                ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d),
                 None
             ),
             (
@@ -548,7 +548,7 @@ class TestQuantizeFx(QuantizationTestCase):
                 Conv3d,
                 conv3d_module_args,
                 (conv3d_input,),
-                ns.call_module(nn.Conv3d if is_reference else nnq.Conv3d),
+                ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d),
                 None
             ),
             (
@@ -631,11 +631,7 @@ class TestQuantizeFx(QuantizationTestCase):
             qr = result_dict["quantized_reference"]
 
             def checkWeightQParams(model):
-                for module_name in ("conv",):
-                    if hasattr(model, module_name):
-                        self.assertTrue(hasattr(qr.get_submodule(module_name), "_weight_qparams"))
-                        self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
-                for module_name in ("linear",):
+                for module_name in ("linear", "conv"):
                     if hasattr(model, module_name):
                         self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
                         self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
@@ -643,19 +639,7 @@ class TestQuantizeFx(QuantizationTestCase):
                         self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
 
             def checkSerDeser(model, is_dynamic):
-                for module_name in ("conv",):
-                    if hasattr(model, module_name):
-                        # make sure seralization works
-                        state_dict = copy.deepcopy(model.state_dict())
-                        self.assertTrue(module_name + "._weight_qparams" in state_dict)
-
-                        # check load_state_dict restores states
-                        module = getattr(model, module_name)
-                        prev_scale = module._weight_qparams["scale"]
-                        module._weight_qparams["scale"] = None
-                        model.load_state_dict(state_dict)
-                        self.assertTrue(torch.equal(prev_scale, module._weight_qparams["scale"]))
-                for module_name in ("linear",):
+                for module_name in ("linear", "conv"):
                     if hasattr(model, module_name):
                         # make sure seralization works
                         state_dict = copy.deepcopy(model.state_dict())
@@ -3001,6 +2985,44 @@ class TestQuantizeFx(QuantizationTestCase):
             result_ref = m_ref(data)
             self.assertTrue(torch.equal(result, result_ref))
 
+    def test_ref_conv_module(self):
+        """ Make sure the numerics for models with ref conv module
+        matches models with fbgemm/qnnpack module
+        """
+        convs = {
+            1: nn.Conv1d,
+            2: nn.Conv2d,
+            3: nn.Conv3d,
+        }
+
+        class M1(torch.nn.Module):
+            def __init__(self, dim):
+                super().__init__()
+                self.conv = convs[dim](3, 3, 3)
+
+            def forward(self, x):
+                return self.conv(x)
+
+        class M2(torch.nn.Module):
+            def __init__(self, dim):
+                super().__init__()
+                self.conv = convs[dim](3, 3, 3)
+                self.relu = torch.nn.ReLU()
+
+            def forward(self, x):
+                return self.relu(self.conv(x))
+
+        for dim, M in itertools.product([1, 2, 3], [M1, M2]):
+            m = M(dim).eval()
+            m = prepare_fx(m, {"": default_qconfig})
+            m_copy = copy.deepcopy(m)
+            m = convert_fx(m, is_reference=False)
+            m_ref = convert_fx(m_copy, is_reference=True)
+            data = self.img_data_dict[dim][0][0]
+            result = m(data)
+            result_ref = m_ref(data)
+            self.assertTrue(torch.equal(result, result_ref))
+
 @skipIfNoFBGEMM
 class TestQuantizeFxOps(QuantizationTestCase):
     """Unit tests for individual ops
@@ -4558,13 +4580,13 @@ class TestQuantizeFxOps(QuantizationTestCase):
             reference_order_check = [
                 ns.call_function(torch.quantize_per_tensor),
                 ns.call_method('dequantize'),
-                ns.call_module(nn.Conv2d),
+                ns.call_module(nnqr.Conv2d),
                 ns.call_function(torch.quantize_per_tensor),
                 ns.call_method('dequantize'),
                 ns.call_module(nn.Sigmoid),
                 ns.call_function(torch.quantize_per_tensor),
                 ns.call_method('dequantize'),
-                ns.call_module(nn.Conv2d),
+                ns.call_module(nnqr.Conv2d),
                 ns.call_function(torch.quantize_per_tensor),
                 ns.call_method('dequantize'),
             ]
diff --git a/torch/nn/intrinsic/quantized/_reference/__init__.py b/torch/nn/intrinsic/quantized/_reference/__init__.py
deleted file mode 100644 (file)
index 3d79bdb..0000000
+++ /dev/null
@@ -1 +0,0 @@
-from .modules import *  # noqa: F403
diff --git a/torch/nn/intrinsic/quantized/_reference/modules/__init__.py b/torch/nn/intrinsic/quantized/_reference/modules/__init__.py
deleted file mode 100644 (file)
index 33b18d8..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-import torch
-from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
-
-__all__ = [
-    'ConvReLU1d',
-    'ConvReLU2d',
-    'ConvReLU3d',
-]
diff --git a/torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py b/torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py
deleted file mode 100644 (file)
index b0305f6..0000000
+++ /dev/null
@@ -1,58 +0,0 @@
-import torch
-import torch.nn.quantized._reference as nnqr
-import torch.nn.functional as F
-
-class ConvReLU1d(nnqr.Conv1d):
-    _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x_dequant = x.dequantize()
-        weight_dequant = self._qweight.dequantize()
-        float_result = F.conv1d(
-            x_dequant, weight_dequant, self._bias, self._conv1d_stride,  # type: ignore[has-type]
-            self._conv1d_padding, self._conv1d_dilation, self.groups)  # type: ignore[has-type]
-        float_result = F.relu(float_result, inplace=True)
-        # NEEDFIX: we don't have dtype in the Linear module APIs right now!
-        result = torch.quantize_per_tensor(
-            float_result, self.scale, self.zero_point, torch.quint8)
-        return result
-
-    def _get_name(self):
-        return "QuantizedConvReLU1d(Reference)"
-
-
-class ConvReLU2d(nnqr.Conv2d):
-    _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x_dequant = x.dequantize()
-        weight_dequant = self._qweight.dequantize()
-        float_result = F.conv2d(
-            x_dequant, weight_dequant, self._bias, self.stride,
-            self.padding, self.dilation, self.groups)
-        float_result = F.relu(float_result, inplace=True)
-        # NEEDFIX: we don't have dtype in the Linear module APIs right now!
-        result = torch.quantize_per_tensor(
-            float_result, self.scale, self.zero_point, torch.quint8)
-        return result
-
-    def _get_name(self):
-        return "QuantizedConvReLU2d(Reference)"
-
-class ConvReLU3d(nnqr.Conv3d):
-    _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x_dequant = x.dequantize()
-        weight_dequant = self._qweight.dequantize()
-        float_result = F.conv3d(
-            x_dequant, weight_dequant, self._bias, self.stride,
-            self.padding, self.dilation, self.groups)
-        float_result = F.relu(float_result, inplace=True)
-        # NEEDFIX: we don't have dtype in the Linear module APIs right now!
-        result = torch.quantize_per_tensor(
-            float_result, self.scale, self.zero_point, torch.quint8)
-        return result
-
-    def _get_name(self):
-        return "QuantizedConvReLU3d(Reference)"
index 036f8e4..6b03bb0 100644 (file)
 import torch
-import torch.nn.quantized as nnq
+import torch.nn as nn
 import torch.nn.functional as F
-from typing import Optional
+from typing import Optional, Dict, Any
 from torch.nn.common_types import _size_1_t
-from torch.nn.modules.utils import _single
+from .utils import _quantize_and_dequantize_weight
+from .utils import _save_weight_qparams
+from .utils import _get_weight_qparam_keys
 
-class _ConvNd(nnq._ConvNd):
+class _ConvNd(torch.nn.modules.conv._ConvNd):
     """ A reference version of nn.quantized.Conv2d
         we will not pack the parameters in this module, since weight packing is an
         optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
         this is useful when user want to use this module in other backends like Glow.
     """
-    __annotations__ = {"_bias": Optional[torch.Tensor]}
+    __annotations__ = {"bias": Optional[torch.Tensor]}
 
     def _save_to_state_dict(self, destination, prefix, keep_vars):
         super()._save_to_state_dict(destination, prefix, keep_vars)
-        destination[prefix + '_qweight'] = self._qweight
-        destination[prefix + '_bias'] = self._bias
+        _save_weight_qparams(
+            destination, prefix, self.weight_qscheme, self.weight_dtype,
+            self.weight_scale, self.weight_zero_point, self.weight_axis)
 
     def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                               missing_keys, unexpected_keys, error_msgs):
-        self._qweight = state_dict[prefix + '_qweight']
-        self._bias = state_dict[prefix + '_bias']
-        state_dict.pop(prefix + '_qweight')
-        state_dict.pop(prefix + '_bias')
+        for key in _get_weight_qparam_keys(state_dict, prefix):
+            setattr(self, key, state_dict[prefix + key])
+            state_dict.pop(prefix + key)
 
         super()._load_from_state_dict(
             state_dict, prefix, local_metadata, False,
             missing_keys, unexpected_keys, error_msgs)
 
-    def _weight_bias(self):
-        return self._qweight, self._bias
-
-    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
-        self._qweight = w
-        self._bias = b
-
-class Conv1d(_ConvNd, nnq.Conv1d):
+    def _init_weight_qparams(self, weight_qparams, device):
+        if weight_qparams is None:
+            weight_qparams = {
+                "qscheme": torch.per_tensor_affine,
+                "dtype": torch.quint8,
+                "scale": 1.0,
+                "zero_point": 0
+            }
+        self.weight_qscheme = weight_qparams["qscheme"]
+        self.weight_dtype = weight_qparams["dtype"]
+        assert self.weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
+            Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized linear module")
+        if self.weight_qscheme is not None:
+            self.register_buffer(
+                "weight_scale",
+                torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
+            self.register_buffer(
+                "weight_zero_point",
+                torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device))
+            if self.weight_qscheme == torch.per_channel_affine:
+                self.register_buffer(
+                    "weight_axis",
+                    torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device))
+            else:
+                # added for TorchScriptability, not used
+                self.register_buffer(
+                    "weight_axis", torch.tensor(0, dtype=torch.int, device=device))
+
+    def get_weight(self):
+        """
+        Fake quantize (quantize and dequantize) the weight with
+        the quantization parameters for weight, this is used to
+        simulate the numerics for the quantized weight in a quantized
+        model
+        """
+        # supress mypy warning
+        assert isinstance(self.weight, torch.Tensor)
+        assert isinstance(self.weight_scale, torch.Tensor)
+        assert isinstance(self.weight_zero_point, torch.Tensor)
+        assert isinstance(self.weight_axis, torch.Tensor)
+        return _quantize_and_dequantize_weight(
+            self.weight, self.weight_qscheme,
+            self.weight_dtype, self.weight_scale, self.weight_zero_point, self.weight_axis)
+
+    @staticmethod
+    def from_float(cls, float_conv, weight_qparams):
+        qref_conv = cls(
+            float_conv.in_channels,
+            float_conv.out_channels,
+            float_conv.kernel_size,  # type: ignore[arg-type]
+            float_conv.stride,  # type: ignore[arg-type]
+            float_conv.padding,  # type: ignore[arg-type]
+            float_conv.dilation,  # type: ignore[arg-type]
+            float_conv.groups,
+            float_conv.bias is not None,  # type: ignore[arg-type]
+            float_conv.padding_mode,
+            device=float_conv.weight.device,
+            dtype=float_conv.weight.dtype,
+            weight_qparams=weight_qparams)
+        qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
+        if float_conv.bias is not None:
+            qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
+        return qref_conv
+
+class Conv1d(_ConvNd, nn.Conv1d):
     def __init__(self,
                  in_channels: int,
                  out_channels: int,
@@ -46,91 +105,107 @@ class Conv1d(_ConvNd, nnq.Conv1d):
                  dilation: _size_1_t = 1,
                  groups: int = 1,
                  bias: bool = True,
-                 padding_mode: str = 'zeros'):
-        nnq.Conv1d.__init__(
+                 padding_mode: str = "zeros",
+                 device=None,
+                 dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None):
+        nn.Conv1d.__init__(
             self, in_channels, out_channels, kernel_size, stride, padding, dilation,
-            groups, bias, padding_mode)
-        # self.stride, self.padding, self.dilation are 2d tuple since
-        # current quantized conv1d is using Conv2dPackedParams
-        # TODO: we should fix this if we implemenet Conv1dPackedParams
-        self._conv1d_stride = _single(self.stride[0])
-        self._conv1d_padding = _single(self.padding[0])
-        self._conv1d_dilation = _single(self.dilation[0])
+            groups, bias, padding_mode, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x_dequant = x.dequantize()
-        weight_dequant = self._qweight.dequantize()
-        float_result = F.conv1d(
-            x_dequant, weight_dequant, self._bias, self._conv1d_stride,
-            self._conv1d_padding, self._conv1d_dilation, self.groups)
-        # NEEDFIX: we don't have dtype in the Linear module APIs right now!
-        result = torch.quantize_per_tensor(
-            float_result, self.scale, self.zero_point, torch.quint8)
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.conv1d ---
+
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.conv1d --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized conv1d
+        """
+        weight_dequant = self.get_weight()
+        result = F.conv1d(
+            x, weight_dequant, self.bias, self.stride,
+            self.padding, self.dilation, self.groups)
         return result
 
     def _get_name(self):
-        return 'QuantizedConv1d(Reference)'
-
-    @torch.jit.export
-    def __setstate__(self, state):
-        self.in_channels = state[0]
-        self.out_channels = state[1]
-        self.kernel_size = state[2]
-        self.stride = state[3]
-        self.padding = state[4]
-        self.dilation = state[5]
-        self.transposed = state[6]
-        self.output_padding = state[7]
-        self.groups = state[8]
-        self.padding_mode = state[9]
-        self.set_weight_bias(state[10], state[11])
-        self.scale = state[12]
-        self.zero_point = state[13]
-        self.training = state[14]
-        self._conv1d_stride = (self.stride[0],)
-        self._conv1d_padding = (self.padding[0],)
-        self._conv1d_dilation = (self.dilation[0],)
-
-class Conv2d(_ConvNd, nnq.Conv2d):
+        return "QuantizedConv1d(Reference)"
+
+    @classmethod
+    def from_float(cls, float_conv, weight_qparams):
+        return _ConvNd.from_float(cls, float_conv, weight_qparams)
+
+class Conv2d(_ConvNd, nn.Conv2d):
     def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                  padding=0, dilation=1, groups=1, bias=True,
-                 padding_mode='zeros'):
-        nnq.Conv2d.__init__(
+                 padding_mode='zeros',
+                 device=None,
+                 dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None):
+        nn.Conv2d.__init__(
             self, in_channels, out_channels, kernel_size, stride, padding, dilation,
-            groups, bias, padding_mode)
+            groups, bias, padding_mode, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x_dequant = x.dequantize()
-        weight_dequant = self._qweight.dequantize()
-        float_result = F.conv2d(
-            x_dequant, weight_dequant, self._bias, self.stride,
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.conv2d ---
+
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.conv2d --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized conv2d
+        """
+        weight_dequant = self.get_weight()
+        result = F.conv2d(
+            x, weight_dequant, self.bias, self.stride,
             self.padding, self.dilation, self.groups)
-        # NEEDFIX: we don't have dtype in the Linear module APIs right now!
-        result = torch.quantize_per_tensor(
-            float_result, self.scale, self.zero_point, torch.quint8)
         return result
 
     def _get_name(self):
-        return 'QuantizedConv2d(Reference)'
+        return "QuantizedConv2d(Reference)"
 
-class Conv3d(_ConvNd, nnq.Conv3d):
+    @classmethod
+    def from_float(cls, float_conv, weight_qparams):
+        return _ConvNd.from_float(cls, float_conv, weight_qparams)
+
+class Conv3d(_ConvNd, nn.Conv3d):
     def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                  padding=0, dilation=1, groups=1, bias=True,
-                 padding_mode='zeros'):
-        nnq.Conv3d.__init__(
+                 padding_mode="zeros",
+                 device=None,
+                 dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None):
+        nn.Conv3d.__init__(
             self, in_channels, out_channels, kernel_size, stride, padding, dilation,
             groups, bias, padding_mode)
+        self._init_weight_qparams(weight_qparams, device)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x_dequant = x.dequantize()
-        weight_dequant = self._qweight.dequantize()
-        float_result = F.conv3d(
-            x_dequant, weight_dequant, self._bias, self.stride,
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.conv3d ---
+
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.conv3d --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized conv3d
+        """
+        weight_dequant = self.get_weight()
+        result = F.conv3d(
+            x, weight_dequant, self.bias, self.stride,
             self.padding, self.dilation, self.groups)
-        # NEEDFIX: we don't have dtype in the Linear module APIs right now!
-        result = torch.quantize_per_tensor(
-            float_result, self.scale, self.zero_point, torch.quint8)
         return result
 
     def _get_name(self):
-        return 'QuantizedConv3d(Reference)'
+        return "QuantizedConv3d(Reference)"
+
+    @classmethod
+    def from_float(cls, float_conv, weight_qparams):
+        return _ConvNd.from_float(cls, float_conv, weight_qparams)
index 779dfcf..418cae1 100644 (file)
@@ -638,19 +638,22 @@ class ConvReluQuantizeHandler(QuantizeHandler):
                 # and qparam is a dictionary of
                 # {"qscheme": ..., "scale": ..., "zero_point": ...} for per tensor quantization or
                 # {"qscheme": ..., "scale": ..., "zero_point": ..., "axis": ...} for per channel quantization
+                float_conv = self.conv
+                fused_conv = None
                 if isinstance(
-                        self.conv,
+                        float_conv,
                         QAT_CONV_MODULE_CLASSES):
                     # case 1. converting qat conv module to
                     # a float conv module, we need to attch
                     # weight fake_quant to the conv module,
                     # weight fake_quant is assumed to be run during
                     # QAT so we don't need to run it again here
-                    float_conv = self.conv.to_float()
+                    float_conv = self.conv.to_float()  # type: ignore[operator]
                     # change qat conv to conv
                     parent_name, name = _parent_name(self.conv_node.target)
                     setattr(modules[parent_name], name, float_conv)
                     if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
+                        fused_conv = float_conv
                         float_conv = float_conv[0]
                     weight_post_process = self.conv.weight_fake_quant
                 else:
@@ -658,15 +661,28 @@ class ConvReluQuantizeHandler(QuantizeHandler):
                     # to float conv module, we need to attach
                     # weight observer to the conv module and run it
                     # with conv weight
-                    float_conv = self.conv
-                    if isinstance(self.conv, torch.nn.intrinsic._FusedModule):
-                        float_conv = self.conv[0]
+                    if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
+                        fused_conv = float_conv
+                        float_conv = float_conv[0]  # type: ignore[index]
                     assert qconfig is not None
                     weight_post_process = qconfig.weight()
                     # run weight observer
-                    weight_post_process(float_conv.weight)
+                    weight_post_process(float_conv.weight)  # type: ignore[operator]
                 weight_qparams = get_qparam_dict(weight_post_process)
-                _to_reference(float_conv, weight_qparams)
+                # hardcoded for now, TODO: expose the api to user,
+                # we can have a map from module to reference module
+                # and allow user to register new ones
+                qconv_cls = get_static_quant_module_class(
+                    type(float_conv), is_reference=is_reference)
+                ref_conv = qconv_cls.from_float(float_conv, weight_qparams)  # type: ignore[attr-defined]
+                # if the parent is a fused conv (Sequential), we can replace the first
+                # item to ref conv, otherwise we can update
+                # the conv instance in the module tree
+                if fused_conv is not None:
+                    fused_conv[0] = ref_conv
+                else:
+                    parent_name, name = _parent_name(self.conv_node.target)
+                    setattr(modules[parent_name], name, ref_conv)
                 op_out = quantized_graph.create_node(
                     'call_module',
                     self.conv_node.target,
index 03b1778..6851ba7 100644 (file)
@@ -7,7 +7,6 @@ import torch.nn.functional as F
 import torch.nn.intrinsic as nni
 import torch.nn.intrinsic.quantized as nniq
 import torch.nn.intrinsic.quantized.dynamic as nniqd
-import torch.nn.intrinsic.quantized._reference as nniqr
 import torch.nn.intrinsic.qat as nniqat
 import torch.nn.quantized as nnq
 import torch.nn.quantized._reference as nnqr
@@ -29,20 +28,6 @@ DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
     nn.Conv1d: nnqr.Conv1d,
     nn.Conv2d: nnqr.Conv2d,
     nn.Conv3d: nnqr.Conv3d,
-    nni.ConvReLU1d: nniqr.ConvReLU1d,
-    nni.ConvReLU2d: nniqr.ConvReLU2d,
-    nni.ConvReLU3d: nniqr.ConvReLU3d,
-    # QAT Modules
-    nnqat.Conv2d: nnqr.Conv2d,
-    nnqat.Conv3d: nnqr.Conv3d,
-    nniqat.ConvBn1d: nnqr.Conv1d,
-    nniqat.ConvBn2d: nnqr.Conv2d,
-    nniqat.ConvBn3d: nnqr.Conv3d,
-    nniqat.ConvBnReLU1d: nniqr.ConvReLU1d,
-    nniqat.ConvBnReLU2d: nniqr.ConvReLU2d,
-    nniqat.ConvBnReLU3d: nniqr.ConvReLU3d,
-    nniqat.ConvReLU2d: nniqr.ConvReLU2d,
-    nniqat.ConvReLU3d: nniqr.ConvReLU3d,
 }
 
 # Default map for swapping float module to quantized ones