[fx2trt] Add acc_ops.sign and converter for it (#63876)
authorShiyan Deng <dsy842974287@fb.com>
Tue, 31 Aug 2021 18:29:07 +0000 (11:29 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 18:31:53 +0000 (11:31 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63876

Add `acc_ops.sign` which maps from `torch.sign`.

Add a plugin (not support dynamic shape currently) for `acc_ops.sign`. The plugin calls `at::sign` directly.

Test Plan: buck test mode/opt -c python.package_style=inplace -c fbcode.nvcc_arch=a100 caffe2/torch/fb/fx2trt:test_unary_ops

Reviewed By: yinghai

Differential Revision: D30518081

fbshipit-source-id: a0b9e6c30deac0b04b8cb09a162579e229985330

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
torch/fx/experimental/fx2trt/fx2trt.py
torch/fx/experimental/fx_acc/acc_ops.py

index ba370b2..e101b6b 100644 (file)
@@ -1098,7 +1098,6 @@ def acc_ops_clamp(network, target, args, kwargs, name):
 
     return input_val
 
-
 @tensorrt_converter(acc_ops.tuple_construct)
 def acc_ops_tuple_construct(network, target, args, kwargs, name):
     return kwargs["tensors"]
index 72497a7..f1d17e7 100644 (file)
@@ -4,6 +4,7 @@ from typing import List, NamedTuple, Iterable, Any, Optional, Tuple
 import tensorrt as trt
 import torch
 import torch.fx
+from torch.fx.node import _get_qualified_name
 
 
 # Borrowed from torch2trt
@@ -226,14 +227,15 @@ class TRTInterpreter(torch.fx.Interpreter):
         else:
             self.network = self.builder.create_network()
 
+        missing_ops = self.validate_conversion()
+        if missing_ops:
+            warnings.warn("Interpretation will fail due to missing operations \n"
+                          + "\n".join(f"{i}" for i in missing_ops))
+
         self.optimization_profiles: Optional[List] = None
         self.input_specs = input_specs
         self.input_specs_iter = 0
         self.validate_input_specs()
-        missing_ops = self.validate_conversion
-        if not missing_ops:
-            warnings.warn("Interpretation may fail due to missing operations \n"
-                          + "\n".join(f"{i}" for i in missing_ops))
         self._cur_node_name: Optional[str] = None
         self._input_names: List[str] = []
         self._output_names: List[str] = []
@@ -299,13 +301,15 @@ class TRTInterpreter(torch.fx.Interpreter):
         missing_converter = set()
 
         for node in self.module.graph.nodes:
-            if node.op in ["call_function", "call_method"] and not CONVERTERS.get(node.target):
-                missing_converter.add(f"{node.op} {node.target}")
+            if node.op == "call_function" and not CONVERTERS.get(node.target):
+                missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}")
+            elif node.op == "call_method" and not CONVERTERS.get(node.target):
+                missing_converter.add(f"{node.op} torch.Tensor.{node.target}")
             elif node.op == "call_module":
                 submod = self.fetch_attr(node.target)
                 submod_type = getattr(submod, "_base_class_origin", type(submod))
                 if not CONVERTERS.get(submod_type):
-                    missing_converter.add(f"{node.op} {submod_type}")
+                    missing_converter.add(f"{node.op} {torch.typename(submod_type)}")
 
         return missing_converter
 
index 1b4b469..b10d35e 100644 (file)
@@ -95,6 +95,12 @@ def avg_pool2d(
     return nn.functional.avg_pool2d(**locals())
 
 
+@register_acc_op_mapping(op_and_target=("call_function", torch.sign))
+@register_acc_op
+def sign(*, input):
+    return torch.sign(input)
+
+
 @register_acc_op
 def size(*, input):
     return input.size()