[quant][tensorrt] Add tensorrt backend config (#64623)
authorJerry Zhang <jerryzh@fb.com>
Tue, 14 Sep 2021 22:26:03 +0000 (15:26 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 22:27:33 +0000 (15:27 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64623

The config api will change, but we'll add configs gradually for TensorRT to unblock experimentation

Test Plan:
python torch/fx/experimental/fx2trt/example/unittests.py

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D30800474

fbshipit-source-id: 3c4640de1205a0f19b62943ab84f386d80394ec2

torch/fx/experimental/fx2trt/example/unittests.py [new file with mode: 0644]
torch/quantization/fx/__init__.py
torch/quantization/fx/backend_config_dict/__init__.py
torch/quantization/fx/backend_config_dict/tensorrt.py [new file with mode: 0644]
torch/quantization/fx/quantization_patterns.py
torch/quantization/quantize_fx.py

diff --git a/torch/fx/experimental/fx2trt/example/unittests.py b/torch/fx/experimental/fx2trt/example/unittests.py
new file mode 100644 (file)
index 0000000..d625675
--- /dev/null
@@ -0,0 +1,103 @@
+import torch
+from torch.quantization.quantize_fx import (
+    prepare_fx,
+    convert_fx,
+    get_tensorrt_backend_config_dict
+)
+import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer
+from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule
+from torch.testing._internal.common_quantization import QuantizationTestCase
+from torch.testing._internal.common_cuda import TEST_CUDA
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.common_quantization import NodeSpec as ns
+
+import unittest
+
+def lower_to_trt(model, sample_input, shape_ranges):
+    model = acc_tracer.trace(model, [sample_input])  # type: ignore[attr-defined]
+    interp = TRTInterpreter(
+        model,
+        [InputTensorSpec(
+            torch.Size([-1, *sample_input.shape[1:]]), torch.float,
+            shape_ranges=shape_ranges, has_batch_dim=True)],
+        explicit_batch_dimension=True, explicit_precision=True)
+    engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
+    trt_mod = TRTModule(engine, input_names, output_names)
+    return trt_mod
+
+
+
+@unittest.skipIf(not TEST_CUDA, "gpu is not available.")
+class TestQuantizeFxTRT(QuantizationTestCase):
+    def test_conv(self):
+        class Conv2d(torch.nn.Module):
+            def __init__(self, *args):
+                super().__init__()
+                self.conv = torch.nn.Conv2d(*args)
+
+            def forward(self, x):
+                return self.conv(x)
+
+        conv2d_input = torch.rand(1, 3, 224, 224)
+        conv2d_module_args = (3, 3, 3)
+
+        m = Conv2d(*conv2d_module_args).eval()
+        qconfig = torch.quantization.QConfig(
+            activation=torch.quantization.observer.HistogramObserver.with_args(
+                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
+            ),
+            weight=torch.quantization.default_weight_observer
+        )
+        prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict())
+        # calibration
+        prepared(conv2d_input)
+        quantized = convert_fx(prepared, is_reference=True)
+        node_occurrence = {
+            ns.call_function(torch.quantize_per_tensor): 1,
+            ns.call_method("dequantize"): 1
+        }
+        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
+        # lower to trt
+        trt_mod = lower_to_trt(quantized, conv2d_input, [((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))])
+        # make sure it runs
+        trt_mod(conv2d_input.cuda())
+
+    def test_linear(self):
+        class LinearModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 10)
+
+            def forward(self, x):
+                return self.linear(x)
+
+        linear_module_input = torch.rand(8, 5)
+
+        m = LinearModule().eval()
+        qconfig = torch.quantization.QConfig(
+            activation=torch.quantization.observer.HistogramObserver.with_args(
+                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
+            ),
+            weight=torch.quantization.default_weight_observer
+        )
+        prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict())
+        # calibration
+        prepared(linear_module_input)
+        quantized = convert_fx(prepared, is_reference=True)
+        node_occurrence = {
+            ns.call_function(torch.quantize_per_tensor): 1,
+            ns.call_method("dequantize"): 1
+        }
+        self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
+        # lower to trt
+        trt_mod = lower_to_trt(
+            quantized,
+            linear_module_input,
+            [((1, *linear_module_input.shape[1:]),
+              (5, *linear_module_input.shape[1:]),
+              (10, *linear_module_input.shape[1:]))])
+        # make sure it runs
+        trt_mod(linear_module_input.cuda())
+
+if __name__ == '__main__':
+    run_tests()
index 8ba1b6d..01ee7e8 100644 (file)
@@ -1,3 +1,5 @@
 from .prepare import prepare
 from .convert import convert
 from .fuse import Fuser
+from .backend_config_dict import get_fbgemm_backend_config_dict
+from .backend_config_dict import get_tensorrt_backend_config_dict
index edb2b95..c50e61a 100644 (file)
@@ -1,4 +1,5 @@
 from .fbgemm import get_fbgemm_backend_config_dict
+from .tensorrt import get_tensorrt_backend_config_dict
 
 def validate_backend_config_dict(backend_config_dict):
     return "quant_patterns" in backend_config_dict
diff --git a/torch/quantization/fx/backend_config_dict/tensorrt.py b/torch/quantization/fx/backend_config_dict/tensorrt.py
new file mode 100644 (file)
index 0000000..932f491
--- /dev/null
@@ -0,0 +1,15 @@
+import torch
+from ..quantization_patterns import ConvReLUQuantizeHandlerNew, LinearReLUQuantizeHandlerNew
+
+def get_tensorrt_backend_config_dict():
+    """ Get the backend config dictionary for tensorrt backend
+    NOTE: Current api will change in the future, it's just to unblock experimentation for
+    new backends, please don't use it right now.
+    """
+    quant_patterns = {
+        torch.nn.Conv2d: ConvReLUQuantizeHandlerNew,
+        torch.nn.Linear: LinearReLUQuantizeHandlerNew
+    }
+    return {
+        "quant_patterns": quant_patterns
+    }
index 7ae07bf..df4c21a 100644 (file)
@@ -90,7 +90,6 @@ class QuantizeHandler(ABC):
                     return maybe_obs
         return None
 
-
     def input_output_observed(self) -> bool:
         """
         Returns True if the pattern matched to this qhandler could be
@@ -1690,3 +1689,309 @@ class StandaloneModuleQuantizeHandler(QuantizeHandler):
         setattr(modules[parent_name], name, quantized_standalone_module)
         modules[str(node.target)] = quantized_standalone_module
         return quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))
+
+
+class ConvReLUQuantizeHandlerNew(QuantizeHandler):
+    """ This is to unblock perf testing for TensorRT, this will be
+    changed in the future so don't depend on this.
+    """
+    def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
+        super().__init__(node, modules)
+        self.relu_node = None
+        if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
+           (node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)):
+            self.relu_node = node
+            node = node.args[0]  # type: ignore[assignment]
+        self.conv_node = node
+        if node.op == "call_module":
+            self.conv = modules[str(self.conv_node.target)]
+        elif node.op == "call_function":
+            self.conv = node.target  # type: ignore[assignment]
+
+    def should_insert_observer_for_output(
+        self,
+        qconfig: Any,
+        model_is_training: bool,
+    ) -> bool:
+        return False
+
+    def is_output_quantized(self, qconfig, is_reference):
+        return False
+
+    def convert(self,
+                node: Node,
+                qconfig: QConfigAny,
+                modules: Dict[str, torch.nn.Module],
+                quantized_graph: Graph,
+                node_name_to_scope: Dict[str, Tuple[str, type]],
+                load_arg: Callable,
+                is_reference: bool = False,
+                convert_custom_config_dict: Dict[str, Any] = None) -> Node:
+        assert is_reference, "ConvReLUQuantizeHandlerNew only works for the case when is_reference=True"
+        activation_int8_quantized = activation_is_int8_quantized(qconfig)
+        if self.conv_node.op == 'call_module':
+            output_activation_post_process = \
+                self._maybe_get_last_node_only_observer(modules)
+            # note that relu should already be fused into conv module in the fusion step
+            assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
+                'please make sure to run fusion before prepare'
+            # produce dequant - float_op - quant pattern
+            dtype = torch.float
+            if activation_int8_quantized:
+                dtype = activation_dtype(qconfig)
+            activation = load_arg(quantized=dtype)(self.conv_node.args[0])
+            args = load_arg(quantized=torch.float)(self.conv_node.args)
+            # Get the float conv and attach quantization scheme and quantization
+            # parameters of weight to the module
+            # and qparam is a dictionary of
+            # {"qscheme": ..., "scale": ..., "zero_point": ...} for per tensor quantization or
+            # {"qscheme": ..., "scale": ..., "zero_point": ..., "axis": ...} for per channel quantization
+            float_conv = self.conv
+            fused_conv = None
+            if isinstance(
+                    float_conv,
+                    QAT_CONV_MODULE_CLASSES):
+                # case 1. converting qat conv module to
+                # a float conv module, we need to attch
+                # weight fake_quant to the conv module,
+                # weight fake_quant is assumed to be run during
+                # QAT so we don't need to run it again here
+                float_conv = self.conv.to_float()  # type: ignore[operator]
+                # change qat conv to conv
+                parent_name, name = _parent_name(self.conv_node.target)
+                setattr(modules[parent_name], name, float_conv)
+                if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
+                    fused_conv = float_conv
+                    float_conv = float_conv[0]
+                weight_post_process = self.conv.weight_fake_quant
+            else:
+                # case 2. converting a conv module/fused conv module
+                # to float conv module, we need to attach
+                # weight observer to the conv module and run it
+                # with conv weight
+                if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
+                    fused_conv = float_conv
+                    float_conv = float_conv[0]  # type: ignore[index]
+                assert qconfig is not None
+                weight_post_process = qconfig.weight()
+                # run weight observer
+                weight_post_process(float_conv.weight)  # type: ignore[operator]
+            weight_qparams = get_qparam_dict(weight_post_process)
+            # hardcoded for now, TODO: expose the api to user,
+            # we can have a map from module to reference module
+            # and allow user to register new ones
+            qconv_cls = get_static_quant_module_class(
+                type(float_conv), is_reference=is_reference)
+            ref_conv = qconv_cls.from_float(float_conv, weight_qparams)  # type: ignore[attr-defined]
+            # if the parent is a fused conv (Sequential), we can replace the first
+            # item to ref conv, otherwise we can update
+            # the conv instance in the module tree
+            if fused_conv is not None:
+                fused_conv[0] = ref_conv
+            else:
+                parent_name, name = _parent_name(self.conv_node.target)
+                setattr(modules[parent_name], name, ref_conv)
+            op_out = quantized_graph.create_node(
+                'call_module',
+                self.conv_node.target,
+                args, {})
+            # disabling quantize node for output for now, this will be controlled by the
+            # backend_config_dict in the final design
+            if output_activation_post_process:
+                op_out = quantize_node(
+                    op_out,
+                    output_activation_post_process,
+                    node,
+                    modules,
+                    quantized_graph,
+                    node_name_to_scope,
+                    is_input=False)
+            return op_out
+        else:
+            assert self.conv_node.op == "call_function"
+            # make sure the input and weight are quantized to torch.quint8, torch.qint8, respectively
+            load_arg(quantized={0: torch.quint8, 1: torch.qint8})(self.conv_node.args)
+            args = load_arg(quantized=torch.float)(self.conv_node.args)
+            kwargs = load_arg(quantized=torch.float)(self.conv_node.kwargs)
+            op_out = quantized_graph.create_node(
+                "call_function", self.conv, args, kwargs)
+            if self.relu_node:
+                relu_args = [op_out]
+                relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
+                relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
+                op_out = quantized_graph.create_node(
+                    "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
+
+            # disabling quantize node for output for now, this will be controlled by the
+            # backend_config_dict in the final design
+            # if activation_int8_quantized:
+            #     root_module = modules['']
+            #     act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
+            #     act_post_process_node = self.relu_node if self.relu_node else self.conv_node
+            #     activation_post_process = \
+            #         self._maybe_get_last_node_only_observer(modules)
+            #     assert activation_post_process is not None
+            #     return quantize_node(
+            #         op_out,
+            #         activation_post_process,
+            #         act_post_process_node,
+            #         modules,
+            #         quantized_graph,
+            #         node_name_to_scope,
+            #         is_input=False)
+            # else:
+            #     # output for dynamically quantized conv op is not quantized
+            #     return op_out
+            return op_out
+
+class LinearReLUQuantizeHandlerNew(QuantizeHandler):
+    """ This is to unblock perf testing for TensorRT, this will be
+    changed in the future so don't depend on this.
+    """
+    def __init__(
+            self,
+            node: Node,
+            modules: Dict[str, torch.nn.Module]):
+        super().__init__(node, modules)
+        self.relu_node = None
+        if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
+           (node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)):
+            self.relu_node = node
+            node = node.args[0]  # type: ignore[assignment]
+        self.linear_node = node
+        if node.op == 'call_module':
+            self.linear = modules[str(self.linear_node.target)]
+
+    def should_insert_observer_for_output(
+        self,
+        qconfig: Any,
+        model_is_training: bool,
+    ) -> bool:
+        return False
+
+    def is_output_quantized(self, qconfig, is_reference):
+        return False
+
+    def convert(self,
+                node: Node,
+                qconfig: QConfigAny,
+                modules: Dict[str, torch.nn.Module],
+                quantized_graph: Graph,
+                node_name_to_scope: Dict[str, Tuple[str, type]],
+                load_arg: Callable,
+                is_reference: bool = False,
+                convert_custom_config_dict: Dict[str, Any] = None) -> Node:
+        assert is_reference, "LinearReLUQuantizeHandlerNew only works for the case when is_reference=True"
+        if convert_custom_config_dict is None:
+            convert_custom_config_dict = {}
+        dtypes = get_qconfig_dtypes(qconfig)
+        activation_int8_quantized = activation_is_int8_quantized(qconfig)
+        activation_statically_quantized = activation_is_statically_quantized(qconfig)
+        weight_dtype = dtypes[1]
+        if self.linear_node.op == 'call_module':
+
+            output_activation_post_process = \
+                self._maybe_get_last_node_only_observer(modules)
+
+            # note that relu should already be fused into linear modul in the fusion step
+            assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
+                'please make sure to run fusion before prepare'
+            # produce dequant - float_op - quant pattern
+            dtype = torch.float
+            if activation_int8_quantized:
+                dtype = activation_dtype(qconfig)
+            activation = load_arg(quantized=dtype)(self.linear_node.args[0])
+            args = load_arg(quantized=torch.float)(self.linear_node.args)
+            # Get the float linear and attach qscheme and qparams
+            # the the module
+            float_linear = self.linear
+            fused_linear = None
+            if isinstance(float_linear, (torch.nn.qat.Linear, torch.nn.intrinsic.qat.LinearReLU)):
+                float_linear = float_linear.to_float()
+                # change qat linear to linear
+                parent_name, name = _parent_name(self.linear_node.target)
+                setattr(modules[parent_name], name, float_linear)
+                # Attach weight fake quant to the linear module
+                if isinstance(float_linear, torch.nn.intrinsic.LinearReLU):
+                    fused_linear = float_linear
+                    float_linear = float_linear[0]
+                weight_post_process = self.linear.weight_fake_quant
+            else:
+                if isinstance(float_linear, torch.nn.intrinsic.LinearReLU):
+                    fused_linear = float_linear
+                    float_linear = self.linear[0]  # type: ignore[index]
+                # Attach the weight observer to the module
+                weight_post_process = qconfig.weight()  # type: ignore[union-attr]
+                # Run weight observer
+                weight_post_process(float_linear.weight)  # type: ignore[operator]
+
+            weight_qparams = get_qparam_dict(weight_post_process)
+            # TODO: include the configuration in backend_config_dict
+            # we can have a map from module to reference module
+            # and allow user to register new ones
+            qlinear_cls = get_static_quant_module_class(
+                type(float_linear), is_reference=is_reference)
+            ref_linear = qlinear_cls.from_float(float_linear, weight_qparams)
+
+            # if the parent is a fused linear (Sequential), we can replace the first
+            # item to ref linear, otherwise we can update
+            # the linear instance in the module tree
+            if fused_linear is not None:
+                fused_linear[0] = ref_linear
+            else:
+                parent_name, name = _parent_name(self.linear_node.target)
+                setattr(modules[parent_name], name, ref_linear)
+            op_out = quantized_graph.create_node(
+                'call_module',
+                self.linear_node.target,
+                args, {})
+            if output_activation_post_process:
+                op_out = quantize_node(
+                    op_out,
+                    output_activation_post_process,
+                    node,
+                    modules,
+                    quantized_graph,
+                    node_name_to_scope,
+                    is_input=False)
+            return op_out
+        else:  # call_function
+            assert self.linear_node.op == 'call_function'
+            quantized_input_dtypes = [torch.float, torch.float]
+            if activation_int8_quantized:
+                quantized_input_dtypes[0] = torch.quint8
+            if weight_is_statically_quantized(qconfig):
+                quantized_input_dtypes[1] = torch.qint8
+            args = load_arg(quantized=quantized_input_dtypes)(self.linear_node.args)
+            args = load_arg(quantized=torch.float)(self.linear_node.args)
+            kwargs = load_arg(quantized=torch.float)(self.linear_node.kwargs)
+            op_out = quantized_graph.create_node(
+                "call_function", torch.nn.functional.linear, args, kwargs)
+            if self.relu_node:
+                relu_args = [op_out]
+                relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
+                relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
+                op_out = quantized_graph.create_node(
+                    "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
+
+            return op_out
+            # TODO: enable later
+            # if activation_statically_quantized:
+            #     # quantize output for statically quantized linear op
+            #     root_module = modules['']
+            #     act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name
+            #     act_post_process_node = self.relu_node if self.relu_node else self.linear_node
+            #     activation_post_process = \
+            #         self._maybe_get_last_node_only_observer(modules)
+            #     assert activation_post_process is not None
+            #     return quantize_node(
+            #         op_out,
+            #         activation_post_process,
+            #         act_post_process_node,
+            #         modules,
+            #         quantized_graph,
+            #         node_name_to_scope,
+            #         is_input=False)
+            #     else:
+            #         # output for dynamically quantized linear op is not quantized
+            #         return op_out
index 2dd98ea..c6cb824 100644 (file)
@@ -4,6 +4,8 @@ from torch.fx._symbolic_trace import Tracer
 from torch.fx.node import Target, Node, Argument
 from .fx import Fuser  # noqa: F401
 from .fx import prepare, convert  # noqa: F401
+from .fx import get_fbgemm_backend_config_dict  # noqa: F401
+from .fx import get_tensorrt_backend_config_dict  # noqa: F401
 from .fx.utils import graph_pretty_str  # noqa: F401
 from .fx.utils import get_custom_module_class_keys  # noqa: F401
 from .fx.graph_module import ObservedGraphModule, QuantizedGraphModule
@@ -139,7 +141,8 @@ class QuantizationTracer(Tracer):
         self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type)
         return node
 
-def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
+def _prepare_fx(model: torch.nn.Module,
+                qconfig_dict: Any,
                 prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
                 equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
                 backend_config_dict: Optional[Dict[str, Any]] = None,
@@ -195,6 +198,7 @@ forward graph of the parent module,
         tracer.node_name_to_scope,
         prepare_custom_config_dict=prepare_custom_config_dict,
         equalization_qconfig_dict=equalization_qconfig_dict,
+        backend_config_dict=backend_config_dict,
         is_standalone_module=is_standalone_module)
 
     for attr_name in preserved_attributes:
@@ -428,7 +432,12 @@ def prepare_fx(
     torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
     assert not model.training, 'prepare_fx only works for models in ' + \
         'eval mode'
-    return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict)
+    return _prepare_fx(
+        model,
+        qconfig_dict,
+        prepare_custom_config_dict,
+        equalization_qconfig_dict,
+        backend_config_dict)
 
 def prepare_qat_fx(
         model: torch.nn.Module, qconfig_dict: Any,
@@ -467,7 +476,11 @@ def prepare_qat_fx(
     torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
     assert model.training, 'prepare_qat_fx only works for models in  ' + \
         'train mode'
-    return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, backend_config_dict)
+    return _prepare_fx(
+        model,
+        qconfig_dict,
+        prepare_custom_config_dict,
+        backend_config_dict=backend_config_dict)
 
 def _convert_fx(
         graph_module: GraphModule, is_reference: bool,