From: Supriya Rao Date: Fri, 27 Aug 2021 04:05:56 +0000 (-0700) Subject: [quant][fx] Add support for dynamic linear + relu fusion (INT8) (#63799) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~663 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c7027f19efbb2f7b274c9e5fc0e87fe4b084e6ae;p=platform%2Fupstream%2Fpytorch.git [quant][fx] Add support for dynamic linear + relu fusion (INT8) (#63799) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63799 Add a new module that can be used for module swap with the nni.LinearReLU module in convert function. Supports INT8 currently (since FP16 op doesn't have relu fusion yet). Fixes #55393 Test Plan: python test/test_quantization.py test_dynamic_fusion Imported from OSS Reviewed By: heitorschueroff Differential Revision: D30502812 fbshipit-source-id: 3668e4f001a0626d469e17ac323acf582ee28a51 --- diff --git a/test/quantization/eager/test_quantize_eager_ptq.py b/test/quantization/eager/test_quantize_eager_ptq.py index 1824da5..10cbd92 100644 --- a/test/quantization/eager/test_quantize_eager_ptq.py +++ b/test/quantization/eager/test_quantize_eager_ptq.py @@ -42,6 +42,7 @@ from torch.testing._internal.common_quantization import ( EmbeddingBagModule, EmbeddingModule, EmbeddingWithLinear, + LinearReluLinearModel, ) # annotated models @@ -995,6 +996,23 @@ class TestPostTrainingDynamic(QuantizationTestCase): model = quantize_dynamic(NestedModel().eval(), qconfig_dict) checkQuantized(model) + def test_linear_relu_fusion(self): + dtype = torch.qint8 + model = LinearReluLinearModel().eval() + qconfig = default_dynamic_qconfig + qconfig_dict = {'' : qconfig} + torch.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True) + prepare_dynamic(model, qconfig_dict) + convert_dynamic(model) + + def checkQuantized(model): + self.checkDynamicQuantizedLinearRelu(model.fc1, dtype) + self.checkDynamicQuantizedLinear(model.fc2, dtype) + self.checkScriptable(model, self.calib_data, check_save_load=True) + self.checkNoQconfig(model) + + checkQuantized(model) + @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]), dtype=st.sampled_from([torch.qint8, torch.float16])) def test_quantized_rnn(self, qconfig, dtype): diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 08474d2..cdf2e7b 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -6,6 +6,7 @@ import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq +import torch.nn.intrinsic.quantized.dynamic as nniqd import torch.multiprocessing as mp # graph mode quantization based on fx @@ -2883,6 +2884,57 @@ class TestQuantizeFx(QuantizationTestCase): self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) + @skipIfNoFBGEMM + def test_dynamic_with_fusion(self): + """ + Tests that dynamic quantization APIs work with Linear + Relu fusion + """ + class LinearRelu(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear(x) + return self.relu(x) + + class Linear(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.ones(5, 5) + self.b = torch.zeros(5) + + def forward(self, x): + return torch.nn.functional.linear(x, self.w, self.b) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu()) + self.mods2 = Linear() + self.relu = F.relu + + def forward(self, x): + x = self.mods1(x) + x = self.mods2(x) + x = self.relu(x) + return x + + model = M().eval() + qconfig = { + "": default_dynamic_qconfig, + } + m = prepare_fx(model, qconfig) + m = convert_fx(m) + m(torch.rand(5, 5)) + node_list = [ + ns.call_module(nniqd.LinearReLU), + ns.call_module(nniqd.LinearReLU), + ns.call_function(torch.ops.quantized.linear_relu_dynamic), + ] + self.checkGraphModuleNodes(m, expected_node_list=node_list) + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops @@ -2956,7 +3008,7 @@ class TestQuantizeFxOps(QuantizationTestCase): } quant_type_to_qlinear_relu_fun = { # we don't have linear_relu_dynamic - QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), + QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic), QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu), QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu), } diff --git a/torch/nn/intrinsic/quantized/dynamic/__init__.py b/torch/nn/intrinsic/quantized/dynamic/__init__.py new file mode 100644 index 0000000..3d79bdb --- /dev/null +++ b/torch/nn/intrinsic/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py b/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000..ce57186 --- /dev/null +++ b/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py @@ -0,0 +1,6 @@ +import torch +from .linear_relu import LinearReLU + +__all__ = [ + 'LinearReLU', +] diff --git a/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py new file mode 100644 index 0000000..04c4c95 --- /dev/null +++ b/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -0,0 +1,47 @@ +import torch +import torch.nn.quantized.dynamic as nnqd +import torch.nn.intrinsic as nni + +class LinearReLU(nnqd.Linear): + r""" + A LinearReLU module fused from Linear and ReLU modules that can be used + for dynamic quantization. + Supports both, FP16 and INT8 quantization. + + We adopt the same interface as :class:`torch.nn.quantized.dynamic.Linear`. + + Attributes: + Same as torch.nn.quantized.dynamic.Linear + + Examples:: + + >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._packed_params.dtype == torch.qint8: + # TODO check if we should set reduce_rage = True by default here + Y = torch.ops.quantized.linear_relu_dynamic( + x, self._packed_params._packed_params, reduce_range=True) + # TODO Support this in a later PR + # elif self._packed_params.dtype == torch.float16: + # Y = torch.ops.quantized.linear_relu_dynamic_fp16( + # x, self._packed_params._packed_params) + else: + raise RuntimeError('Unsupported dtype on dynamic quantized linear relu!') + return Y.to(x.dtype) + + def _get_name(self): + return 'DynamicQuantizedLinearReLU' + + @classmethod + def from_float(cls, mod): + return super(LinearReLU, cls).from_float(mod) diff --git a/torch/nn/quantized/dynamic/modules/linear.py b/torch/nn/quantized/dynamic/modules/linear.py index 07cfdfe..ee153b1 100644 --- a/torch/nn/quantized/dynamic/modules/linear.py +++ b/torch/nn/quantized/dynamic/modules/linear.py @@ -1,5 +1,6 @@ import torch import torch.nn.quantized as nnq +import torch.nn.intrinsic as nni from torch.nn.quantized.modules.utils import _quantize_weight class Linear(nnq.Linear): @@ -79,11 +80,15 @@ class Linear(nnq.Linear): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, + torch.nn.intrinsic.modules.fused.LinearReLU] + assert type(mod) in float_modules, \ 'nn.quantized.dynamic.Linear.from_float only works for one of' + \ str([float_mod.__name__ for float_mod in float_modules]) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' + if type(mod) == nni.LinearReLU: + mod = mod[0] if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() else: @@ -102,6 +107,6 @@ class Linear(nnq.Linear): qweight = mod.weight.float() else: raise RuntimeError('Unsupported dtype specified for dynamic quantized Linear!') - qlinear = Linear(mod.in_features, mod.out_features, dtype=dtype) + qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) qlinear.set_weight_bias(qweight, mod.bias) return qlinear diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 09ca190..b7c39ca 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -1022,9 +1022,14 @@ class LinearReLUQuantizeHandler(QuantizeHandler): elif dtypes in [(torch.float32, torch.qint8, torch.quint8), (torch.float32, torch.float16, None)]: # choose linear dynamic or linear dynamic fp16 op based on weight dtype - qlinear_op = torch.ops.quantized.linear_dynamic \ - if weight_dtype == torch.qint8 \ - else torch.ops.quantized.linear_dynamic_fp16 + if weight_dtype == torch.qint8: + if self.relu_node: + qlinear_op = torch.ops.quantized.linear_relu_dynamic + else: + qlinear_op = torch.ops.quantized.linear_dynamic + else: # TODO add support for fp16 + relu fusion in a later PR + qlinear_op = torch.ops.quantized.linear_dynamic_fp16 + linear_input = load_arg(quantized=torch.float)(self.linear_node.args[0]) qlinear_args = (linear_input, packed_weight) # type: ignore[assignment] op_out = quantized_graph.create_node( @@ -1033,7 +1038,7 @@ class LinearReLUQuantizeHandler(QuantizeHandler): # TODO: may need to change the key to Node regenerate the map in each transformation, # since we might not be able to rely on the name node_name_to_scope[op_out.name] = node_name_to_scope[self.linear_node.name] - if self.relu_node: + if self.relu_node and weight_dtype is not torch.qint8: op_out = quantized_graph.create_node("call_function", torch.nn.functional.relu, (op_out,), {}) return op_out else: diff --git a/torch/quantization/ns/mappings.py b/torch/quantization/ns/mappings.py index 399ddca..e97d771 100644 --- a/torch/quantization/ns/mappings.py +++ b/torch/quantization/ns/mappings.py @@ -8,6 +8,7 @@ toq = torch.ops.quantized import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd import torch.nn.intrinsic.quantized as nniq +import torch.nn.intrinsic.quantized.dynamic as nniqd import torch.nn.intrinsic.qat as nniqat import torch.nn.intrinsic as nni import torch.nn.qat as nnqat @@ -70,6 +71,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: nnq.Linear, nni.LinearReLU, nniq.LinearReLU, + nniqd.LinearReLU, nnqat.Linear, nnqd.Linear, nniqat.LinearReLU, @@ -529,6 +531,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nniqat.ConvReLU2d, nniqat.ConvReLU3d, nniqat.LinearReLU, + nniqd.LinearReLU, ]) MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([ diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 6179398..775d40b 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -6,6 +6,7 @@ from torch import nn import torch.nn.functional as F import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq +import torch.nn.intrinsic.quantized.dynamic as nniqd import torch.nn.intrinsic.quantized._reference as nniqr import torch.nn.intrinsic.qat as nniqat import torch.nn.quantized as nnq @@ -122,6 +123,7 @@ DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { nn.GRU: nnqd.GRU, nn.LSTMCell: nnqd.LSTMCell, nn.RNNCell: nnqd.RNNCell, + nni.LinearReLU: nniqd.LinearReLU, } # Allowlist for propagating the qconfig diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 6b2d1dd..77512f7 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -5,6 +5,7 @@ checking quantization api and properties of resulting modules. import torch import torch.nn as nn import torch.nn.functional as F +import torch.nn.intrinsic.quantized.dynamic as nniqd import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd from torch.nn.intrinsic import _FusedModule @@ -422,6 +423,13 @@ class QuantizationTestCase(TestCase): self.assertEqual(type(mod), nnqd.Linear) self.assertEqual(mod._packed_params.dtype, dtype) + def checkDynamicQuantizedLinearRelu(self, mod, dtype): + r"""Checks that mod has been swapped for an nnqd.Linear + module, the bias is float. + """ + self.assertEqual(type(mod), nniqd.LinearReLU) + self.assertEqual(mod._packed_params.dtype, dtype) + def check_eager_serialization(self, ref_model, loaded_model, x): # Check state dict serialization and torch.save APIs model_dict = ref_model.state_dict()