--- /dev/null
+from typing import Tuple, Callable, Any
+
+import torch
+import torch.fx.passes.net_min_base as net_min_base
+from torch.fx.experimental.fx2trt.fx2trt import (
+ TRTModule,
+ TRTInterpreter,
+ InputTensorSpec,
+)
+from torch.fx.passes.tools_common import Tensors
+
+
+def lower_mod_default(
+ mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048
+) -> TRTModule:
+ interp = TRTInterpreter(
+ mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
+ )
+ res_mod = TRTModule(*interp.run(max_batch_size=batch_size))
+ return res_mod
+
+
+class TensorRTMinizerSetting(net_min_base._MinimizerSettingBase):
+ def __init__(self, explicit_batch_dimension: Any = True):
+ self.explicit_batch_dimension = explicit_batch_dimension
+ super(TensorRTMinizerSetting, self).__init__()
+
+
+class TensorRTMinimizer(net_min_base._MinimizerBase):
+ def __init__(
+ self,
+ module: torch.fx.GraphModule,
+ sample_input: Tensors,
+ compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]],
+ settings: TensorRTMinizerSetting = TensorRTMinizerSetting(),
+ max_batch_size: Any = 2048,
+ lower_fn: Callable[[torch.fx.GraphModule, Tensors, Any], TRTModule] = lower_mod_default,
+ ):
+ self.lower_fn = lower_fn
+ self.max_batch_size = max_batch_size
+ super().__init__(module, sample_input, compare_fn, settings)
+
+ def run_a(self, mod, inputs):
+ mod.eval()
+ with torch.no_grad():
+ return mod(*inputs)
+
+ def run_b(self, mod, inputs):
+ mod.eval()
+ try:
+ mod = self.lower_fn(mod, inputs, self.max_batch_size)
+ output = mod(*inputs)
+ except RuntimeError as e:
+ raise net_min_base.FxNetMinimizerRunFuncError(
+ f"Encounter an error when processing \n{mod.graph}\n {e}"
+ )
+ else:
+ return output
+
+ def get_nodes(self, start=None, end=None, enable_print=False):
+ nodes = self._collect_nodes(start, end)
+ if enable_print:
+ print(f"Nodes fetched from start {start} to end {end} as: {nodes}")
+ return nodes
--- /dev/null
+from typing import Iterable, Tuple
+
+import torch
+import torch.fx.passes.splitter_base as splitter_base
+from torch.fx.experimental.fx2trt.tools.trt_minimizer import TensorRTMinimizer
+from torch.fx.experimental.fx2trt.fx2trt import (
+ InputTensorSpec,
+ TRTModule,
+ TRTInterpreter,
+ CONVERTERS,
+)
+from torch.fx.passes.operator_support import OperatorSupport
+from torch.fx.passes.tools_common import Tensors
+
+
+class TRTOperatorSupport(OperatorSupport):
+ def __init__(self):
+ self._support_dict = {}
+ for k in CONVERTERS.keys():
+ self._support_dict[k] = None
+
+
+class TRTSplitter(splitter_base._SplitterBase):
+ def __init__(
+ self,
+ module: torch.fx.GraphModule,
+ sample_input: Tuple[torch.Tensor],
+ operator_support: OperatorSupport = None,
+ settings: splitter_base._SplitterSettingBase = None,
+ ):
+ if not operator_support:
+ operator_support = TRTOperatorSupport()
+ if not settings:
+ settings = splitter_base._SplitterSettingBase()
+
+ super().__init__(module, sample_input, operator_support, settings)
+
+ def _lower_model_to_backend(
+ self,
+ mod: torch.fx.GraphModule,
+ inputs: Iterable[torch.Tensor]
+ ):
+ """
+ Lower a GraphModule `mod` to TensorRT with `inputs`.
+ """
+ # Current code for lowering is place-holder, subject to future change
+ # based on feeds model's actual status
+ interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
+ engine, input_names, output_names = interp.run(*inputs)
+ return TRTModule(engine, input_names, output_names)
+
+ def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors):
+ """
+ This function serves the preview functionality in Splitter. When previewing
+ splitting result, if something wrong happens during lowering model to TensorRT
+ or running a TensorRT model, this function will be called to find any culprit
+ that is responsible for the error.
+ """
+ # Since we don't care about accuracy here, we pass in a dummy compare function.
+ minimizer = TensorRTMinimizer(mod, inputs, lambda a, b, c: (1, True))
+ minimizer.settings.traverse_method = "sequential"
+ minimizer.settings.find_all = True
+ culprits = minimizer.minimize()
+
+ if len(culprits) == 0:
+ reports = "Unable to find a culprit!\n"
+ else:
+ reports = "Found some problematic nodes:\n"
+ for node in culprits:
+ reports += f"{node.format_node()}\n"
+
+ return reports