From b45ec163106f5b801ebfe306ae659d85efc3432b Mon Sep 17 00:00:00 2001 From: Protonu Basu Date: Mon, 20 Sep 2021 14:29:30 -0700 Subject: [PATCH] Add support to lower acc_ops.transpose (#65036) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65036 Reviewed By: jfix71, 842974287 Differential Revision: D30934503 fbshipit-source-id: 51880d3d36492f5206f77c9d1a994d8532597b62 --- .../fx2trt/converters/acc_ops_converters.py | 17 ---- torch/fx/experimental/fx_acc/acc_ops.py | 104 +++++++++++++++------ 2 files changed, 75 insertions(+), 46 deletions(-) diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index 47fe26d..c83cb37 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -1403,23 +1403,6 @@ def acc_ops_cat(network, target, args, kwargs, name): return layer.get_output(0) -@tensorrt_converter(acc_ops.transpose) -def acc_ops_transpose(network, target, args, kwargs, name): - input_val, dim_0, dim_1 = kwargs["input"], kwargs["dim0"], kwargs["dim1"] - - # TODO: Remove this after enabling const folding in fx_acc - if isinstance(input_val, torch.Tensor): - return input_val.transpose(dim_0, dim_1).contiguous() - - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"transpose received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_transpose_layer(network, input_val, dim_0, dim_1, name) - - @tensorrt_converter(acc_ops.matmul) def acc_ops_matmul(network, target, args, kwargs, name): input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input") diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index c5a3b09..7590763 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -2,6 +2,8 @@ import operator import torch # isort:skip +from typing import Sequence, Optional, List, cast + import torch.fx.experimental.fx_acc.acc_utils as acc_utils import torch.nn as nn from torch.fx.experimental.fx_acc.acc_normalizer import ( @@ -11,8 +13,6 @@ from torch.fx.experimental.fx_acc.acc_normalizer import ( ) from torch.fx.passes.shape_prop import _extract_tensor_metadata -from typing import Sequence, Optional, List - this_arg_is_optional = True @@ -227,13 +227,47 @@ def cat(*, tensors, dim): return torch.cat(**locals()) -@register_acc_op_mapping(op_and_target=("call_function", torch.transpose)) -@register_acc_op_mapping(op_and_target=("call_method", "transpose")) -@register_acc_op -def transpose(*, input, dim0, dim1): - if input.dim() < 2: - return input - return torch.transpose(**locals()) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.transpose), + arg_replacement_tuples=[ + ("input", "input"), + ("dim0", "dim0"), + ("dim1", "dim1"), + ], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "transpose"), + arg_replacement_tuples=[ + ("input", "input"), + ("dim0", "dim0"), + ("dim1", "dim1"), + ], +) +def transpose_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: + # Get the dim-permutation/shuffle + shape_as_list = node.meta["tensor_meta"].shape + ranks = len(shape_as_list) + shuffle = list(i for i in range(ranks)) + dim0 = cast(int, node.kwargs["dim0"]) + dim1 = cast(int, node.kwargs["dim1"]) + shuffle[dim0] = dim1 + shuffle[dim1] = dim0 + + # Create the new acc_ops.permute node. Update all uses of the transpose + # node and then delete the transpose node. + with node.graph.inserting_after(node): + permute_node = node.graph.call_function( + the_function=permute, + kwargs={ + "input": node.kwargs.get("input"), + "permutation": shuffle, + }, + ) + permute_node.meta = node.meta.copy() + node.replace_all_uses_with(permute_node) + + permute_node.graph.erase_node(node) + return permute_node @register_acc_op_mapping(op_and_target=("call_method", "contiguous")) @@ -311,12 +345,14 @@ def addmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: ], ) def t_mapper(node: torch.fx.Node, _: nn.Module): + ranks = len(node.meta["tensor_meta"].shape) + shuffle = [1, 0] if (ranks > 1) else [0] + with node.graph.inserting_before(node): new_node = node.graph.create_node( "call_function", - transpose, - kwargs={"input": node.kwargs["input"], "dim0": 0, "dim1": 1}, - name=node.name, + permute, + kwargs={"input": node.kwargs["input"], "permutation": shuffle}, ) new_node.meta = node.meta.copy() return new_node @@ -371,7 +407,6 @@ def dropout_mapper(node: torch.fx.Node, mod: nn.Module): """ return node.kwargs["input"] - @register_acc_op_mapping( op_and_target=("call_function", torch.ops.quantized.add), arg_replacement_tuples=[ @@ -546,8 +581,6 @@ def sum(*, input, dim=None, keepdim=False, dtype=None): return input.sum(dtype=dtype) - - @register_custom_acc_mapper_fn( op_and_target=("call_method", "max"), arg_replacement_tuples=[ @@ -580,7 +613,9 @@ def sum(*, input, dim=None, keepdim=False, dtype=None): ("keepdim", "keepdim", this_arg_is_optional), ], ) -def add_maximum_minimum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node: +def add_maximum_minimum_mapper( + node: torch.fx.Node, mod: torch.fx.GraphModule +) -> torch.fx.Node: # there are effectively three versions of torch.max / torch.min # full reduce: torch.max(input) -> Tensor # dimensional reduce: torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) @@ -591,33 +626,37 @@ def add_maximum_minimum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) - # to lookup the right function def target_map(op, target): if (op, target) in (("call_method", "max"), ("call_function", torch.max)): - return dict(full_reduce=max_full_reduce, - dim_reduce=max_dim_reduce, - elementwise=maximum) + return dict( + full_reduce=max_full_reduce, + dim_reduce=max_dim_reduce, + elementwise=maximum, + ) elif (op, target) in (("call_method", "min"), ("call_function", torch.min)): - return dict(full_reduce=min_full_reduce, - dim_reduce=min_dim_reduce, - elementwise=minimum) + return dict( + full_reduce=min_full_reduce, + dim_reduce=min_dim_reduce, + elementwise=minimum, + ) with node.graph.inserting_before(node): new_targets = target_map(node.op, node.target) max_kwargs = dict() - max_kwargs['input'] = node.kwargs['input'] - if ("dim_or_other" not in node.kwargs) or (node.kwargs['dim_or_other'] is None): + max_kwargs["input"] = node.kwargs["input"] + if ("dim_or_other" not in node.kwargs) or (node.kwargs["dim_or_other"] is None): nt = new_targets["full_reduce"] max_node = node.graph.call_function(nt, kwargs=max_kwargs) - elif isinstance(node.kwargs['dim_or_other'], int): + elif isinstance(node.kwargs["dim_or_other"], int): nt = new_targets["dim_reduce"] dim = node.kwargs["dim_or_other"] - max_kwargs['dim'] = dim - max_kwargs['keepdim'] = node.kwargs.get('keepdim', False) + max_kwargs["dim"] = dim + max_kwargs["keepdim"] = node.kwargs.get("keepdim", False) max_node = node.graph.call_function(nt, kwargs=max_kwargs) else: other = node.kwargs["dim_or_other"] assert isinstance(other, torch.fx.Node) # Lowering path for when provided "other", where we do elem-wise max nt = new_targets["elementwise"] - max_kwargs['other'] = other + max_kwargs["other"] = other max_node = node.graph.call_function(nt, kwargs=max_kwargs) max_node.meta = node.meta.copy() return max_node @@ -627,24 +666,29 @@ def add_maximum_minimum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) - def max_full_reduce(*, input): return torch.max(**locals()) + @register_acc_op def max_dim_reduce(*, input, dim=None, keepdim=False): return torch.max(**locals()) + @register_acc_op_mapping(op_and_target=("call_function", torch.maximum)) @register_acc_op_mapping(op_and_target=("call_method", "maximum")) @register_acc_op def maximum(*, input, other): return torch.maximum(**locals()) + @register_acc_op def min_full_reduce(*, input): return torch.min(input) + @register_acc_op def min_dim_reduce(*, input, dim=None, keepdim=False): return torch.min(input, dim=dim, keepdim=keepdim) + @register_acc_op_mapping(op_and_target=("call_function", torch.minimum)) @register_acc_op_mapping(op_and_target=("call_method", "minimum")) @register_acc_op @@ -1092,7 +1136,9 @@ def slice_tensor(*, input, dims, starts, stops, steps): ], ) def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: - assert isinstance(node.kwargs["start"], int) and isinstance(node.kwargs["length"], int) + assert isinstance(node.kwargs["start"], int) and isinstance( + node.kwargs["length"], int + ) kwargs = { "input": node.kwargs["input"], "dims": (node.kwargs["dim"],), -- 2.7.4