From 0d0605eaa9243c938faddd3fb60f922c4a48c953 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 27 Aug 2021 20:58:20 -0700 Subject: [PATCH] [quant][graphmode][fx] Add reference quantized linear module (#63627) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63627 Added reference quantized linear module for the custom backend flow, the reference quantized module will have the following code: ``` w(float) -- quant - dequant \ x(float) ------------- F.linear --- ``` In the full model, we will see ``` w(float) -- quant - *dequant \ x -- quant --- *dequant -- *F.linear --- *quant - dequant ``` and the backend should be able to fuse the ops with `*` into a quantized linear Test Plan: python test/test_quantization.py TestQuantizeFx.test_conv_linear_reference Imported from OSS Reviewed By: vkuzo Differential Revision: D30504750 fbshipit-source-id: 5729921745c2b6a0fb344efc3689f3b170e89500 --- test/quantization/core/test_quantized_module.py | 51 +++------ test/quantization/fx/test_quantize_fx.py | 72 +++++++++++- .../quantized/_reference/modules/__init__.py | 2 - .../quantized/_reference/modules/linear_relu.py | 28 ----- torch/nn/quantized/_reference/modules/linear.py | 124 ++++++++++++++++----- torch/nn/quantized/_reference/modules/utils.py | 45 ++++++++ torch/quantization/fx/quantization_patterns.py | 19 +++- torch/quantization/quantization_mappings.py | 5 +- 8 files changed, 240 insertions(+), 106 deletions(-) delete mode 100644 torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py create mode 100644 torch/nn/quantized/_reference/modules/utils.py diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index 10d5831..bc8a6b3 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -6,7 +6,6 @@ import torch.nn.intrinsic.quantized._reference as nniqr import torch.nn.quantized as nnq import torch.nn.quantized._reference as nnqr import torch.nn.quantized.dynamic as nnqd -import torch.nn.functional as F import torch.quantization from torch.quantization import ( @@ -70,24 +69,21 @@ class TestStaticQuantizedModule(QuantizationTestCase): [4, 8], [True, False], [True, False], - [True, False], [True, False]) for (batch_size, in_features, out_features, use_bias, - use_fused, per_channel, is_reference) in options: + use_fused, per_channel) in options: self._test_linear_api_impl( batch_size, in_features, out_features, use_bias, use_fused, - per_channel, is_reference) + per_channel) - def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel, is_reference): + def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel): if torch.backends.quantized.engine == 'qnnpack': per_channel = False - # (use_fused, is_reference) -> quantized class + # use_fused -> quantized class class_map = { - (True, True) : nniqr.LinearReLU, - (True, False) : nniq.LinearReLU, - (False, True) : nnqr.Linear, - (False, False) : nnq.Linear, + True: nniq.LinearReLU, + False: nnq.Linear, } W = torch.rand(out_features, in_features).float() @@ -107,7 +103,7 @@ class TestStaticQuantizedModule(QuantizationTestCase): B = torch.rand(out_features).float() if use_bias else None scale = 0.5 zero_point = 3 - qlinear = class_map[(use_fused, is_reference)](in_features, out_features) + qlinear = class_map[use_fused](in_features, out_features) qlinear_copy = qlinear # deepcopy does not work right now # qlinear_copy = copy.deepcopy(qlinear) @@ -127,21 +123,11 @@ class TestStaticQuantizedModule(QuantizationTestCase): # Check if the module implementation matches calling the # ops directly - if is_reference: - weight = qlinear._qweight - bias = qlinear._bias - weight_dequant = weight.dequantize() - X_q_dq = X_q.dequantize() - Z_ref = F.linear(X_q_dq, weight_dequant, bias) - if use_fused: - Z_ref = F.relu(Z_ref, inplace=True) - Z_ref = torch.quantize_per_tensor(Z_ref, scale, zero_point, torch.quint8) + W_pack = qlinear._packed_params._packed_params + if use_fused: + Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) else: - W_pack = qlinear._packed_params._packed_params - if use_fused: - Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) - else: - Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) + Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) self.assertEqual(Z_ref, Z_q) self.assertTrue( @@ -163,16 +149,12 @@ class TestStaticQuantizedModule(QuantizationTestCase): else: self.assertEqual(model_dict[key], loaded_dict[key]) - loaded_qlinear = class_map[(use_fused, is_reference)]( + loaded_qlinear = class_map[use_fused]( in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) - if is_reference: - self.assertEqual(qlinear._qweight, loaded_qlinear._qweight) - self.assertEqual(qlinear._bias, loaded_qlinear._bias) - else: - linear_unpack = torch.ops.quantized.linear_unpack - self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), - linear_unpack(loaded_qlinear._packed_params._packed_params)) + linear_unpack = torch.ops.quantized.linear_unpack + self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), + linear_unpack(loaded_qlinear._packed_params._packed_params)) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) # make sure loaded_qlinear has the same dir as qlinear since @@ -180,8 +162,7 @@ class TestStaticQuantizedModule(QuantizationTestCase): self.checkScriptable(loaded_qlinear, [[X_q]], check_save_load=True) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) - if not is_reference: - self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) + self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 762919e..7ae29e0 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F import torch.nn as nn import torch.nn.quantized as nnq +import torch.nn.quantized._reference as nnqr import torch.nn.quantized.dynamic as nnqd import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq @@ -571,7 +572,7 @@ class TestQuantizeFx(QuantizationTestCase): LinearModule, (), (linear_module_input,), - ns.call_module(nn.Linear) if is_reference else ns.call_module(nnqd.Linear), + ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear), None, ), ( @@ -579,7 +580,7 @@ class TestQuantizeFx(QuantizationTestCase): LinearModule, (), (linear_module_input,), - ns.call_module(nn.Linear if is_reference else nnq.Linear), + ns.call_module(nnqr.Linear if is_reference else nnq.Linear), None, ), ] @@ -608,6 +609,13 @@ class TestQuantizeFx(QuantizationTestCase): """ Test quantizing functional conv and linear with reference option """ tests = self._get_conv_linear_test_cases(is_reference=True) + + def _get_keys(prefix, is_dynamic): + all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] + if not is_dynamic: + all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) + return all_keys + for (is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_node) in tests: quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC @@ -623,13 +631,19 @@ class TestQuantizeFx(QuantizationTestCase): qr = result_dict["quantized_reference"] def checkWeightQParams(model): - for module_name in ("linear", "conv"): + for module_name in ("conv",): if hasattr(model, module_name): self.assertTrue(hasattr(qr.get_submodule(module_name), "_weight_qparams")) self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) + for module_name in ("linear",): + if hasattr(model, module_name): + self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) + self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) + self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) + self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) - def checkSerDeser(model): - for module_name in ("linear", "conv"): + def checkSerDeser(model, is_dynamic): + for module_name in ("conv",): if hasattr(model, module_name): # make sure seralization works state_dict = copy.deepcopy(model.state_dict()) @@ -641,6 +655,20 @@ class TestQuantizeFx(QuantizationTestCase): module._weight_qparams["scale"] = None model.load_state_dict(state_dict) self.assertTrue(torch.equal(prev_scale, module._weight_qparams["scale"])) + for module_name in ("linear",): + if hasattr(model, module_name): + # make sure seralization works + state_dict = copy.deepcopy(model.state_dict()) + all_keys = _get_keys(module_name, is_dynamic) + for key in all_keys: + self.assertTrue(key in state_dict) + # check load_state_dict restores states + module = getattr(model, module_name) + prev_scale = module.weight_scale + module.weight_scale = None + model.load_state_dict(state_dict) + module = getattr(model, module_name) + self.assertTrue(torch.equal(prev_scale, module.weight_scale)) checkWeightQParams(qr) @@ -648,7 +676,7 @@ class TestQuantizeFx(QuantizationTestCase): # make sure the qparams are preserved after copy checkWeightQParams(qr) - checkSerDeser(qr) + checkSerDeser(qr, is_dynamic) @skipIfNoFBGEMM def test_dynamic_quant_weight_observer(self): @@ -2941,6 +2969,38 @@ class TestQuantizeFx(QuantizationTestCase): ] self.checkGraphModuleNodes(m, expected_node_list=node_list) + def test_ref_linear_module(self): + """ Make sure the numerics for models with ref linear module + matches models with fbgemm/qnnpack module + """ + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + class M2(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.linear(x)) + + for M in [M1, M2]: + m = M().eval() + m = prepare_fx(m, {"": default_qconfig}) + m_copy = copy.deepcopy(m) + m = convert_fx(m, is_reference=False) + m_ref = convert_fx(m_copy, is_reference=True) + data = torch.randn(5, 10) + result = m(data) + result_ref = m_ref(data) + self.assertTrue(torch.equal(result, result_ref)) + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops diff --git a/torch/nn/intrinsic/quantized/_reference/modules/__init__.py b/torch/nn/intrinsic/quantized/_reference/modules/__init__.py index bf8ff3a..33b18d8 100644 --- a/torch/nn/intrinsic/quantized/_reference/modules/__init__.py +++ b/torch/nn/intrinsic/quantized/_reference/modules/__init__.py @@ -1,9 +1,7 @@ import torch -from .linear_relu import LinearReLU from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d __all__ = [ - 'LinearReLU', 'ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', diff --git a/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py b/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py deleted file mode 100644 index 39c5953..0000000 --- a/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -import torch.nn.intrinsic as nni -import torch.nn.quantized._reference as nnqr -import torch.nn.functional as F - -class LinearReLU(nnqr.Linear): - _FLOAT_MODULE = nni.LinearReLU - - def __init__( - self, - in_features, - out_features, - bias=True, - dtype=torch.qint8): - super().__init__(in_features, out_features, bias, dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.linear(x_dequant, weight_dequant, self._bias) - float_result = F.relu(float_result, inplace=True) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) - return result - - def _get_name(self): - return "QuantizedLinearReLU(Reference)" diff --git a/torch/nn/quantized/_reference/modules/linear.py b/torch/nn/quantized/_reference/modules/linear.py index 276dc01..1df5499 100644 --- a/torch/nn/quantized/_reference/modules/linear.py +++ b/torch/nn/quantized/_reference/modules/linear.py @@ -1,51 +1,115 @@ import torch -import torch.nn.quantized as nnq +import torch.nn as nn import torch.nn.functional as F -from typing import Optional +from typing import Optional, Dict, Any +from .utils import _quantize_and_dequantize_weight +from .utils import _save_weight_qparams +from .utils import _get_weight_qparam_keys -class Linear(nnq.Linear): - """ A backend independent version of nn.quantized.Linear - we will not pack the parameters in this module, since weight packing is an - optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), - this is useful when user want to use this module in other backends like Glow. +class Linear(nn.Linear): + """ A reference quantized linear module that fits into the FX + Graph Mode Quantization workflow + activation will be floating point Tensor, we will store floating + point weight as well in the module, but in forward we'll quantize + and dequantize the weight before running the floating point functional + linear operator. """ - def __init__(self, in_features, out_features, bias_=True, - dtype=torch.qint8): - super().__init__(in_features, out_features, bias_, dtype) - self._qweight, self._bias = self._packed_params._weight_bias() - del self._packed_params + def __init__( + self, + in_features: int, + out_features: int, + bias_: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_qparams: Optional[Dict[str, Any]] = None): + super().__init__(in_features, out_features, bias_, device, dtype) + if weight_qparams is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0 + } + self.weight_qscheme = weight_qparams["qscheme"] + self.weight_dtype = weight_qparams["dtype"] + assert self.weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \ + Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized linear module") + if self.weight_qscheme is not None: + self.register_buffer( + "weight_scale", + torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device)) + self.register_buffer( + "weight_zero_point", + torch.tensor( + weight_qparams["zero_point"], + dtype=torch.int, device=device)) + if self.weight_qscheme == torch.per_channel_affine: + self.register_buffer( + "weight_axis", + torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device)) + else: + # added for TorchScriptability, not used + self.register_buffer( + "weight_axis", + torch.tensor(0, dtype=torch.int, device=device)) def _get_name(self): return "QuantizedLinear(Reference)" + def get_weight(self): + """ + Fake quantize (quantize and dequantize) the weight with + the quantization parameters for weight, this is used to + simulate the numerics for the quantized weight in a quantized + model + """ + # supress mypy warning + assert isinstance(self.weight, torch.Tensor) + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + assert isinstance(self.weight_axis, torch.Tensor) + return _quantize_and_dequantize_weight( + self.weight, self.weight_qscheme, self.weight_dtype, self.weight_scale, + self.weight_zero_point, self.weight_axis) + def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.linear(x_dequant, weight_dequant, self._bias) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.linear --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.linear --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized linear + """ + weight_dequant = self.get_weight() + result = F.linear(x, weight_dequant, self.bias) return result def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) - destination[prefix + '_qweight'] = self._qweight - destination[prefix + '_bias'] = self._bias + _save_weight_qparams( + destination, prefix, self.weight_qscheme, self.weight_dtype, + self.weight_scale, self.weight_zero_point, self.weight_axis) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self._qweight = state_dict[prefix + '_qweight'] - self._bias = state_dict[prefix + '_bias'] - state_dict.pop(prefix + '_qweight') - state_dict.pop(prefix + '_bias') + for key in _get_weight_qparam_keys(state_dict, prefix): + setattr(self, key, state_dict[prefix + key]) + state_dict.pop(prefix + key) super()._load_from_state_dict( state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) - def _weight_bias(self): - return self._qweight, self._bias - - def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: - self._qweight = w - self._bias = b + @classmethod + def from_float(cls, float_linear, weight_qparams): + qref_linear = Linear( + float_linear.in_features, float_linear.out_features, + float_linear.bias is not None, device=float_linear.weight.device, + dtype=float_linear.weight.dtype, weight_qparams=weight_qparams) + qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach()) + if float_linear.bias is not None: + qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach()) + return qref_linear diff --git a/torch/nn/quantized/_reference/modules/utils.py b/torch/nn/quantized/_reference/modules/utils.py new file mode 100644 index 0000000..7c36650 --- /dev/null +++ b/torch/nn/quantized/_reference/modules/utils.py @@ -0,0 +1,45 @@ +import torch +from typing import Dict, Any + +def _quantize_and_dequantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: torch.Tensor): + """ Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme == torch.per_tensor_affine: + weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype) + weight_dequant = weight.dequantize() + elif weight_qscheme == torch.per_channel_affine: + weight = torch.quantize_per_channel( + weight, weight_scale, + weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type] + weight_dequant = weight.dequantize() + else: + weight_dequant = weight + return weight_dequant + +def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis): + destination[prefix + "weight_qscheme"] = weight_qscheme + destination[prefix + "weight_dtype"] = weight_dtype + if weight_qscheme is not None: + destination[prefix + "weight_scale"] = weight_scale + destination[prefix + "weight_zero_point"] = weight_zero_point + if weight_qscheme == torch.per_channel_affine: + destination[prefix + "weight_axis"] = weight_axis + +def _get_weight_qparam_keys( + state_dict: Dict[str, Any], + prefix: str): + keys = ["weight_qscheme", "weight_dtype"] + weight_qscheme = state_dict[prefix + "weight_qscheme"] + if weight_qscheme is not None: + keys.append("weight_scale") + keys.append("weight_zero_point") + if weight_qscheme == torch.quantize_per_channel: + keys.append("weight_axis") + return keys diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 6362961..e8b8736 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -869,6 +869,7 @@ class LinearReLUQuantizeHandler(QuantizeHandler): # Get the float linear and attach qscheme and qparams # the the module float_linear = self.linear + fused_linear = None if isinstance(float_linear, (torch.nn.qat.Linear, torch.nn.intrinsic.qat.LinearReLU)): float_linear = float_linear.to_float() # change qat linear to linear @@ -876,10 +877,12 @@ class LinearReLUQuantizeHandler(QuantizeHandler): setattr(modules[parent_name], name, float_linear) # Attach weight fake quant to the linear module if isinstance(float_linear, torch.nn.intrinsic.LinearReLU): + fused_linear = float_linear float_linear = float_linear[0] weight_post_process = self.linear.weight_fake_quant else: if isinstance(float_linear, torch.nn.intrinsic.LinearReLU): + fused_linear = float_linear float_linear = self.linear[0] # type: ignore[index] # Attach the weight observer to the module weight_post_process = qconfig.weight() # type: ignore[union-attr] @@ -887,7 +890,21 @@ class LinearReLUQuantizeHandler(QuantizeHandler): weight_post_process(float_linear.weight) # type: ignore[operator] weight_qparams = get_qparam_dict(weight_post_process) - _to_reference(float_linear, weight_qparams) + # TODO: include the configuration in backend_config_dict + # we can have a map from module to reference module + # and allow user to register new ones + qlinear_cls = get_static_quant_module_class( + type(float_linear), is_reference=is_reference) + ref_linear = qlinear_cls.from_float(float_linear, weight_qparams) + + # if the parent is a fused linear (Sequential), we can replace the first + # item to ref linear, otherwise we can update + # the linear instance in the module tree + if fused_linear is not None: + fused_linear[0] = ref_linear + else: + parent_name, name = _parent_name(self.linear_node.target) + setattr(modules[parent_name], name, ref_linear) op_out = quantized_graph.create_node( 'call_module', self.linear_node.target, diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 775d40b..03b1778 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -25,16 +25,14 @@ from .utils import get_combined_dict # Default map for swapping float module to reference quantized modules DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + nn.Linear: nnqr.Linear, nn.Conv1d: nnqr.Conv1d, nn.Conv2d: nnqr.Conv2d, nn.Conv3d: nnqr.Conv3d, - nn.Linear: nnqr.Linear, nni.ConvReLU1d: nniqr.ConvReLU1d, nni.ConvReLU2d: nniqr.ConvReLU2d, nni.ConvReLU3d: nniqr.ConvReLU3d, - nni.LinearReLU: nniqr.LinearReLU, # QAT Modules - nnqat.Linear: nnqr.Linear, nnqat.Conv2d: nnqr.Conv2d, nnqat.Conv3d: nnqr.Conv3d, nniqat.ConvBn1d: nnqr.Conv1d, @@ -45,7 +43,6 @@ DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { nniqat.ConvBnReLU3d: nniqr.ConvReLU3d, nniqat.ConvReLU2d: nniqr.ConvReLU2d, nniqat.ConvReLU3d: nniqr.ConvReLU3d, - nniqat.LinearReLU: nniqr.LinearReLU, } # Default map for swapping float module to quantized ones -- 2.7.4