From: Shiyan Deng Date: Tue, 31 Aug 2021 18:29:07 +0000 (-0700) Subject: [fx2trt] Add acc_ops.sign and converter for it (#63876) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~553 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=845bc89811f59822fe585cf44e774857adefcff7;p=platform%2Fupstream%2Fpytorch.git [fx2trt] Add acc_ops.sign and converter for it (#63876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63876 Add `acc_ops.sign` which maps from `torch.sign`. Add a plugin (not support dynamic shape currently) for `acc_ops.sign`. The plugin calls `at::sign` directly. Test Plan: buck test mode/opt -c python.package_style=inplace -c fbcode.nvcc_arch=a100 caffe2/torch/fb/fx2trt:test_unary_ops Reviewed By: yinghai Differential Revision: D30518081 fbshipit-source-id: a0b9e6c30deac0b04b8cb09a162579e229985330 --- diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index ba370b2..e101b6b 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -1098,7 +1098,6 @@ def acc_ops_clamp(network, target, args, kwargs, name): return input_val - @tensorrt_converter(acc_ops.tuple_construct) def acc_ops_tuple_construct(network, target, args, kwargs, name): return kwargs["tensors"] diff --git a/torch/fx/experimental/fx2trt/fx2trt.py b/torch/fx/experimental/fx2trt/fx2trt.py index 72497a7..f1d17e7 100644 --- a/torch/fx/experimental/fx2trt/fx2trt.py +++ b/torch/fx/experimental/fx2trt/fx2trt.py @@ -4,6 +4,7 @@ from typing import List, NamedTuple, Iterable, Any, Optional, Tuple import tensorrt as trt import torch import torch.fx +from torch.fx.node import _get_qualified_name # Borrowed from torch2trt @@ -226,14 +227,15 @@ class TRTInterpreter(torch.fx.Interpreter): else: self.network = self.builder.create_network() + missing_ops = self.validate_conversion() + if missing_ops: + warnings.warn("Interpretation will fail due to missing operations \n" + + "\n".join(f"{i}" for i in missing_ops)) + self.optimization_profiles: Optional[List] = None self.input_specs = input_specs self.input_specs_iter = 0 self.validate_input_specs() - missing_ops = self.validate_conversion - if not missing_ops: - warnings.warn("Interpretation may fail due to missing operations \n" - + "\n".join(f"{i}" for i in missing_ops)) self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = [] @@ -299,13 +301,15 @@ class TRTInterpreter(torch.fx.Interpreter): missing_converter = set() for node in self.module.graph.nodes: - if node.op in ["call_function", "call_method"] and not CONVERTERS.get(node.target): - missing_converter.add(f"{node.op} {node.target}") + if node.op == "call_function" and not CONVERTERS.get(node.target): + missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}") + elif node.op == "call_method" and not CONVERTERS.get(node.target): + missing_converter.add(f"{node.op} torch.Tensor.{node.target}") elif node.op == "call_module": submod = self.fetch_attr(node.target) submod_type = getattr(submod, "_base_class_origin", type(submod)) if not CONVERTERS.get(submod_type): - missing_converter.add(f"{node.op} {submod_type}") + missing_converter.add(f"{node.op} {torch.typename(submod_type)}") return missing_converter diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index 1b4b469..b10d35e 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -95,6 +95,12 @@ def avg_pool2d( return nn.functional.avg_pool2d(**locals()) +@register_acc_op_mapping(op_and_target=("call_function", torch.sign)) +@register_acc_op +def sign(*, input): + return torch.sign(input) + + @register_acc_op def size(*, input): return input.size()