[fx_acc] [fx2trt] add acc op mapper for argmin and converter for topk (#63823)
authorShiyan Deng <dsy842974287@fb.com>
Thu, 26 Aug 2021 20:06:46 +0000 (13:06 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 20:16:22 +0000 (13:16 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63823

Add mapper for `torch.argmin` which maps it to `acc_ops.flatten` (optional) + `acc_ops.topk` + `acc_ops.getitem` + `acc_ops.squeeze` (optional). This diff doesn't allow mapping if `dim=None && keepdim=True` in `torch.argmin`.

Add fx2trt converter for `acc_ops.topk`.

Test Plan:
buck test mode/opt glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_argmin
buck run mode/opt caffe2/torch/fb/fx2trt:test_topk

Reviewed By: jfix71

Differential Revision: D30501771

fbshipit-source-id: 0babc45e69bac5e61ff0b9b4dfb98940398e3e57

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
torch/fx/experimental/fx2trt/fx2trt.py
torch/fx/experimental/fx_acc/acc_ops.py

index 33a817d..ba370b2 100644 (file)
@@ -717,6 +717,7 @@ def acc_ops_squeeze(network, target, args, kwargs, name):
     # dim, which is a very rare case. For now we just claim not supporting dim=None.
     assert dim is not None, "We don't support dim=None right now."
 
+    dim = dim % (len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0))
     if network.has_implicit_batch_dimension:
         assert dim != 0, "We don't support squeeze batch dim when it's implicit."
         dim -= 1
@@ -796,6 +797,29 @@ def acc_ops_unsqueeze(network, target, args, kwargs, name):
     layer.name = name
     return layer.get_output(0)
 
+@tensorrt_converter(acc_ops.topk)
+def acc_ops_topk(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(f"topk received input {input_val} that is not part "
+                           "of the TensorRT region!")
+
+    if kwargs["sorted"] and kwargs["k"] != 1:
+        raise RuntimeError("Currently we don't support sorted=True in topk.")
+
+    if not network.has_implicit_batch_dimension and len(input_val.shape) <= 1:
+        raise RuntimeError("At least 2 dimensions are required for input to topk.")
+
+    num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
+    k = kwargs["k"]
+    dim = (kwargs["dim"] if kwargs["dim"] else -1) % num_dims
+    operation = trt.TopKOperation.MAX if kwargs["largest"] else trt.TopKOperation.MIN
+    layer = network.add_topk(
+        input_val, operation, k, get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension)
+    )
+    layer.name = name
+    return (layer.get_output(0), layer.get_output(1))
 
 @tensorrt_converter(acc_ops.adaptive_avg_pool2d)
 def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name):
index ede99fd..72497a7 100644 (file)
@@ -415,8 +415,6 @@ class TRTInterpreter(torch.fx.Interpreter):
             name = f"output{i}"
             output.name = name
             self.network.mark_output(output)
-            if self.fp16_mode:
+            if self.fp16_mode and output.dtype == trt.float32:
                 output.dtype = trt.float16
-            else:
-                output.dtype = trt.float32
             self._output_names.append(name)
index 95fffaa..692ca63 100644 (file)
@@ -705,6 +705,57 @@ def batch_norm(
 def layer_norm(*, input, normalized_shape, weight, bias, eps):
     return nn.functional.layer_norm(**locals())
 
+def argmin_max_mapper_impl(node: torch.fx.Node, largest: bool) -> torch.fx.Node:
+    """
+    Map torch.argmin or torch.argmax to acc_ops.flatten (depend on dim) + acc_ops.topk
+    + acc_ops.getitem + acc_ops.squeeze (depends on keepdim).
+    """
+    input_node = node.kwargs["input"]
+    dim = node.kwargs["dim"]
+    keepdim = node.kwargs["keepdim"]
+
+    if dim is None and keepdim:
+        raise RuntimeError("We currently don't support argmin/argmax with dim=None and keepdim=True")
+
+    with node.graph.inserting_before(node):
+        if dim is None:
+            flatten_kwargs = {"input": node.kwargs["input"], "start_dim": 0, "end_dim": -1}
+            flatten_node = node.graph.call_function(flatten, kwargs=flatten_kwargs)
+            flatten_node.meta["type"] = torch.Tensor
+            input_node = flatten_node
+            dim = -1
+
+        topk_kwargs = {"input": input_node, "k": 1, "dim": dim, "largest": largest, "sorted": False}
+        topk_node = node.graph.call_function(topk, kwargs=topk_kwargs)
+        # It's actually more like NamedTuple but tuple here should be fine.
+        topk_node.meta["type"] = tuple
+
+        getitem_kwargs = {"input": topk_node, "idx": 1}
+        getitem_node = node.graph.call_function(getitem, kwargs=getitem_kwargs)
+        getitem_node.meta["type"] = torch.Tensor
+        output_node = getitem_node
+
+        if not keepdim:
+            squeeze_kwargs = {"input": getitem_node, "dim": dim}
+            output_node = node.graph.call_function(squeeze, kwargs=squeeze_kwargs)
+
+        output_node.meta = node.meta.copy()
+        return output_node
+
+@register_custom_acc_mapper_fn(
+    op_and_target=("call_function", torch.argmin),
+    arg_replacement_tuples=[
+        ("input", "input"),
+        ("dim", "dim"),
+        ("keepdim", "keepdim"),
+    ],
+)
+def torch_argmin_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node:
+    """
+    Map torch.argmin to acc_ops.flatten (depend on dim) + acc_ops.topk + acc_ops.getitem
+    + acc_ops.squeeze (depends on keepdim).
+    """
+    return argmin_max_mapper_impl(node, largest=False)
 
 @register_custom_acc_mapper_fn(
     op_and_target=("call_method", "split"),