From 95d98dfeec09a2a8685519cb76c1caa2cff7de49 Mon Sep 17 00:00:00 2001 From: Shirong Wu Date: Thu, 9 Sep 2021 21:02:15 -0700 Subject: [PATCH] Add TRTSplitter (#64762) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64762 Extract and format TRTSplitter from fx2trt_example code, current implementation is tentative, subject to changed based on feeds model lowering progress. Test Plan: manul print of supported operator: `{: None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, 'dequantize': None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None, : None}` Reviewed By: 842974287 Differential Revision: D30798047 fbshipit-source-id: 69076a550874425b7186fbbf2ecf03da4a99b42f --- torch/fx/experimental/fx2trt/tools/__init__.py | 1 + .../fx/experimental/fx2trt/tools/trt_minimizer.py | 64 +++++++++++++++++++ torch/fx/experimental/fx2trt/tools/trt_splitter.py | 72 ++++++++++++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 torch/fx/experimental/fx2trt/tools/__init__.py create mode 100644 torch/fx/experimental/fx2trt/tools/trt_minimizer.py create mode 100644 torch/fx/experimental/fx2trt/tools/trt_splitter.py diff --git a/torch/fx/experimental/fx2trt/tools/__init__.py b/torch/fx/experimental/fx2trt/tools/__init__.py new file mode 100644 index 0000000..118b2ee --- /dev/null +++ b/torch/fx/experimental/fx2trt/tools/__init__.py @@ -0,0 +1 @@ +from .trt_minimizer import * # noqa: F403 diff --git a/torch/fx/experimental/fx2trt/tools/trt_minimizer.py b/torch/fx/experimental/fx2trt/tools/trt_minimizer.py new file mode 100644 index 0000000..0ae886c --- /dev/null +++ b/torch/fx/experimental/fx2trt/tools/trt_minimizer.py @@ -0,0 +1,64 @@ +from typing import Tuple, Callable, Any + +import torch +import torch.fx.passes.net_min_base as net_min_base +from torch.fx.experimental.fx2trt.fx2trt import ( + TRTModule, + TRTInterpreter, + InputTensorSpec, +) +from torch.fx.passes.tools_common import Tensors + + +def lower_mod_default( + mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048 +) -> TRTModule: + interp = TRTInterpreter( + mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True + ) + res_mod = TRTModule(*interp.run(max_batch_size=batch_size)) + return res_mod + + +class TensorRTMinizerSetting(net_min_base._MinimizerSettingBase): + def __init__(self, explicit_batch_dimension: Any = True): + self.explicit_batch_dimension = explicit_batch_dimension + super(TensorRTMinizerSetting, self).__init__() + + +class TensorRTMinimizer(net_min_base._MinimizerBase): + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Tensors, + compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]], + settings: TensorRTMinizerSetting = TensorRTMinizerSetting(), + max_batch_size: Any = 2048, + lower_fn: Callable[[torch.fx.GraphModule, Tensors, Any], TRTModule] = lower_mod_default, + ): + self.lower_fn = lower_fn + self.max_batch_size = max_batch_size + super().__init__(module, sample_input, compare_fn, settings) + + def run_a(self, mod, inputs): + mod.eval() + with torch.no_grad(): + return mod(*inputs) + + def run_b(self, mod, inputs): + mod.eval() + try: + mod = self.lower_fn(mod, inputs, self.max_batch_size) + output = mod(*inputs) + except RuntimeError as e: + raise net_min_base.FxNetMinimizerRunFuncError( + f"Encounter an error when processing \n{mod.graph}\n {e}" + ) + else: + return output + + def get_nodes(self, start=None, end=None, enable_print=False): + nodes = self._collect_nodes(start, end) + if enable_print: + print(f"Nodes fetched from start {start} to end {end} as: {nodes}") + return nodes diff --git a/torch/fx/experimental/fx2trt/tools/trt_splitter.py b/torch/fx/experimental/fx2trt/tools/trt_splitter.py new file mode 100644 index 0000000..97af41a --- /dev/null +++ b/torch/fx/experimental/fx2trt/tools/trt_splitter.py @@ -0,0 +1,72 @@ +from typing import Iterable, Tuple + +import torch +import torch.fx.passes.splitter_base as splitter_base +from torch.fx.experimental.fx2trt.tools.trt_minimizer import TensorRTMinimizer +from torch.fx.experimental.fx2trt.fx2trt import ( + InputTensorSpec, + TRTModule, + TRTInterpreter, + CONVERTERS, +) +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.tools_common import Tensors + + +class TRTOperatorSupport(OperatorSupport): + def __init__(self): + self._support_dict = {} + for k in CONVERTERS.keys(): + self._support_dict[k] = None + + +class TRTSplitter(splitter_base._SplitterBase): + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Tuple[torch.Tensor], + operator_support: OperatorSupport = None, + settings: splitter_base._SplitterSettingBase = None, + ): + if not operator_support: + operator_support = TRTOperatorSupport() + if not settings: + settings = splitter_base._SplitterSettingBase() + + super().__init__(module, sample_input, operator_support, settings) + + def _lower_model_to_backend( + self, + mod: torch.fx.GraphModule, + inputs: Iterable[torch.Tensor] + ): + """ + Lower a GraphModule `mod` to TensorRT with `inputs`. + """ + # Current code for lowering is place-holder, subject to future change + # based on feeds model's actual status + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + engine, input_names, output_names = interp.run(*inputs) + return TRTModule(engine, input_names, output_names) + + def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors): + """ + This function serves the preview functionality in Splitter. When previewing + splitting result, if something wrong happens during lowering model to TensorRT + or running a TensorRT model, this function will be called to find any culprit + that is responsible for the error. + """ + # Since we don't care about accuracy here, we pass in a dummy compare function. + minimizer = TensorRTMinimizer(mod, inputs, lambda a, b, c: (1, True)) + minimizer.settings.traverse_method = "sequential" + minimizer.settings.find_all = True + culprits = minimizer.minimize() + + if len(culprits) == 0: + reports = "Unable to find a culprit!\n" + else: + reports = "Found some problematic nodes:\n" + for node in culprits: + reports += f"{node.format_node()}\n" + + return reports -- 2.7.4