[quant][fx] Add support for dynamic linear + relu fusion (INT8) (#63799)
authorSupriya Rao <supriyar@fb.com>
Fri, 27 Aug 2021 04:05:56 +0000 (21:05 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 04:10:46 +0000 (21:10 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63799

Add a new module that can be used for module swap with the nni.LinearReLU module in convert function.
Supports INT8 currently (since FP16 op doesn't have relu fusion yet).

Fixes #55393

Test Plan:
python test/test_quantization.py test_dynamic_fusion

Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D30502812

fbshipit-source-id: 3668e4f001a0626d469e17ac323acf582ee28a51

test/quantization/eager/test_quantize_eager_ptq.py
test/quantization/fx/test_quantize_fx.py
torch/nn/intrinsic/quantized/dynamic/__init__.py [new file with mode: 0644]
torch/nn/intrinsic/quantized/dynamic/modules/__init__.py [new file with mode: 0644]
torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py [new file with mode: 0644]
torch/nn/quantized/dynamic/modules/linear.py
torch/quantization/fx/quantization_patterns.py
torch/quantization/ns/mappings.py
torch/quantization/quantization_mappings.py
torch/testing/_internal/common_quantization.py

index 1824da5..10cbd92 100644 (file)
@@ -42,6 +42,7 @@ from torch.testing._internal.common_quantization import (
     EmbeddingBagModule,
     EmbeddingModule,
     EmbeddingWithLinear,
+    LinearReluLinearModel,
 )
 
 # annotated models
@@ -995,6 +996,23 @@ class TestPostTrainingDynamic(QuantizationTestCase):
         model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
         checkQuantized(model)
 
+    def test_linear_relu_fusion(self):
+        dtype = torch.qint8
+        model = LinearReluLinearModel().eval()
+        qconfig = default_dynamic_qconfig
+        qconfig_dict = {'' : qconfig}
+        torch.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True)
+        prepare_dynamic(model, qconfig_dict)
+        convert_dynamic(model)
+
+        def checkQuantized(model):
+            self.checkDynamicQuantizedLinearRelu(model.fc1, dtype)
+            self.checkDynamicQuantizedLinear(model.fc2, dtype)
+            self.checkScriptable(model, self.calib_data, check_save_load=True)
+            self.checkNoQconfig(model)
+
+        checkQuantized(model)
+
     @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]),
            dtype=st.sampled_from([torch.qint8, torch.float16]))
     def test_quantized_rnn(self, qconfig, dtype):
index 08474d2..cdf2e7b 100644 (file)
@@ -6,6 +6,7 @@ import torch.nn.quantized as nnq
 import torch.nn.quantized.dynamic as nnqd
 import torch.nn.intrinsic as nni
 import torch.nn.intrinsic.quantized as nniq
+import torch.nn.intrinsic.quantized.dynamic as nniqd
 import torch.multiprocessing as mp
 
 # graph mode quantization based on fx
@@ -2883,6 +2884,57 @@ class TestQuantizeFx(QuantizationTestCase):
         self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
         self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
 
+    @skipIfNoFBGEMM
+    def test_dynamic_with_fusion(self):
+        """
+        Tests that dynamic quantization APIs work with Linear + Relu fusion
+        """
+        class LinearRelu(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 5)
+                self.relu = torch.nn.ReLU()
+
+            def forward(self, x):
+                x = self.linear(x)
+                return self.relu(x)
+
+        class Linear(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.w = torch.ones(5, 5)
+                self.b = torch.zeros(5)
+
+            def forward(self, x):
+                return torch.nn.functional.linear(x, self.w, self.b)
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu())
+                self.mods2 = Linear()
+                self.relu = F.relu
+
+            def forward(self, x):
+                x = self.mods1(x)
+                x = self.mods2(x)
+                x = self.relu(x)
+                return x
+
+        model = M().eval()
+        qconfig = {
+            "": default_dynamic_qconfig,
+        }
+        m = prepare_fx(model, qconfig)
+        m = convert_fx(m)
+        m(torch.rand(5, 5))
+        node_list = [
+            ns.call_module(nniqd.LinearReLU),
+            ns.call_module(nniqd.LinearReLU),
+            ns.call_function(torch.ops.quantized.linear_relu_dynamic),
+        ]
+        self.checkGraphModuleNodes(m, expected_node_list=node_list)
+
 @skipIfNoFBGEMM
 class TestQuantizeFxOps(QuantizationTestCase):
     """Unit tests for individual ops
@@ -2956,7 +3008,7 @@ class TestQuantizeFxOps(QuantizationTestCase):
         }
         quant_type_to_qlinear_relu_fun = {
             # we don't have linear_relu_dynamic
-            QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
+            QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic),
             QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu),
             QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu),
         }
diff --git a/torch/nn/intrinsic/quantized/dynamic/__init__.py b/torch/nn/intrinsic/quantized/dynamic/__init__.py
new file mode 100644 (file)
index 0000000..3d79bdb
--- /dev/null
@@ -0,0 +1 @@
+from .modules import *  # noqa: F403
diff --git a/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py b/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py
new file mode 100644 (file)
index 0000000..ce57186
--- /dev/null
@@ -0,0 +1,6 @@
+import torch
+from .linear_relu import LinearReLU
+
+__all__ = [
+    'LinearReLU',
+]
diff --git a/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py
new file mode 100644 (file)
index 0000000..04c4c95
--- /dev/null
@@ -0,0 +1,47 @@
+import torch
+import torch.nn.quantized.dynamic as nnqd
+import torch.nn.intrinsic as nni
+
+class LinearReLU(nnqd.Linear):
+    r"""
+    A LinearReLU module fused from Linear and ReLU modules that can be used
+    for dynamic quantization.
+    Supports both, FP16 and INT8 quantization.
+
+    We adopt the same interface as :class:`torch.nn.quantized.dynamic.Linear`.
+
+    Attributes:
+        Same as torch.nn.quantized.dynamic.Linear
+
+    Examples::
+
+        >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30)
+        >>> input = torch.randn(128, 20)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 30])
+    """
+    _FLOAT_MODULE = nni.LinearReLU  # type: ignore[assignment]
+
+    def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
+        super().__init__(in_features, out_features, bias, dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self._packed_params.dtype == torch.qint8:
+            # TODO check if we should set reduce_rage = True by default here
+            Y = torch.ops.quantized.linear_relu_dynamic(
+                x, self._packed_params._packed_params, reduce_range=True)
+        # TODO Support this in a later PR
+        # elif self._packed_params.dtype == torch.float16:
+        #     Y = torch.ops.quantized.linear_relu_dynamic_fp16(
+        #         x, self._packed_params._packed_params)
+        else:
+            raise RuntimeError('Unsupported dtype on dynamic quantized linear relu!')
+        return Y.to(x.dtype)
+
+    def _get_name(self):
+        return 'DynamicQuantizedLinearReLU'
+
+    @classmethod
+    def from_float(cls, mod):
+        return super(LinearReLU, cls).from_float(mod)
index 07cfdfe..ee153b1 100644 (file)
@@ -1,5 +1,6 @@
 import torch
 import torch.nn.quantized as nnq
+import torch.nn.intrinsic as nni
 from torch.nn.quantized.modules.utils import _quantize_weight
 
 class Linear(nnq.Linear):
@@ -79,11 +80,15 @@ class Linear(nnq.Linear):
             mod (Module): a float module, either produced by torch.quantization
                           utilities or provided by the user
         """
-        float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
+        float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
+                         torch.nn.intrinsic.modules.fused.LinearReLU]
+
         assert type(mod) in float_modules, \
             'nn.quantized.dynamic.Linear.from_float only works for one of' + \
             str([float_mod.__name__ for float_mod in float_modules])
         assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        if type(mod) == nni.LinearReLU:
+            mod = mod[0]
         if mod.qconfig is not None and mod.qconfig.weight is not None:
             weight_observer = mod.qconfig.weight()
         else:
@@ -102,6 +107,6 @@ class Linear(nnq.Linear):
             qweight = mod.weight.float()
         else:
             raise RuntimeError('Unsupported dtype specified for dynamic quantized Linear!')
-        qlinear = Linear(mod.in_features, mod.out_features, dtype=dtype)
+        qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
         qlinear.set_weight_bias(qweight, mod.bias)
         return qlinear
index 09ca190..b7c39ca 100644 (file)
@@ -1022,9 +1022,14 @@ class LinearReLUQuantizeHandler(QuantizeHandler):
                 elif dtypes in [(torch.float32, torch.qint8, torch.quint8),
                                 (torch.float32, torch.float16, None)]:
                     # choose linear dynamic or linear dynamic fp16 op based on weight dtype
-                    qlinear_op = torch.ops.quantized.linear_dynamic \
-                        if weight_dtype == torch.qint8 \
-                        else torch.ops.quantized.linear_dynamic_fp16
+                    if weight_dtype == torch.qint8:
+                        if self.relu_node:
+                            qlinear_op = torch.ops.quantized.linear_relu_dynamic
+                        else:
+                            qlinear_op = torch.ops.quantized.linear_dynamic
+                    else:  # TODO add support for fp16 + relu fusion in a later PR
+                        qlinear_op = torch.ops.quantized.linear_dynamic_fp16
+
                     linear_input = load_arg(quantized=torch.float)(self.linear_node.args[0])
                     qlinear_args = (linear_input, packed_weight)  # type: ignore[assignment]
                     op_out = quantized_graph.create_node(
@@ -1033,7 +1038,7 @@ class LinearReLUQuantizeHandler(QuantizeHandler):
                     # TODO: may need to change the key to Node regenerate the map in each transformation,
                     # since we might not be able to rely on the name
                     node_name_to_scope[op_out.name] = node_name_to_scope[self.linear_node.name]
-                    if self.relu_node:
+                    if self.relu_node and weight_dtype is not torch.qint8:
                         op_out = quantized_graph.create_node("call_function", torch.nn.functional.relu, (op_out,), {})
                     return op_out
                 else:
index 399ddca..e97d771 100644 (file)
@@ -8,6 +8,7 @@ toq = torch.ops.quantized
 import torch.nn.quantized as nnq
 import torch.nn.quantized.dynamic as nnqd
 import torch.nn.intrinsic.quantized as nniq
+import torch.nn.intrinsic.quantized.dynamic as nniqd
 import torch.nn.intrinsic.qat as nniqat
 import torch.nn.intrinsic as nni
 import torch.nn.qat as nnqat
@@ -70,6 +71,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
             nnq.Linear,
             nni.LinearReLU,
             nniq.LinearReLU,
+            nniqd.LinearReLU,
             nnqat.Linear,
             nnqd.Linear,
             nniqat.LinearReLU,
@@ -529,6 +531,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
         nniqat.ConvReLU2d,
         nniqat.ConvReLU3d,
         nniqat.LinearReLU,
+        nniqd.LinearReLU,
     ])
 
     MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
index 6179398..775d40b 100644 (file)
@@ -6,6 +6,7 @@ from torch import nn
 import torch.nn.functional as F
 import torch.nn.intrinsic as nni
 import torch.nn.intrinsic.quantized as nniq
+import torch.nn.intrinsic.quantized.dynamic as nniqd
 import torch.nn.intrinsic.quantized._reference as nniqr
 import torch.nn.intrinsic.qat as nniqat
 import torch.nn.quantized as nnq
@@ -122,6 +123,7 @@ DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
     nn.GRU: nnqd.GRU,
     nn.LSTMCell: nnqd.LSTMCell,
     nn.RNNCell: nnqd.RNNCell,
+    nni.LinearReLU: nniqd.LinearReLU,
 }
 
 # Allowlist for propagating the qconfig
index 6b2d1dd..77512f7 100644 (file)
@@ -5,6 +5,7 @@ checking quantization api and properties of resulting modules.
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+import torch.nn.intrinsic.quantized.dynamic as nniqd
 import torch.nn.quantized as nnq
 import torch.nn.quantized.dynamic as nnqd
 from torch.nn.intrinsic import _FusedModule
@@ -422,6 +423,13 @@ class QuantizationTestCase(TestCase):
         self.assertEqual(type(mod), nnqd.Linear)
         self.assertEqual(mod._packed_params.dtype, dtype)
 
+    def checkDynamicQuantizedLinearRelu(self, mod, dtype):
+        r"""Checks that mod has been swapped for an nnqd.Linear
+            module, the bias is float.
+        """
+        self.assertEqual(type(mod), nniqd.LinearReLU)
+        self.assertEqual(mod._packed_params.dtype, dtype)
+
     def check_eager_serialization(self, ref_model, loaded_model, x):
         # Check state dict serialization and torch.save APIs
         model_dict = ref_model.state_dict()