From 7cfbc85821e8928db570a0730437b96484ac7b60 Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Thu, 26 Aug 2021 13:06:46 -0700 Subject: [PATCH] [fx_acc] [fx2trt] add acc op mapper for argmin and converter for topk (#63823) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63823 Add mapper for `torch.argmin` which maps it to `acc_ops.flatten` (optional) + `acc_ops.topk` + `acc_ops.getitem` + `acc_ops.squeeze` (optional). This diff doesn't allow mapping if `dim=None && keepdim=True` in `torch.argmin`. Add fx2trt converter for `acc_ops.topk`. Test Plan: buck test mode/opt glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_argmin buck run mode/opt caffe2/torch/fb/fx2trt:test_topk Reviewed By: jfix71 Differential Revision: D30501771 fbshipit-source-id: 0babc45e69bac5e61ff0b9b4dfb98940398e3e57 --- .../fx2trt/converters/acc_ops_converters.py | 24 ++++++++++ torch/fx/experimental/fx2trt/fx2trt.py | 4 +- torch/fx/experimental/fx_acc/acc_ops.py | 51 ++++++++++++++++++++++ 3 files changed, 76 insertions(+), 3 deletions(-) diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index 33a817d..ba370b2 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -717,6 +717,7 @@ def acc_ops_squeeze(network, target, args, kwargs, name): # dim, which is a very rare case. For now we just claim not supporting dim=None. assert dim is not None, "We don't support dim=None right now." + dim = dim % (len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)) if network.has_implicit_batch_dimension: assert dim != 0, "We don't support squeeze batch dim when it's implicit." dim -= 1 @@ -796,6 +797,29 @@ def acc_ops_unsqueeze(network, target, args, kwargs, name): layer.name = name return layer.get_output(0) +@tensorrt_converter(acc_ops.topk) +def acc_ops_topk(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError(f"topk received input {input_val} that is not part " + "of the TensorRT region!") + + if kwargs["sorted"] and kwargs["k"] != 1: + raise RuntimeError("Currently we don't support sorted=True in topk.") + + if not network.has_implicit_batch_dimension and len(input_val.shape) <= 1: + raise RuntimeError("At least 2 dimensions are required for input to topk.") + + num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + k = kwargs["k"] + dim = (kwargs["dim"] if kwargs["dim"] else -1) % num_dims + operation = trt.TopKOperation.MAX if kwargs["largest"] else trt.TopKOperation.MIN + layer = network.add_topk( + input_val, operation, k, get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension) + ) + layer.name = name + return (layer.get_output(0), layer.get_output(1)) @tensorrt_converter(acc_ops.adaptive_avg_pool2d) def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name): diff --git a/torch/fx/experimental/fx2trt/fx2trt.py b/torch/fx/experimental/fx2trt/fx2trt.py index ede99fd..72497a7 100644 --- a/torch/fx/experimental/fx2trt/fx2trt.py +++ b/torch/fx/experimental/fx2trt/fx2trt.py @@ -415,8 +415,6 @@ class TRTInterpreter(torch.fx.Interpreter): name = f"output{i}" output.name = name self.network.mark_output(output) - if self.fp16_mode: + if self.fp16_mode and output.dtype == trt.float32: output.dtype = trt.float16 - else: - output.dtype = trt.float32 self._output_names.append(name) diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index 95fffaa..692ca63 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -705,6 +705,57 @@ def batch_norm( def layer_norm(*, input, normalized_shape, weight, bias, eps): return nn.functional.layer_norm(**locals()) +def argmin_max_mapper_impl(node: torch.fx.Node, largest: bool) -> torch.fx.Node: + """ + Map torch.argmin or torch.argmax to acc_ops.flatten (depend on dim) + acc_ops.topk + + acc_ops.getitem + acc_ops.squeeze (depends on keepdim). + """ + input_node = node.kwargs["input"] + dim = node.kwargs["dim"] + keepdim = node.kwargs["keepdim"] + + if dim is None and keepdim: + raise RuntimeError("We currently don't support argmin/argmax with dim=None and keepdim=True") + + with node.graph.inserting_before(node): + if dim is None: + flatten_kwargs = {"input": node.kwargs["input"], "start_dim": 0, "end_dim": -1} + flatten_node = node.graph.call_function(flatten, kwargs=flatten_kwargs) + flatten_node.meta["type"] = torch.Tensor + input_node = flatten_node + dim = -1 + + topk_kwargs = {"input": input_node, "k": 1, "dim": dim, "largest": largest, "sorted": False} + topk_node = node.graph.call_function(topk, kwargs=topk_kwargs) + # It's actually more like NamedTuple but tuple here should be fine. + topk_node.meta["type"] = tuple + + getitem_kwargs = {"input": topk_node, "idx": 1} + getitem_node = node.graph.call_function(getitem, kwargs=getitem_kwargs) + getitem_node.meta["type"] = torch.Tensor + output_node = getitem_node + + if not keepdim: + squeeze_kwargs = {"input": getitem_node, "dim": dim} + output_node = node.graph.call_function(squeeze, kwargs=squeeze_kwargs) + + output_node.meta = node.meta.copy() + return output_node + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.argmin), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim"), + ("keepdim", "keepdim"), + ], +) +def torch_argmin_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: + """ + Map torch.argmin to acc_ops.flatten (depend on dim) + acc_ops.topk + acc_ops.getitem + + acc_ops.squeeze (depends on keepdim). + """ + return argmin_max_mapper_impl(node, largest=False) @register_custom_acc_mapper_fn( op_and_target=("call_method", "split"), -- 2.7.4