From 8f88f797dbff54aa4d2b153e9f0dc87794e4cf38 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 30 Aug 2021 14:21:39 -0700 Subject: [PATCH] [quant][graphmode][fx] Add reference quantized conv module (#63828) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63828 Added reference quantized conv module for the custom backend flow, the reference quantized module will have the following code: ``` w(float) -- quant - dequant \ x(float) ------------- F.conv2d --- ``` In the full model, we will see ``` w(float) -- quant - *dequant \ x -- quant --- *dequant -- *F.conv2d --- *quant - dequant ``` and the backend should be able to fuse the ops with `*` into a quantized linear Test Plan: python test/test_quantization.py TestQuantizeFx.test_conv_linear_reference Imported from OSS Reviewed By: vkuzo Differential Revision: D30504749 fbshipit-source-id: e1d8c43a0e0d6d9ea2375b8ca59a9c0f455514fb --- test/quantization/core/test_quantized_module.py | 84 +++----- test/quantization/fx/test_quantize_fx.py | 68 ++++-- .../nn/intrinsic/quantized/_reference/__init__.py | 1 - .../quantized/_reference/modules/__init__.py | 8 - .../quantized/_reference/modules/conv_relu.py | 58 ----- torch/nn/quantized/_reference/modules/conv.py | 237 ++++++++++++++------- torch/quantization/fx/quantization_patterns.py | 30 ++- torch/quantization/quantization_mappings.py | 15 -- 8 files changed, 257 insertions(+), 244 deletions(-) delete mode 100644 torch/nn/intrinsic/quantized/_reference/__init__.py delete mode 100644 torch/nn/intrinsic/quantized/_reference/modules/__init__.py delete mode 100644 torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index bc8a6b3..b0bc782 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -2,9 +2,7 @@ import torch import torch.nn as nn import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq -import torch.nn.intrinsic.quantized._reference as nniqr import torch.nn.quantized as nnq -import torch.nn.quantized._reference as nnqr import torch.nn.quantized.dynamic as nnqd import torch.quantization @@ -211,12 +209,11 @@ class TestStaticQuantizedModule(QuantizationTestCase): self.assertEqual(rqr, rqr2) def _test_conv_api_impl( - self, module_name, qconv_module, conv_module, batch_size, - in_channels_per_group, input_feature_map_size, out_channels_per_group, - groups, kernel_size, stride, padding, padding_mode, dilation, - X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, - use_bias, use_fused, use_channelwise, is_reference - ): + self, module_name, qconv_module, conv_module, batch_size, + in_channels_per_group, input_feature_map_size, out_channels_per_group, + groups, kernel_size, stride, padding, padding_mode, dilation, + X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, + use_bias, use_fused, use_channelwise): for i in range(len(kernel_size)): assume(input_feature_map_size[i] + 2 * padding[i] >= dilation[i] * (kernel_size[i] - 1) + 1) @@ -245,8 +242,7 @@ class TestStaticQuantizedModule(QuantizationTestCase): # Test members self.assertTrue(module_name == qconv_module._get_name(), module_name + " " + qconv_module._get_name()) - if not is_reference: - self.assertTrue(hasattr(qconv_module, '_packed_params')) + self.assertTrue(hasattr(qconv_module, '_packed_params')) self.assertTrue(hasattr(qconv_module, 'scale')) self.assertTrue(hasattr(qconv_module, 'zero_point')) @@ -275,9 +271,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is # 4 assuming the rounding mode is round-to-nearest, ties-to-even. # skip numerics checking for reference module - if not is_reference: - np.testing.assert_array_almost_equal( - Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0) + np.testing.assert_array_almost_equal( + Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0) # Test serialization of quantized Conv Module using state_dict model_dict = qconv_module.state_dict() @@ -297,8 +292,7 @@ class TestStaticQuantizedModule(QuantizationTestCase): self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module)) self.assertTrue(module_name == loaded_qconv_module._get_name()) - if not is_reference: - self.assertTrue(hasattr(loaded_qconv_module, '_packed_params')) + self.assertTrue(hasattr(loaded_qconv_module, '_packed_params')) self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias')) self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight()) @@ -308,9 +302,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): self.assertEqual(qconv_module.zero_point, loaded_qconv_module.zero_point) Y_loaded = loaded_qconv_module(X_q) - if not is_reference: - np.testing.assert_array_almost_equal( - Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0) + np.testing.assert_array_almost_equal( + Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0) # Test serialization b = io.BytesIO() @@ -330,9 +323,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): self.assertEqual(copied_conv.zero_point, qconv_module.zero_point) Y_copied = copied_conv(X_q) - if not is_reference: - np.testing.assert_array_almost_equal( - Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0) + np.testing.assert_array_almost_equal( + Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0) deepcopied_conv = copy.deepcopy(qconv_module) self.assertEqual(deepcopied_conv.bias(), qconv_module.bias()) @@ -340,9 +332,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): self.assertEqual(deepcopied_conv.zero_point, qconv_module.zero_point) Y_deepcopied = copied_conv(X_q) - if not is_reference: - np.testing.assert_array_almost_equal( - Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0) + np.testing.assert_array_almost_equal( + Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0) # JIT testing self.checkScriptable( @@ -377,9 +368,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): [True, False], # use_bias [True, False], # use_fused [True, False], # use_channelwise - [True, False] # is_reference ) - for pad_mode, use_bias, use_fused, use_channelwise, is_reference in options: + for pad_mode, use_bias, use_fused, use_channelwise in options: if torch.backends.quantized.engine == "qnnpack": use_channelwise = False batch_size = 2 @@ -407,15 +397,13 @@ class TestStaticQuantizedModule(QuantizationTestCase): Y_zero_point = 4 if torch.backends.quantized.engine == 'qnnpack': use_channelwise = False - # (use_fused, is_reference) -> quantized class + # use_fused -> quantized class class_map = { - (True, True): (nniqr.ConvReLU1d, "QuantizedConvReLU1d(Reference)"), - (True, False): (nniq.ConvReLU1d, "QuantizedConvReLU1d"), - (False, True): (nnqr.Conv1d, "QuantizedConv1d(Reference)"), - (False, False): (nnq.Conv1d, "QuantizedConv1d") + True: (nniq.ConvReLU1d, "QuantizedConvReLU1d"), + False: (nnq.Conv1d, "QuantizedConv1d") } - qconv_cls, module_name = class_map[(use_fused, is_reference)] + qconv_cls, module_name = class_map[use_fused] qconv_module = qconv_cls( in_channels, out_channels, kernel, stride, pad, dilation, groups, use_bias, padding_mode=pad_mode @@ -434,7 +422,7 @@ class TestStaticQuantizedModule(QuantizationTestCase): in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, pad, pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, - Y_zero_point, use_bias, use_fused, use_channelwise, is_reference) + Y_zero_point, use_bias, use_fused, use_channelwise) @override_qengines def test_conv2d_api(self): @@ -443,9 +431,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): [True, False], # use_bias [True, False], # use_fused [True, False], # use_channelwise - [True, False] # is_reference ) - for pad_mode, use_bias, use_fused, use_channelwise, is_reference in options: + for pad_mode, use_bias, use_fused, use_channelwise in options: if torch.backends.quantized.engine == "qnnpack": use_channelwise = False batch_size = 2 @@ -475,15 +462,13 @@ class TestStaticQuantizedModule(QuantizationTestCase): W_zero_point = [3] Y_scale = 5.0 Y_zero_point = 4 - # (use_fused, is_reference) -> quantized class + # use_fused -> quantized class class_map = { - (True, True): (nniqr.ConvReLU2d, "QuantizedConvReLU2d(Reference)"), - (True, False): (nniq.ConvReLU2d, "QuantizedConvReLU2d"), - (False, True): (nnqr.Conv2d, "QuantizedConv2d(Reference)"), - (False, False): (nnq.Conv2d, "QuantizedConv2d") + True: (nniq.ConvReLU2d, "QuantizedConvReLU2d"), + False: (nnq.Conv2d, "QuantizedConv2d") } - qconv_cls, module_name = class_map[(use_fused, is_reference)] + qconv_cls, module_name = class_map[use_fused] qconv_module = qconv_cls( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode=pad_mode @@ -502,7 +487,7 @@ class TestStaticQuantizedModule(QuantizationTestCase): in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, - Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise, is_reference) + Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise) @skipIfNoFBGEMM def test_conv3d_api(self): @@ -510,9 +495,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): [True, False], # use_bias [True, False], # use_fused [True, False], # use_channelwise - [True, False] # is_reference ) - for use_bias, use_fused, use_channelwise, is_reference in options: + for use_bias, use_fused, use_channelwise in options: if torch.backends.quantized.engine == "qnnpack": use_channelwise = False batch_size = 2 @@ -547,16 +531,14 @@ class TestStaticQuantizedModule(QuantizationTestCase): W_zero_point = [3] Y_scale = 5.0 Y_zero_point = 4 - # (use_fused, is_reference) -> quantized class + # use_fused -> quantized class class_map = { - (True, True): (nniqr.ConvReLU3d, "QuantizedConvReLU3d(Reference)"), - (True, False): (nniq.ConvReLU3d, "QuantizedConvReLU3d"), - (False, True): (nnqr.Conv3d, "QuantizedConv3d(Reference)"), - (False, False): (nnq.Conv3d, "QuantizedConv3d") + True: (nniq.ConvReLU3d, "QuantizedConvReLU3d"), + False: (nnq.Conv3d, "QuantizedConv3d") } with override_quantized_engine('fbgemm'): - qconv_cls, module_name = class_map[(use_fused, is_reference)] + qconv_cls, module_name = class_map[use_fused] qconv_module = qconv_cls( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode=pad_mode @@ -576,7 +558,7 @@ class TestStaticQuantizedModule(QuantizationTestCase): out_channels_per_group, groups, kernel_size, stride, padding, pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, - use_channelwise, is_reference) + use_channelwise) def test_pool_api(self): """Tests the correctness of the pool module. diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 7ae29e0..9682da1 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -532,7 +532,7 @@ class TestQuantizeFx(QuantizationTestCase): Conv1d, conv1d_module_args, (conv1d_input,), - ns.call_module(nn.Conv1d if is_reference else nnq.Conv1d), + ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d), None ), ( @@ -540,7 +540,7 @@ class TestQuantizeFx(QuantizationTestCase): Conv2d, conv2d_module_args, (conv2d_input,), - ns.call_module(nn.Conv2d if is_reference else nnq.Conv2d), + ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d), None ), ( @@ -548,7 +548,7 @@ class TestQuantizeFx(QuantizationTestCase): Conv3d, conv3d_module_args, (conv3d_input,), - ns.call_module(nn.Conv3d if is_reference else nnq.Conv3d), + ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d), None ), ( @@ -631,11 +631,7 @@ class TestQuantizeFx(QuantizationTestCase): qr = result_dict["quantized_reference"] def checkWeightQParams(model): - for module_name in ("conv",): - if hasattr(model, module_name): - self.assertTrue(hasattr(qr.get_submodule(module_name), "_weight_qparams")) - self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) - for module_name in ("linear",): + for module_name in ("linear", "conv"): if hasattr(model, module_name): self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) @@ -643,19 +639,7 @@ class TestQuantizeFx(QuantizationTestCase): self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) def checkSerDeser(model, is_dynamic): - for module_name in ("conv",): - if hasattr(model, module_name): - # make sure seralization works - state_dict = copy.deepcopy(model.state_dict()) - self.assertTrue(module_name + "._weight_qparams" in state_dict) - - # check load_state_dict restores states - module = getattr(model, module_name) - prev_scale = module._weight_qparams["scale"] - module._weight_qparams["scale"] = None - model.load_state_dict(state_dict) - self.assertTrue(torch.equal(prev_scale, module._weight_qparams["scale"])) - for module_name in ("linear",): + for module_name in ("linear", "conv"): if hasattr(model, module_name): # make sure seralization works state_dict = copy.deepcopy(model.state_dict()) @@ -3001,6 +2985,44 @@ class TestQuantizeFx(QuantizationTestCase): result_ref = m_ref(data) self.assertTrue(torch.equal(result, result_ref)) + def test_ref_conv_module(self): + """ Make sure the numerics for models with ref conv module + matches models with fbgemm/qnnpack module + """ + convs = { + 1: nn.Conv1d, + 2: nn.Conv2d, + 3: nn.Conv3d, + } + + class M1(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = convs[dim](3, 3, 3) + + def forward(self, x): + return self.conv(x) + + class M2(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = convs[dim](3, 3, 3) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + for dim, M in itertools.product([1, 2, 3], [M1, M2]): + m = M(dim).eval() + m = prepare_fx(m, {"": default_qconfig}) + m_copy = copy.deepcopy(m) + m = convert_fx(m, is_reference=False) + m_ref = convert_fx(m_copy, is_reference=True) + data = self.img_data_dict[dim][0][0] + result = m(data) + result_ref = m_ref(data) + self.assertTrue(torch.equal(result, result_ref)) + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops @@ -4558,13 +4580,13 @@ class TestQuantizeFxOps(QuantizationTestCase): reference_order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), - ns.call_module(nn.Conv2d), + ns.call_module(nnqr.Conv2d), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ns.call_module(nn.Sigmoid), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), - ns.call_module(nn.Conv2d), + ns.call_module(nnqr.Conv2d), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ] diff --git a/torch/nn/intrinsic/quantized/_reference/__init__.py b/torch/nn/intrinsic/quantized/_reference/__init__.py deleted file mode 100644 index 3d79bdb..0000000 --- a/torch/nn/intrinsic/quantized/_reference/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .modules import * # noqa: F403 diff --git a/torch/nn/intrinsic/quantized/_reference/modules/__init__.py b/torch/nn/intrinsic/quantized/_reference/modules/__init__.py deleted file mode 100644 index 33b18d8..0000000 --- a/torch/nn/intrinsic/quantized/_reference/modules/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch -from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d - -__all__ = [ - 'ConvReLU1d', - 'ConvReLU2d', - 'ConvReLU3d', -] diff --git a/torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py b/torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py deleted file mode 100644 index b0305f6..0000000 --- a/torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import torch.nn.quantized._reference as nnqr -import torch.nn.functional as F - -class ConvReLU1d(nnqr.Conv1d): - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv1d( - x_dequant, weight_dequant, self._bias, self._conv1d_stride, # type: ignore[has-type] - self._conv1d_padding, self._conv1d_dilation, self.groups) # type: ignore[has-type] - float_result = F.relu(float_result, inplace=True) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) - return result - - def _get_name(self): - return "QuantizedConvReLU1d(Reference)" - - -class ConvReLU2d(nnqr.Conv2d): - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv2d( - x_dequant, weight_dequant, self._bias, self.stride, - self.padding, self.dilation, self.groups) - float_result = F.relu(float_result, inplace=True) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) - return result - - def _get_name(self): - return "QuantizedConvReLU2d(Reference)" - -class ConvReLU3d(nnqr.Conv3d): - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv3d( - x_dequant, weight_dequant, self._bias, self.stride, - self.padding, self.dilation, self.groups) - float_result = F.relu(float_result, inplace=True) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) - return result - - def _get_name(self): - return "QuantizedConvReLU3d(Reference)" diff --git a/torch/nn/quantized/_reference/modules/conv.py b/torch/nn/quantized/_reference/modules/conv.py index 036f8e4..6b03bb0 100644 --- a/torch/nn/quantized/_reference/modules/conv.py +++ b/torch/nn/quantized/_reference/modules/conv.py @@ -1,42 +1,101 @@ import torch -import torch.nn.quantized as nnq +import torch.nn as nn import torch.nn.functional as F -from typing import Optional +from typing import Optional, Dict, Any from torch.nn.common_types import _size_1_t -from torch.nn.modules.utils import _single +from .utils import _quantize_and_dequantize_weight +from .utils import _save_weight_qparams +from .utils import _get_weight_qparam_keys -class _ConvNd(nnq._ConvNd): +class _ConvNd(torch.nn.modules.conv._ConvNd): """ A reference version of nn.quantized.Conv2d we will not pack the parameters in this module, since weight packing is an optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), this is useful when user want to use this module in other backends like Glow. """ - __annotations__ = {"_bias": Optional[torch.Tensor]} + __annotations__ = {"bias": Optional[torch.Tensor]} def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) - destination[prefix + '_qweight'] = self._qweight - destination[prefix + '_bias'] = self._bias + _save_weight_qparams( + destination, prefix, self.weight_qscheme, self.weight_dtype, + self.weight_scale, self.weight_zero_point, self.weight_axis) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self._qweight = state_dict[prefix + '_qweight'] - self._bias = state_dict[prefix + '_bias'] - state_dict.pop(prefix + '_qweight') - state_dict.pop(prefix + '_bias') + for key in _get_weight_qparam_keys(state_dict, prefix): + setattr(self, key, state_dict[prefix + key]) + state_dict.pop(prefix + key) super()._load_from_state_dict( state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) - def _weight_bias(self): - return self._qweight, self._bias - - def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: - self._qweight = w - self._bias = b - -class Conv1d(_ConvNd, nnq.Conv1d): + def _init_weight_qparams(self, weight_qparams, device): + if weight_qparams is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0 + } + self.weight_qscheme = weight_qparams["qscheme"] + self.weight_dtype = weight_qparams["dtype"] + assert self.weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \ + Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized linear module") + if self.weight_qscheme is not None: + self.register_buffer( + "weight_scale", + torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device)) + self.register_buffer( + "weight_zero_point", + torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device)) + if self.weight_qscheme == torch.per_channel_affine: + self.register_buffer( + "weight_axis", + torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device)) + else: + # added for TorchScriptability, not used + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device)) + + def get_weight(self): + """ + Fake quantize (quantize and dequantize) the weight with + the quantization parameters for weight, this is used to + simulate the numerics for the quantized weight in a quantized + model + """ + # supress mypy warning + assert isinstance(self.weight, torch.Tensor) + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + assert isinstance(self.weight_axis, torch.Tensor) + return _quantize_and_dequantize_weight( + self.weight, self.weight_qscheme, + self.weight_dtype, self.weight_scale, self.weight_zero_point, self.weight_axis) + + @staticmethod + def from_float(cls, float_conv, weight_qparams): + qref_conv = cls( + float_conv.in_channels, + float_conv.out_channels, + float_conv.kernel_size, # type: ignore[arg-type] + float_conv.stride, # type: ignore[arg-type] + float_conv.padding, # type: ignore[arg-type] + float_conv.dilation, # type: ignore[arg-type] + float_conv.groups, + float_conv.bias is not None, # type: ignore[arg-type] + float_conv.padding_mode, + device=float_conv.weight.device, + dtype=float_conv.weight.dtype, + weight_qparams=weight_qparams) + qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) + if float_conv.bias is not None: + qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) + return qref_conv + +class Conv1d(_ConvNd, nn.Conv1d): def __init__(self, in_channels: int, out_channels: int, @@ -46,91 +105,107 @@ class Conv1d(_ConvNd, nnq.Conv1d): dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, - padding_mode: str = 'zeros'): - nnq.Conv1d.__init__( + padding_mode: str = "zeros", + device=None, + dtype=None, + weight_qparams: Optional[Dict[str, Any]] = None): + nn.Conv1d.__init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode) - # self.stride, self.padding, self.dilation are 2d tuple since - # current quantized conv1d is using Conv2dPackedParams - # TODO: we should fix this if we implemenet Conv1dPackedParams - self._conv1d_stride = _single(self.stride[0]) - self._conv1d_padding = _single(self.padding[0]) - self._conv1d_dilation = _single(self.dilation[0]) + groups, bias, padding_mode, device, dtype) + self._init_weight_qparams(weight_qparams, device) def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv1d( - x_dequant, weight_dequant, self._bias, self._conv1d_stride, - self._conv1d_padding, self._conv1d_dilation, self.groups) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv1d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv1d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv1d + """ + weight_dequant = self.get_weight() + result = F.conv1d( + x, weight_dequant, self.bias, self.stride, + self.padding, self.dilation, self.groups) return result def _get_name(self): - return 'QuantizedConv1d(Reference)' - - @torch.jit.export - def __setstate__(self, state): - self.in_channels = state[0] - self.out_channels = state[1] - self.kernel_size = state[2] - self.stride = state[3] - self.padding = state[4] - self.dilation = state[5] - self.transposed = state[6] - self.output_padding = state[7] - self.groups = state[8] - self.padding_mode = state[9] - self.set_weight_bias(state[10], state[11]) - self.scale = state[12] - self.zero_point = state[13] - self.training = state[14] - self._conv1d_stride = (self.stride[0],) - self._conv1d_padding = (self.padding[0],) - self._conv1d_dilation = (self.dilation[0],) - -class Conv2d(_ConvNd, nnq.Conv2d): + return "QuantizedConv1d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): + return _ConvNd.from_float(cls, float_conv, weight_qparams) + +class Conv2d(_ConvNd, nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, - padding_mode='zeros'): - nnq.Conv2d.__init__( + padding_mode='zeros', + device=None, + dtype=None, + weight_qparams: Optional[Dict[str, Any]] = None): + nn.Conv2d.__init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode) + groups, bias, padding_mode, device, dtype) + self._init_weight_qparams(weight_qparams, device) def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv2d( - x_dequant, weight_dequant, self._bias, self.stride, + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv2d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv2d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv2d + """ + weight_dequant = self.get_weight() + result = F.conv2d( + x, weight_dequant, self.bias, self.stride, self.padding, self.dilation, self.groups) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) return result def _get_name(self): - return 'QuantizedConv2d(Reference)' + return "QuantizedConv2d(Reference)" -class Conv3d(_ConvNd, nnq.Conv3d): + @classmethod + def from_float(cls, float_conv, weight_qparams): + return _ConvNd.from_float(cls, float_conv, weight_qparams) + +class Conv3d(_ConvNd, nn.Conv3d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, - padding_mode='zeros'): - nnq.Conv3d.__init__( + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[Dict[str, Any]] = None): + nn.Conv3d.__init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) + self._init_weight_qparams(weight_qparams, device) def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv3d( - x_dequant, weight_dequant, self._bias, self.stride, + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv3d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv3d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv3d + """ + weight_dequant = self.get_weight() + result = F.conv3d( + x, weight_dequant, self.bias, self.stride, self.padding, self.dilation, self.groups) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) return result def _get_name(self): - return 'QuantizedConv3d(Reference)' + return "QuantizedConv3d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): + return _ConvNd.from_float(cls, float_conv, weight_qparams) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 779dfcf..418cae1 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -638,19 +638,22 @@ class ConvReluQuantizeHandler(QuantizeHandler): # and qparam is a dictionary of # {"qscheme": ..., "scale": ..., "zero_point": ...} for per tensor quantization or # {"qscheme": ..., "scale": ..., "zero_point": ..., "axis": ...} for per channel quantization + float_conv = self.conv + fused_conv = None if isinstance( - self.conv, + float_conv, QAT_CONV_MODULE_CLASSES): # case 1. converting qat conv module to # a float conv module, we need to attch # weight fake_quant to the conv module, # weight fake_quant is assumed to be run during # QAT so we don't need to run it again here - float_conv = self.conv.to_float() + float_conv = self.conv.to_float() # type: ignore[operator] # change qat conv to conv parent_name, name = _parent_name(self.conv_node.target) setattr(modules[parent_name], name, float_conv) if isinstance(float_conv, torch.nn.intrinsic._FusedModule): + fused_conv = float_conv float_conv = float_conv[0] weight_post_process = self.conv.weight_fake_quant else: @@ -658,15 +661,28 @@ class ConvReluQuantizeHandler(QuantizeHandler): # to float conv module, we need to attach # weight observer to the conv module and run it # with conv weight - float_conv = self.conv - if isinstance(self.conv, torch.nn.intrinsic._FusedModule): - float_conv = self.conv[0] + if isinstance(float_conv, torch.nn.intrinsic._FusedModule): + fused_conv = float_conv + float_conv = float_conv[0] # type: ignore[index] assert qconfig is not None weight_post_process = qconfig.weight() # run weight observer - weight_post_process(float_conv.weight) + weight_post_process(float_conv.weight) # type: ignore[operator] weight_qparams = get_qparam_dict(weight_post_process) - _to_reference(float_conv, weight_qparams) + # hardcoded for now, TODO: expose the api to user, + # we can have a map from module to reference module + # and allow user to register new ones + qconv_cls = get_static_quant_module_class( + type(float_conv), is_reference=is_reference) + ref_conv = qconv_cls.from_float(float_conv, weight_qparams) # type: ignore[attr-defined] + # if the parent is a fused conv (Sequential), we can replace the first + # item to ref conv, otherwise we can update + # the conv instance in the module tree + if fused_conv is not None: + fused_conv[0] = ref_conv + else: + parent_name, name = _parent_name(self.conv_node.target) + setattr(modules[parent_name], name, ref_conv) op_out = quantized_graph.create_node( 'call_module', self.conv_node.target, diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 03b1778..6851ba7 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -7,7 +7,6 @@ import torch.nn.functional as F import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq import torch.nn.intrinsic.quantized.dynamic as nniqd -import torch.nn.intrinsic.quantized._reference as nniqr import torch.nn.intrinsic.qat as nniqat import torch.nn.quantized as nnq import torch.nn.quantized._reference as nnqr @@ -29,20 +28,6 @@ DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { nn.Conv1d: nnqr.Conv1d, nn.Conv2d: nnqr.Conv2d, nn.Conv3d: nnqr.Conv3d, - nni.ConvReLU1d: nniqr.ConvReLU1d, - nni.ConvReLU2d: nniqr.ConvReLU2d, - nni.ConvReLU3d: nniqr.ConvReLU3d, - # QAT Modules - nnqat.Conv2d: nnqr.Conv2d, - nnqat.Conv3d: nnqr.Conv3d, - nniqat.ConvBn1d: nnqr.Conv1d, - nniqat.ConvBn2d: nnqr.Conv2d, - nniqat.ConvBn3d: nnqr.Conv3d, - nniqat.ConvBnReLU1d: nniqr.ConvReLU1d, - nniqat.ConvBnReLU2d: nniqr.ConvReLU2d, - nniqat.ConvBnReLU3d: nniqr.ConvReLU3d, - nniqat.ConvReLU2d: nniqr.ConvReLU2d, - nniqat.ConvReLU3d: nniqr.ConvReLU3d, } # Default map for swapping float module to quantized ones -- 2.7.4