[quant][graphmode][fx] Add reference quantized linear module (#63627)
authorJerry Zhang <jerryzh@fb.com>
Sat, 28 Aug 2021 03:58:20 +0000 (20:58 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 28 Aug 2021 05:53:24 +0000 (22:53 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63627

Added reference quantized linear module for the custom backend flow, the reference quantized module will
have the following code:
```
        w(float) -- quant - dequant \
        x(float) ------------- F.linear ---
```
In the full model, we will see
```
        w(float) -- quant - *dequant \
        x -- quant --- *dequant --  *F.linear --- *quant - dequant
```
and the backend should be able to fuse the ops with `*` into a quantized linear

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_conv_linear_reference

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D30504750

fbshipit-source-id: 5729921745c2b6a0fb344efc3689f3b170e89500

test/quantization/core/test_quantized_module.py
test/quantization/fx/test_quantize_fx.py
torch/nn/intrinsic/quantized/_reference/modules/__init__.py
torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py [deleted file]
torch/nn/quantized/_reference/modules/linear.py
torch/nn/quantized/_reference/modules/utils.py [new file with mode: 0644]
torch/quantization/fx/quantization_patterns.py
torch/quantization/quantization_mappings.py

index 10d5831..bc8a6b3 100644 (file)
@@ -6,7 +6,6 @@ import torch.nn.intrinsic.quantized._reference as nniqr
 import torch.nn.quantized as nnq
 import torch.nn.quantized._reference as nnqr
 import torch.nn.quantized.dynamic as nnqd
-import torch.nn.functional as F
 import torch.quantization
 
 from torch.quantization import (
@@ -70,24 +69,21 @@ class TestStaticQuantizedModule(QuantizationTestCase):
             [4, 8],
             [True, False],
             [True, False],
-            [True, False],
             [True, False])
         for (batch_size, in_features, out_features, use_bias,
-             use_fused, per_channel, is_reference) in options:
+             use_fused, per_channel) in options:
             self._test_linear_api_impl(
                 batch_size, in_features, out_features, use_bias, use_fused,
-                per_channel, is_reference)
+                per_channel)
 
-    def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel, is_reference):
+    def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel):
         if torch.backends.quantized.engine == 'qnnpack':
             per_channel = False
 
-        # (use_fused, is_reference) -> quantized class
+        # use_fused -> quantized class
         class_map = {
-            (True, True) : nniqr.LinearReLU,
-            (True, False) : nniq.LinearReLU,
-            (False, True) : nnqr.Linear,
-            (False, False) : nnq.Linear,
+            True: nniq.LinearReLU,
+            False: nnq.Linear,
         }
 
         W = torch.rand(out_features, in_features).float()
@@ -107,7 +103,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
         B = torch.rand(out_features).float() if use_bias else None
         scale = 0.5
         zero_point = 3
-        qlinear = class_map[(use_fused, is_reference)](in_features, out_features)
+        qlinear = class_map[use_fused](in_features, out_features)
 
         qlinear_copy = qlinear  # deepcopy does not work right now
         # qlinear_copy = copy.deepcopy(qlinear)
@@ -127,21 +123,11 @@ class TestStaticQuantizedModule(QuantizationTestCase):
 
         # Check if the module implementation matches calling the
         # ops directly
-        if is_reference:
-            weight = qlinear._qweight
-            bias = qlinear._bias
-            weight_dequant = weight.dequantize()
-            X_q_dq = X_q.dequantize()
-            Z_ref = F.linear(X_q_dq, weight_dequant, bias)
-            if use_fused:
-                Z_ref = F.relu(Z_ref, inplace=True)
-            Z_ref = torch.quantize_per_tensor(Z_ref, scale, zero_point, torch.quint8)
+        W_pack = qlinear._packed_params._packed_params
+        if use_fused:
+            Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point)
         else:
-            W_pack = qlinear._packed_params._packed_params
-            if use_fused:
-                Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point)
-            else:
-                Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)
+            Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)
 
         self.assertEqual(Z_ref, Z_q)
         self.assertTrue(
@@ -163,16 +149,12 @@ class TestStaticQuantizedModule(QuantizationTestCase):
             else:
                 self.assertEqual(model_dict[key], loaded_dict[key])
 
-        loaded_qlinear = class_map[(use_fused, is_reference)](
+        loaded_qlinear = class_map[use_fused](
             in_features, out_features)
         loaded_qlinear.load_state_dict(loaded_dict)
-        if is_reference:
-            self.assertEqual(qlinear._qweight, loaded_qlinear._qweight)
-            self.assertEqual(qlinear._bias, loaded_qlinear._bias)
-        else:
-            linear_unpack = torch.ops.quantized.linear_unpack
-            self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
-                             linear_unpack(loaded_qlinear._packed_params._packed_params))
+        linear_unpack = torch.ops.quantized.linear_unpack
+        self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
+                         linear_unpack(loaded_qlinear._packed_params._packed_params))
         self.assertEqual(qlinear.scale, loaded_qlinear.scale)
         self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
         # make sure loaded_qlinear has the same dir as qlinear since
@@ -180,8 +162,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
         self.checkScriptable(loaded_qlinear, [[X_q]], check_save_load=True)
         self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
         self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
-        if not is_reference:
-            self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
+        self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
         Z_q2 = loaded_qlinear(X_q)
         self.assertEqual(Z_q, Z_q2)
 
index 762919e..7ae29e0 100644 (file)
@@ -3,6 +3,7 @@ import torch
 import torch.nn.functional as F
 import torch.nn as nn
 import torch.nn.quantized as nnq
+import torch.nn.quantized._reference as nnqr
 import torch.nn.quantized.dynamic as nnqd
 import torch.nn.intrinsic as nni
 import torch.nn.intrinsic.quantized as nniq
@@ -571,7 +572,7 @@ class TestQuantizeFx(QuantizationTestCase):
                 LinearModule,
                 (),
                 (linear_module_input,),
-                ns.call_module(nn.Linear) if is_reference else ns.call_module(nnqd.Linear),
+                ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear),
                 None,
             ),
             (
@@ -579,7 +580,7 @@ class TestQuantizeFx(QuantizationTestCase):
                 LinearModule,
                 (),
                 (linear_module_input,),
-                ns.call_module(nn.Linear if is_reference else nnq.Linear),
+                ns.call_module(nnqr.Linear if is_reference else nnq.Linear),
                 None,
             ),
         ]
@@ -608,6 +609,13 @@ class TestQuantizeFx(QuantizationTestCase):
         """ Test quantizing functional conv and linear with reference option
         """
         tests = self._get_conv_linear_test_cases(is_reference=True)
+
+        def _get_keys(prefix, is_dynamic):
+            all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
+            if not is_dynamic:
+                all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
+            return all_keys
+
         for (is_dynamic, ModuleClass, module_constructor_inputs,
              inputs, quantized_node, weight_prepack_node) in tests:
             quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
@@ -623,13 +631,19 @@ class TestQuantizeFx(QuantizationTestCase):
             qr = result_dict["quantized_reference"]
 
             def checkWeightQParams(model):
-                for module_name in ("linear", "conv"):
+                for module_name in ("conv",):
                     if hasattr(model, module_name):
                         self.assertTrue(hasattr(qr.get_submodule(module_name), "_weight_qparams"))
                         self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
+                for module_name in ("linear",):
+                    if hasattr(model, module_name):
+                        self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
+                        self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
+                        self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
+                        self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
 
-            def checkSerDeser(model):
-                for module_name in ("linear", "conv"):
+            def checkSerDeser(model, is_dynamic):
+                for module_name in ("conv",):
                     if hasattr(model, module_name):
                         # make sure seralization works
                         state_dict = copy.deepcopy(model.state_dict())
@@ -641,6 +655,20 @@ class TestQuantizeFx(QuantizationTestCase):
                         module._weight_qparams["scale"] = None
                         model.load_state_dict(state_dict)
                         self.assertTrue(torch.equal(prev_scale, module._weight_qparams["scale"]))
+                for module_name in ("linear",):
+                    if hasattr(model, module_name):
+                        # make sure seralization works
+                        state_dict = copy.deepcopy(model.state_dict())
+                        all_keys = _get_keys(module_name, is_dynamic)
+                        for key in all_keys:
+                            self.assertTrue(key in state_dict)
+                        # check load_state_dict restores states
+                        module = getattr(model, module_name)
+                        prev_scale = module.weight_scale
+                        module.weight_scale = None
+                        model.load_state_dict(state_dict)
+                        module = getattr(model, module_name)
+                        self.assertTrue(torch.equal(prev_scale, module.weight_scale))
 
 
             checkWeightQParams(qr)
@@ -648,7 +676,7 @@ class TestQuantizeFx(QuantizationTestCase):
             # make sure the qparams are preserved after copy
             checkWeightQParams(qr)
 
-            checkSerDeser(qr)
+            checkSerDeser(qr, is_dynamic)
 
     @skipIfNoFBGEMM
     def test_dynamic_quant_weight_observer(self):
@@ -2941,6 +2969,38 @@ class TestQuantizeFx(QuantizationTestCase):
             ]
             self.checkGraphModuleNodes(m, expected_node_list=node_list)
 
+    def test_ref_linear_module(self):
+        """ Make sure the numerics for models with ref linear module
+        matches models with fbgemm/qnnpack module
+        """
+        class M1(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(10, 5)
+
+            def forward(self, x):
+                return self.linear(x)
+
+        class M2(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(10, 5)
+                self.relu = torch.nn.ReLU()
+
+            def forward(self, x):
+                return self.relu(self.linear(x))
+
+        for M in [M1, M2]:
+            m = M().eval()
+            m = prepare_fx(m, {"": default_qconfig})
+            m_copy = copy.deepcopy(m)
+            m = convert_fx(m, is_reference=False)
+            m_ref = convert_fx(m_copy, is_reference=True)
+            data = torch.randn(5, 10)
+            result = m(data)
+            result_ref = m_ref(data)
+            self.assertTrue(torch.equal(result, result_ref))
+
 @skipIfNoFBGEMM
 class TestQuantizeFxOps(QuantizationTestCase):
     """Unit tests for individual ops
index bf8ff3a..33b18d8 100644 (file)
@@ -1,9 +1,7 @@
 import torch
-from .linear_relu import LinearReLU
 from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
 
 __all__ = [
-    'LinearReLU',
     'ConvReLU1d',
     'ConvReLU2d',
     'ConvReLU3d',
diff --git a/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py b/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py
deleted file mode 100644 (file)
index 39c5953..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-import torch
-import torch.nn.intrinsic as nni
-import torch.nn.quantized._reference as nnqr
-import torch.nn.functional as F
-
-class LinearReLU(nnqr.Linear):
-    _FLOAT_MODULE = nni.LinearReLU
-
-    def __init__(
-            self,
-            in_features,
-            out_features,
-            bias=True,
-            dtype=torch.qint8):
-        super().__init__(in_features, out_features, bias, dtype)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x_dequant = x.dequantize()
-        weight_dequant = self._qweight.dequantize()
-        float_result = F.linear(x_dequant, weight_dequant, self._bias)
-        float_result = F.relu(float_result, inplace=True)
-        # NEEDFIX: we don't have dtype in the Linear module APIs right now!
-        result = torch.quantize_per_tensor(
-            float_result, self.scale, self.zero_point, torch.quint8)
-        return result
-
-    def _get_name(self):
-        return "QuantizedLinearReLU(Reference)"
index 276dc01..1df5499 100644 (file)
 import torch
-import torch.nn.quantized as nnq
+import torch.nn as nn
 import torch.nn.functional as F
-from typing import Optional
+from typing import Optional, Dict, Any
+from .utils import _quantize_and_dequantize_weight
+from .utils import _save_weight_qparams
+from .utils import _get_weight_qparam_keys
 
-class Linear(nnq.Linear):
-    """ A backend independent version of nn.quantized.Linear
-        we will not pack the parameters in this module, since weight packing is an
-        optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
-        this is useful when user want to use this module in other backends like Glow.
+class Linear(nn.Linear):
+    """ A reference quantized linear module that fits into the FX
+    Graph Mode Quantization workflow
+    activation will be floating point Tensor, we will store floating
+    point weight as well in the module, but in forward we'll quantize
+    and dequantize the weight before running the floating point functional
+    linear operator.
     """
-    def __init__(self, in_features, out_features, bias_=True,
-                 dtype=torch.qint8):
-        super().__init__(in_features, out_features, bias_, dtype)
-        self._qweight, self._bias = self._packed_params._weight_bias()
-        del self._packed_params
+    def __init__(
+            self,
+            in_features: int,
+            out_features: int,
+            bias_: bool = True,
+            device: Optional[torch.device] = None,
+            dtype: Optional[torch.dtype] = None,
+            weight_qparams: Optional[Dict[str, Any]] = None):
+        super().__init__(in_features, out_features, bias_, device, dtype)
+        if weight_qparams is None:
+            weight_qparams = {
+                "qscheme": torch.per_tensor_affine,
+                "dtype": torch.quint8,
+                "scale": 1.0,
+                "zero_point": 0
+            }
+        self.weight_qscheme = weight_qparams["qscheme"]
+        self.weight_dtype = weight_qparams["dtype"]
+        assert self.weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
+            Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized linear module")
+        if self.weight_qscheme is not None:
+            self.register_buffer(
+                "weight_scale",
+                torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
+            self.register_buffer(
+                "weight_zero_point",
+                torch.tensor(
+                    weight_qparams["zero_point"],
+                    dtype=torch.int, device=device))
+            if self.weight_qscheme == torch.per_channel_affine:
+                self.register_buffer(
+                    "weight_axis",
+                    torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device))
+            else:
+                # added for TorchScriptability, not used
+                self.register_buffer(
+                    "weight_axis",
+                    torch.tensor(0, dtype=torch.int, device=device))
 
     def _get_name(self):
         return "QuantizedLinear(Reference)"
 
+    def get_weight(self):
+        """
+        Fake quantize (quantize and dequantize) the weight with
+        the quantization parameters for weight, this is used to
+        simulate the numerics for the quantized weight in a quantized
+        model
+        """
+        # supress mypy warning
+        assert isinstance(self.weight, torch.Tensor)
+        assert isinstance(self.weight_scale, torch.Tensor)
+        assert isinstance(self.weight_zero_point, torch.Tensor)
+        assert isinstance(self.weight_axis, torch.Tensor)
+        return _quantize_and_dequantize_weight(
+            self.weight, self.weight_qscheme, self.weight_dtype, self.weight_scale,
+            self.weight_zero_point, self.weight_axis)
+
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x_dequant = x.dequantize()
-        weight_dequant = self._qweight.dequantize()
-        float_result = F.linear(x_dequant, weight_dequant, self._bias)
-        # NEEDFIX: we don't have dtype in the Linear module APIs right now!
-        result = torch.quantize_per_tensor(
-            float_result, self.scale, self.zero_point, torch.quint8)
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.linear ---
+
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.linear --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized linear
+        """
+        weight_dequant = self.get_weight()
+        result = F.linear(x, weight_dequant, self.bias)
         return result
 
     def _save_to_state_dict(self, destination, prefix, keep_vars):
         super()._save_to_state_dict(destination, prefix, keep_vars)
-        destination[prefix + '_qweight'] = self._qweight
-        destination[prefix + '_bias'] = self._bias
+        _save_weight_qparams(
+            destination, prefix, self.weight_qscheme, self.weight_dtype,
+            self.weight_scale, self.weight_zero_point, self.weight_axis)
 
     def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                               missing_keys, unexpected_keys, error_msgs):
-        self._qweight = state_dict[prefix + '_qweight']
-        self._bias = state_dict[prefix + '_bias']
-        state_dict.pop(prefix + '_qweight')
-        state_dict.pop(prefix + '_bias')
+        for key in _get_weight_qparam_keys(state_dict, prefix):
+            setattr(self, key, state_dict[prefix + key])
+            state_dict.pop(prefix + key)
 
         super()._load_from_state_dict(
             state_dict, prefix, local_metadata, False,
             missing_keys, unexpected_keys, error_msgs)
 
-    def _weight_bias(self):
-        return self._qweight, self._bias
-
-    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
-        self._qweight = w
-        self._bias = b
+    @classmethod
+    def from_float(cls, float_linear, weight_qparams):
+        qref_linear = Linear(
+            float_linear.in_features, float_linear.out_features,
+            float_linear.bias is not None, device=float_linear.weight.device,
+            dtype=float_linear.weight.dtype, weight_qparams=weight_qparams)
+        qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
+        if float_linear.bias is not None:
+            qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
+        return qref_linear
diff --git a/torch/nn/quantized/_reference/modules/utils.py b/torch/nn/quantized/_reference/modules/utils.py
new file mode 100644 (file)
index 0000000..7c36650
--- /dev/null
@@ -0,0 +1,45 @@
+import torch
+from typing import Dict, Any
+
+def _quantize_and_dequantize_weight(
+        weight: torch.Tensor,
+        weight_qscheme: torch.qscheme,
+        weight_dtype: torch.dtype,
+        weight_scale: torch.Tensor,
+        weight_zero_point: torch.Tensor,
+        weight_axis: torch.Tensor):
+    """ Quantize and then dequantize the weight based on
+    the quantization parameters
+    """
+    if weight_qscheme == torch.per_tensor_affine:
+        weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
+        weight_dequant = weight.dequantize()
+    elif weight_qscheme == torch.per_channel_affine:
+        weight = torch.quantize_per_channel(
+            weight, weight_scale,
+            weight_zero_point, weight_axis.item(), weight_dtype)  # type: ignore[arg-type]
+        weight_dequant = weight.dequantize()
+    else:
+        weight_dequant = weight
+    return weight_dequant
+
+def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis):
+    destination[prefix + "weight_qscheme"] = weight_qscheme
+    destination[prefix + "weight_dtype"] = weight_dtype
+    if weight_qscheme is not None:
+        destination[prefix + "weight_scale"] = weight_scale
+        destination[prefix + "weight_zero_point"] = weight_zero_point
+        if weight_qscheme == torch.per_channel_affine:
+            destination[prefix + "weight_axis"] = weight_axis
+
+def _get_weight_qparam_keys(
+        state_dict: Dict[str, Any],
+        prefix: str):
+    keys = ["weight_qscheme", "weight_dtype"]
+    weight_qscheme = state_dict[prefix + "weight_qscheme"]
+    if weight_qscheme is not None:
+        keys.append("weight_scale")
+        keys.append("weight_zero_point")
+        if weight_qscheme == torch.quantize_per_channel:
+            keys.append("weight_axis")
+    return keys
index 6362961..e8b8736 100644 (file)
@@ -869,6 +869,7 @@ class LinearReLUQuantizeHandler(QuantizeHandler):
                 # Get the float linear and attach qscheme and qparams
                 # the the module
                 float_linear = self.linear
+                fused_linear = None
                 if isinstance(float_linear, (torch.nn.qat.Linear, torch.nn.intrinsic.qat.LinearReLU)):
                     float_linear = float_linear.to_float()
                     # change qat linear to linear
@@ -876,10 +877,12 @@ class LinearReLUQuantizeHandler(QuantizeHandler):
                     setattr(modules[parent_name], name, float_linear)
                     # Attach weight fake quant to the linear module
                     if isinstance(float_linear, torch.nn.intrinsic.LinearReLU):
+                        fused_linear = float_linear
                         float_linear = float_linear[0]
                     weight_post_process = self.linear.weight_fake_quant
                 else:
                     if isinstance(float_linear, torch.nn.intrinsic.LinearReLU):
+                        fused_linear = float_linear
                         float_linear = self.linear[0]  # type: ignore[index]
                     # Attach the weight observer to the module
                     weight_post_process = qconfig.weight()  # type: ignore[union-attr]
@@ -887,7 +890,21 @@ class LinearReLUQuantizeHandler(QuantizeHandler):
                     weight_post_process(float_linear.weight)  # type: ignore[operator]
 
                 weight_qparams = get_qparam_dict(weight_post_process)
-                _to_reference(float_linear, weight_qparams)
+                # TODO: include the configuration in backend_config_dict
+                # we can have a map from module to reference module
+                # and allow user to register new ones
+                qlinear_cls = get_static_quant_module_class(
+                    type(float_linear), is_reference=is_reference)
+                ref_linear = qlinear_cls.from_float(float_linear, weight_qparams)
+
+                # if the parent is a fused linear (Sequential), we can replace the first
+                # item to ref linear, otherwise we can update
+                # the linear instance in the module tree
+                if fused_linear is not None:
+                    fused_linear[0] = ref_linear
+                else:
+                    parent_name, name = _parent_name(self.linear_node.target)
+                    setattr(modules[parent_name], name, ref_linear)
                 op_out = quantized_graph.create_node(
                     'call_module',
                     self.linear_node.target,
index 775d40b..03b1778 100644 (file)
@@ -25,16 +25,14 @@ from .utils import get_combined_dict
 
 # Default map for swapping float module to reference quantized modules
 DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
+    nn.Linear: nnqr.Linear,
     nn.Conv1d: nnqr.Conv1d,
     nn.Conv2d: nnqr.Conv2d,
     nn.Conv3d: nnqr.Conv3d,
-    nn.Linear: nnqr.Linear,
     nni.ConvReLU1d: nniqr.ConvReLU1d,
     nni.ConvReLU2d: nniqr.ConvReLU2d,
     nni.ConvReLU3d: nniqr.ConvReLU3d,
-    nni.LinearReLU: nniqr.LinearReLU,
     # QAT Modules
-    nnqat.Linear: nnqr.Linear,
     nnqat.Conv2d: nnqr.Conv2d,
     nnqat.Conv3d: nnqr.Conv3d,
     nniqat.ConvBn1d: nnqr.Conv1d,
@@ -45,7 +43,6 @@ DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
     nniqat.ConvBnReLU3d: nniqr.ConvReLU3d,
     nniqat.ConvReLU2d: nniqr.ConvReLU2d,
     nniqat.ConvReLU3d: nniqr.ConvReLU3d,
-    nniqat.LinearReLU: nniqr.LinearReLU,
 }
 
 # Default map for swapping float module to quantized ones