From fd09e564d67796bfc9b866e6101123840a33ed9a Mon Sep 17 00:00:00 2001 From: Emad El-Haraty Date: Mon, 13 Sep 2021 17:59:11 -0700 Subject: [PATCH] add acc_ops.max, acc_ops.maximum, consolidate acc_ops.min and acc_ops.minimum Summary: This diff adds `acc_ops.max` and `acc_ops.maximum` support. It further consolidates the logic for `acc_ops.min` and `acc_ops.minimum` to match the logic for max. torch.max has three behaviors: ```1. max(input) 2. max(input, dim, keepdim=False, *, out=None) 3. max(input, other, *, out=None) ``` Likewise, `torch.min` has three identical behaviors. I've chosen to implement each as an acc_op, then map to the appropriate one. the third max function is effectively `torch.maximum`, so I've implemented it as that. Reviewed By: yinghai, jfix71, 842974287 Differential Revision: D30551464 fbshipit-source-id: 0a2eec10e5185cbf7d9984eec3fd399b23528b2a --- .../fx2trt/converters/acc_ops_converters.py | 106 ++++++++++++-- torch/fx/experimental/fx_acc/acc_ops.py | 152 ++++++++++++++------- 2 files changed, 204 insertions(+), 54 deletions(-) diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index 8f93493..46a6a51 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -124,7 +124,6 @@ def broadcast(network, a, b, a_name, b_name, preset_diff=0): return a, b - def add_binary_elementwise_layer(network, lhs_val, rhs_val, op_type, name): lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs") rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs") @@ -691,6 +690,104 @@ def acc_ops_sum(network, target, args, kwargs, name): return layer.get_output(0) +def add_acc_ops_full_reduce(network, target, args, kwargs, name, reduce_op): + input_val = kwargs["input"] + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"max received input {input_val} that is not part " + "of the TensorRT region!" + ) + assert ( + not network.has_implicit_batch_dimension + ), "Do not support max over all the elements for implicit batch." + + dim = range(len(input_val.shape)) + + layer = network.add_reduce( + input_val, + reduce_op, + get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension), + False, + ) + layer.name = name + return layer.get_output(0) + +def add_acc_ops_dim_reduce(network, target, args, kwargs, name, reduce_op): + new_kwargs = kwargs.copy() + new_kwargs['k'] = 1 + + + if reduce_op == trt.ReduceOperation.MAX: + new_kwargs['largest'] = True + elif reduce_op == trt.ReduceOperation.MIN: + new_kwargs['largest'] = False + new_kwargs['sorted'] = False + + + (topk_out0, topk_out1) = acc_ops_topk(network, target, args, new_kwargs, name + "_topk") + + topk_out0.name = f"{name}_topk0" + topk_out1.name = f"{name}_topk1" + + if 'keepdim' in new_kwargs and new_kwargs['keepdim']: + return (topk_out0, topk_out1) + + dim = new_kwargs['dim'] + if network.has_implicit_batch_dimension: + assert dim != 0, "can't reduce on dim == 0 when network has implicit batch dimension" + # we remove the first dim in the shape tuple when it is implicit + dim -= 1 + input_val = topk_out0 + shape = input_val.shape + + output_shape = [] + for i, s in enumerate(shape): + if i == dim and s == 1: + continue + output_shape.append(s) + + shuffle_layer0 = network.add_shuffle(input_val) + shuffle_layer0.reshape_dims = tuple(output_shape) + shuffle_layer0.name = name + '_shuffle0' + + input_val = topk_out1 + shape = input_val.shape + + shuffle_layer1 = network.add_shuffle(input_val) + shuffle_layer1.reshape_dims = tuple(output_shape) + shuffle_layer1.name = name + '_shuffle1' + + + return (shuffle_layer0.get_output(0), shuffle_layer1.get_output(0)) + +@tensorrt_converter(acc_ops.max_full_reduce) +def acc_ops_max_full_reduce(network, target, args, kwargs, name): + return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX) + +@tensorrt_converter(acc_ops.min_full_reduce) +def acc_ops_min_full_reduce(network, target, args, kwargs, name): + return add_acc_ops_full_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN) + +@tensorrt_converter(acc_ops.max_dim_reduce) +def acc_ops_max_dim_reduce(network, target, args, kwargs, name): + return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MAX) + +@tensorrt_converter(acc_ops.min_dim_reduce) +def acc_ops_min_dim_reduce(network, target, args, kwargs, name): + return add_acc_ops_dim_reduce(network, target, args, kwargs, name, trt.ReduceOperation.MIN) + +@tensorrt_converter(acc_ops.maximum) +def acc_ops_maximum(network, target, args, kwargs, name): + return add_binary_elementwise_layer( + network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.MAX, name + ) + +@tensorrt_converter(acc_ops.minimum) +def acc_ops_minimum(network, target, args, kwargs, name): + return add_binary_elementwise_layer( + network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.MIN, name + ) + @tensorrt_converter(acc_ops.max_pool2d) def acc_ops_max_pool2d(network, target, args, kwargs, name): input_val = kwargs["input"] @@ -794,13 +891,6 @@ def acc_ops_pow(network, target, args, kwargs, name): network, kwargs["input"], kwargs["exponent"], trt.ElementWiseOperation.POW, name ) -@tensorrt_converter(acc_ops.min_two_tensors_input) -def acc_ops_min_two_tensors_input(network, target, args, kwargs, name): - return add_binary_elementwise_layer( - network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.MIN, name - ) - - @tensorrt_converter(acc_ops.unsqueeze) def acc_ops_unsqueeze(network, target, args, kwargs, name): input_val = kwargs["input"] diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index 2ffc993..1aa0d42 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -355,52 +355,6 @@ def matmul(*, input, other): return torch.matmul(**locals()) -@register_custom_acc_mapper_fn( - op_and_target=("call_function", torch.min), - arg_replacement_tuples=[ - ("input", "input"), - ("other", "other", this_arg_is_optional), - ("dim", "dim", this_arg_is_optional), - ("keepdim", "keepdim", this_arg_is_optional), - ], -) -def custom_torch_min_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: - """ - Add custom mapping for torch.min because torch.min has three input types, where each yields a different output type: - 1. torch.min(input); Output: tensor wih the minimum number - 2. torch.min(input, other); Output: tensor with coordinate-wise min value - 3[Not Supported] torch.min(input, dim, keepdim); Output:(min_values,and min_indices) - """ - with node.graph.inserting_before(node): - # If dim is in kwargs, assert "Not Supported" - assert "dim" not in node.kwargs, "Currently not support dim in torch.min" - - if "other" in node.kwargs and node.kwargs["other"] is not None: - # If kwargs[other] is a valid tensor, call min_two_tensors_input, - op_func = min_two_tensors_input - else: - # Otherwise, call min_single_tensor_input - op_func = min_single_tensor_input - - new_node = node.graph.create_node( - "call_function", - op_func, - kwargs=node.kwargs, - name=node.name, - ) - new_node.meta = node.meta - return new_node - - -@register_acc_op -def min_single_tensor_input(*, input): - return torch.min(input) - - -@register_acc_op -def min_two_tensors_input(*, input, other): - return torch.min(input, other) - @register_acc_op_mapping( op_and_target=("call_function", torch.ops.quantized.add), @@ -573,6 +527,112 @@ def sum(*, input, dim=None, keepdim=False, dtype=None): return input.sum(dtype=dtype) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "max"), + arg_replacement_tuples=[ + ("input", "input"), + (("dim", "other"), "dim_or_other", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.max), + arg_replacement_tuples=[ + ("input", "input"), + (("dim", "other"), "dim_or_other", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "min"), + arg_replacement_tuples=[ + ("input", "input"), + (("dim", "other"), "dim_or_other", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.min), + arg_replacement_tuples=[ + ("input", "input"), + (("dim", "other"), "dim_or_other", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ], +) +def add_maximum_minimum_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node: + # there are effectively three versions of torch.max / torch.min + # full reduce: torch.max(input) -> Tensor + # dimensional reduce: torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + # elementwise: torch.max(input, other, *, out=None) -> Tensor + + # the mapper function is remapping for both min and max situations + # this helper function makes the choices available clearer and provides an easier way + # to lookup the right function + def target_map(op, target): + if (op, target) in (("call_method", "max"), ("call_function", torch.max)): + return dict(full_reduce=max_full_reduce, + dim_reduce=max_dim_reduce, + elementwise=maximum) + elif (op, target) in (("call_method", "min"), ("call_function", torch.min)): + return dict(full_reduce=min_full_reduce, + dim_reduce=min_dim_reduce, + elementwise=minimum) + + with node.graph.inserting_before(node): + new_targets = target_map(node.op, node.target) + max_kwargs = dict() + max_kwargs['input'] = node.kwargs['input'] + if ("dim_or_other" not in node.kwargs) or (node.kwargs['dim_or_other'] is None): + nt = new_targets["full_reduce"] + max_node = node.graph.call_function(nt, kwargs=max_kwargs) + elif isinstance(node.kwargs['dim_or_other'], int): + nt = new_targets["dim_reduce"] + dim = node.kwargs["dim_or_other"] + max_kwargs['dim'] = dim + max_kwargs['keepdim'] = node.kwargs.get('keepdim', False) + max_node = node.graph.call_function(nt, kwargs=max_kwargs) + else: + other = node.kwargs["dim_or_other"] + assert isinstance(other, torch.fx.Node) + # Lowering path for when provided "other", where we do elem-wise max + nt = new_targets["elementwise"] + max_kwargs['other'] = other + max_node = node.graph.call_function(nt, kwargs=max_kwargs) + max_node.meta = node.meta.copy() + return max_node + + +@register_acc_op +def max_full_reduce(*, input): + return torch.max(**locals()) + +@register_acc_op +def max_dim_reduce(*, input, dim=None, keepdim=False): + return torch.max(**locals()) + +@register_acc_op_mapping(op_and_target=("call_function", torch.maximum)) +@register_acc_op_mapping(op_and_target=("call_method", "maximum")) +@register_acc_op +def maximum(*, input, other): + return torch.maximum(**locals()) + +@register_acc_op +def min_full_reduce(*, input): + return torch.min(input) + +@register_acc_op +def min_dim_reduce(*, input, dim=None, keepdim=False): + return torch.min(input, dim=dim, keepdim=keepdim) + +@register_acc_op_mapping(op_and_target=("call_function", torch.minimum)) +@register_acc_op_mapping(op_and_target=("call_method", "minimum")) +@register_acc_op +def minimum(*, input, other): + return torch.minimum(**locals()) + + @register_acc_op_mapping(op_and_target=("call_function", torch.sigmoid)) @register_acc_op_mapping(op_and_target=("call_method", "sigmoid")) @register_acc_op -- 2.7.4