Add TRTSplitter (#64762)
authorShirong Wu <shirong@fb.com>
Fri, 10 Sep 2021 04:02:15 +0000 (21:02 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 04:08:57 +0000 (21:08 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64762

Extract and format TRTSplitter from fx2trt_example code, current implementation is tentative, subject to changed based on feeds model lowering progress.

Test Plan:
manul print of supported operator:
`{<class 'torch.nn.modules.activation.ReLU'>: None, <function relu at 0x7f9b1abd0790>: None, <class 'torch.nn.modules.activation.Sigmoid'>: None, <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>: None, <built-in method add of type object at 0x7f9b7f402498>: None, <built-in function add>: None, <built-in method add of PyCapsule object at 0x7f9b1a3dc690>: None, <built-in method add_relu of PyCapsule object at 0x7f9b1a34cf90>: None, <class 'torch.nn.modules.batchnorm.BatchNorm2d'>: None, <class 'torch.nn.quantized.modules.batchnorm.BatchNorm2d'>: None, <class 'torch.nn.modules.conv.Conv2d'>: None, <class 'torch.nn.quantized.modules.conv.Conv2d'>: None, <class 'torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d'>: None, <class 'torch.nn.modules.linear.Linear'>: None, <class 'torch.nn.quantized.modules.linear.Linear'>: None, <class 'torch.nn.modules.pooling.MaxPool2d'>: None, <built-in function mul>: None, <built-in method mul of type object at 0x7f9b7f402498>: None, <built-in method mul of PyCapsule object at 0x7f9b1a3dc6c0>: None, <built-in method flatten of type object at 0x7f9b7f402498>: None, <class 'torch.nn.quantized.modules.DeQuantize'>: None, <built-in method dequantize of type object at 0x7f9b7f402498>: None, 'dequantize': None, <class 'torch.nn.quantized.modules.Quantize'>: None, <built-in method quantize_per_tensor of type object at 0x7f9b7f402498>: None, <class 'torch.nn.modules.linear.Identity'>: None, <function conv2d at 0x7f9b1a1fe9d0>: None, <function flatten at 0x7f9b1a1f5ca0>: None, <function size at 0x7f9b1a1f5b80>: None, <function batch_norm at 0x7f9b1a1feaf0>: None, <function layer_norm at 0x7f9b1a1feb80>: None, <function softmax at 0x7f9b1a1f9550>: None, <function relu at 0x7f9b1a1fe040>: None, <function sin at 0x7f9b1a2030d0>: None, <function cos at 0x7f9b1a203160>: None, <function tan at 0x7f9b1a2031f0>: None, <function sinh at 0x7f9b1a1fe160>: None, <function cosh at 0x7f9b1a1fe280>: None, <function tanh at 0x7f9b1a1fe310>: None, <function asin at 0x7f9b1a1fe3a0>: None, <function acos at 0x7f9b1a1fe430>: None, <function atan at 0x7f9b1a1fe4c0>: None, <function exp at 0x7f9b1a1fe550>: None, <function log at 0x7f9b1a1fe5e0>: None, <function sqrt at 0x7f9b1a1fe670>: None, <function reciprocal at 0x7f9b1a1fe700>: None, <function abs at 0x7f9b1a1fe790>: None, <function neg at 0x7f9b1a1fe820>: None, <function floor at 0x7f9b1a1fe8b0>: None, <function ceil at 0x7f9b1a1fe940>: None, <function sum at 0x7f9b1a1f9c10>: None, <function max_pool2d at 0x7f9b1a1f5d30>: None, <function squeeze at 0x7f9b1a1f5c10>: None, <function add at 0x7f9b1a1f91f0>: None, <function sub at 0x7f9b1a1f9ca0>: None, <function div at 0x7f9b1a1f9dc0>: None, <function mul at 0x7f9b1a1f9d30>: None, <function pow at 0x7f9b1a1f9e50>: None, <function min_two_tensors_input at 0x7f9b1a1f9940>: None, <function unsqueeze at 0x7f9b1a1f9280>: None, <function topk at 0x7f9b1a203280>: None, <function adaptive_avg_pool2d at 0x7f9b1a1f5dc0>: None, <function avg_pool2d at 0x7f9b1a1f5e50>: None, <function reshape at 0x7f9b1a203550>: None, <function slice_tensor at 0x7f9b1a1fee50>: None, <function split at 0x7f9b1a1fec10>: None, <function linear at 0x7f9b1a1f51f0>: None, <function clamp at 0x7f9b1a1f93a0>: None, <function tuple_construct at 0x7f9b1a1fed30>: None, <function contiguous at 0x7f9b1a1f9430>: None, <function getitem at 0x7f9b1a203310>: None, <function cat at 0x7f9b1a1f9310>: None, <function transpose at 0x7f9b1a1f94c0>: None, <function matmul at 0x7f9b1a1f98b0>: None, <function sigmoid at 0x7f9b1a1fe1f0>: None, <function permute at 0x7f9b1a1f9670>: None, <function quantize_per_tensor at 0x7f9b1a1f9b80>: None, <function dequantize at 0x7f9b1a1f99d0>: None, <function sign at 0x7f9b1a1f5ee0>: None}`

Reviewed By: 842974287

Differential Revision: D30798047

fbshipit-source-id: 69076a550874425b7186fbbf2ecf03da4a99b42f

torch/fx/experimental/fx2trt/tools/__init__.py [new file with mode: 0644]
torch/fx/experimental/fx2trt/tools/trt_minimizer.py [new file with mode: 0644]
torch/fx/experimental/fx2trt/tools/trt_splitter.py [new file with mode: 0644]

diff --git a/torch/fx/experimental/fx2trt/tools/__init__.py b/torch/fx/experimental/fx2trt/tools/__init__.py
new file mode 100644 (file)
index 0000000..118b2ee
--- /dev/null
@@ -0,0 +1 @@
+from .trt_minimizer import *  # noqa: F403
diff --git a/torch/fx/experimental/fx2trt/tools/trt_minimizer.py b/torch/fx/experimental/fx2trt/tools/trt_minimizer.py
new file mode 100644 (file)
index 0000000..0ae886c
--- /dev/null
@@ -0,0 +1,64 @@
+from typing import Tuple, Callable, Any
+
+import torch
+import torch.fx.passes.net_min_base as net_min_base
+from torch.fx.experimental.fx2trt.fx2trt import (
+    TRTModule,
+    TRTInterpreter,
+    InputTensorSpec,
+)
+from torch.fx.passes.tools_common import Tensors
+
+
+def lower_mod_default(
+    mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048
+) -> TRTModule:
+    interp = TRTInterpreter(
+        mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
+    )
+    res_mod = TRTModule(*interp.run(max_batch_size=batch_size))
+    return res_mod
+
+
+class TensorRTMinizerSetting(net_min_base._MinimizerSettingBase):
+    def __init__(self, explicit_batch_dimension: Any = True):
+        self.explicit_batch_dimension = explicit_batch_dimension
+        super(TensorRTMinizerSetting, self).__init__()
+
+
+class TensorRTMinimizer(net_min_base._MinimizerBase):
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        sample_input: Tensors,
+        compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]],
+        settings: TensorRTMinizerSetting = TensorRTMinizerSetting(),
+        max_batch_size: Any = 2048,
+        lower_fn: Callable[[torch.fx.GraphModule, Tensors, Any], TRTModule] = lower_mod_default,
+    ):
+        self.lower_fn = lower_fn
+        self.max_batch_size = max_batch_size
+        super().__init__(module, sample_input, compare_fn, settings)
+
+    def run_a(self, mod, inputs):
+        mod.eval()
+        with torch.no_grad():
+            return mod(*inputs)
+
+    def run_b(self, mod, inputs):
+        mod.eval()
+        try:
+            mod = self.lower_fn(mod, inputs, self.max_batch_size)
+            output = mod(*inputs)
+        except RuntimeError as e:
+            raise net_min_base.FxNetMinimizerRunFuncError(
+                f"Encounter an error when processing \n{mod.graph}\n {e}"
+            )
+        else:
+            return output
+
+    def get_nodes(self, start=None, end=None, enable_print=False):
+        nodes = self._collect_nodes(start, end)
+        if enable_print:
+            print(f"Nodes fetched from start {start} to end {end} as: {nodes}")
+        return nodes
diff --git a/torch/fx/experimental/fx2trt/tools/trt_splitter.py b/torch/fx/experimental/fx2trt/tools/trt_splitter.py
new file mode 100644 (file)
index 0000000..97af41a
--- /dev/null
@@ -0,0 +1,72 @@
+from typing import Iterable, Tuple
+
+import torch
+import torch.fx.passes.splitter_base as splitter_base
+from torch.fx.experimental.fx2trt.tools.trt_minimizer import TensorRTMinimizer
+from torch.fx.experimental.fx2trt.fx2trt import (
+    InputTensorSpec,
+    TRTModule,
+    TRTInterpreter,
+    CONVERTERS,
+)
+from torch.fx.passes.operator_support import OperatorSupport
+from torch.fx.passes.tools_common import Tensors
+
+
+class TRTOperatorSupport(OperatorSupport):
+    def __init__(self):
+        self._support_dict = {}
+        for k in CONVERTERS.keys():
+            self._support_dict[k] = None
+
+
+class TRTSplitter(splitter_base._SplitterBase):
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        sample_input: Tuple[torch.Tensor],
+        operator_support: OperatorSupport = None,
+        settings: splitter_base._SplitterSettingBase = None,
+    ):
+        if not operator_support:
+            operator_support = TRTOperatorSupport()
+        if not settings:
+            settings = splitter_base._SplitterSettingBase()
+
+        super().__init__(module, sample_input, operator_support, settings)
+
+    def _lower_model_to_backend(
+        self,
+        mod: torch.fx.GraphModule,
+        inputs: Iterable[torch.Tensor]
+    ):
+        """
+        Lower a GraphModule `mod` to TensorRT with `inputs`.
+        """
+        # Current code for lowering is place-holder, subject to future change
+        # based on feeds model's actual status
+        interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
+        engine, input_names, output_names = interp.run(*inputs)
+        return TRTModule(engine, input_names, output_names)
+
+    def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors):
+        """
+        This function serves the preview functionality in Splitter. When previewing
+        splitting result, if something wrong happens during lowering model to TensorRT
+        or running a TensorRT model, this function will be called to find any culprit
+        that is responsible for the error.
+        """
+        # Since we don't care about accuracy here, we pass in a dummy compare function.
+        minimizer = TensorRTMinimizer(mod, inputs, lambda a, b, c: (1, True))
+        minimizer.settings.traverse_method = "sequential"
+        minimizer.settings.find_all = True
+        culprits = minimizer.minimize()
+
+        if len(culprits) == 0:
+            reports = "Unable to find a culprit!\n"
+        else:
+            reports = "Found some problematic nodes:\n"
+            for node in culprits:
+                reports += f"{node.format_node()}\n"
+
+        return reports