From 2bb898e039d90d3897cc0fbd1886f2b2fe4dbcfb Mon Sep 17 00:00:00 2001 From: Jordan Fix Date: Wed, 15 Sep 2021 19:39:41 -0700 Subject: [PATCH] [acc_ops] Add support for torch variants of squeeze and mul (#65037) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65037 att Test Plan: updated unit tests Reviewed By: yuhc Differential Revision: D30952224 fbshipit-source-id: aaf75b27b4fc6c0436ba7bfcf324f761b900171b --- torch/fx/experimental/fx_acc/acc_ops.py | 61 ++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index 3d19a2c..fcd60c7 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -49,10 +49,14 @@ def flatten(*, input, start_dim=0, end_dim=-1): @register_acc_op_mapping( - op_and_target=( - "call_method", - "squeeze", - ), + op_and_target=("call_method", "squeeze"), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim", this_arg_is_optional), + ], +) +@register_acc_op_mapping( + op_and_target=("call_function", torch.squeeze), arg_replacement_tuples=[ ("input", "input"), ("dim", "dim", this_arg_is_optional), @@ -435,7 +439,7 @@ def quantize_per_tensor(*, input, acc_out_ty=None): @register_acc_op def dequantize(*, input, input_tensor_meta): - """ `input_tensor_meta` contains extra argument of quantization + """`input_tensor_meta` contains extra argument of quantization parameters, e.g. scale/zero_point and will be using for lowring dequantize op to TensorRT """ @@ -448,6 +452,7 @@ def sub(*, input, other): return input - other +@register_acc_op_mapping(op_and_target=("call_function", torch.mul)) @register_acc_op_mapping(op_and_target=("call_function", operator.mul)) @register_acc_op def mul(*, input, other): @@ -479,6 +484,7 @@ def pow(*, input, exponent): def relu(*, input, inplace=False): return nn.functional.relu(**locals()) + @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.log1p), arg_replacement_tuples=[ @@ -495,6 +501,7 @@ def torch_log1p_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node log_node.meta = node.meta.copy() return log_node + @register_custom_acc_mapper_fn( op_and_target=("call_method", "sum"), arg_replacement_tuples=[ @@ -775,6 +782,7 @@ def batch_norm( def layer_norm(*, input, normalized_shape, weight, bias, eps): return nn.functional.layer_norm(**locals()) + def argmin_max_mapper_impl(node: torch.fx.Node, largest: bool) -> torch.fx.Node: """ Map torch.argmin or torch.argmax to acc_ops.flatten (depend on dim) + acc_ops.topk @@ -785,17 +793,29 @@ def argmin_max_mapper_impl(node: torch.fx.Node, largest: bool) -> torch.fx.Node: keepdim = node.kwargs["keepdim"] if dim is None and keepdim: - raise RuntimeError("We currently don't support argmin/argmax with dim=None and keepdim=True") + raise RuntimeError( + "We currently don't support argmin/argmax with dim=None and keepdim=True" + ) with node.graph.inserting_before(node): if dim is None: - flatten_kwargs = {"input": node.kwargs["input"], "start_dim": 0, "end_dim": -1} + flatten_kwargs = { + "input": node.kwargs["input"], + "start_dim": 0, + "end_dim": -1, + } flatten_node = node.graph.call_function(flatten, kwargs=flatten_kwargs) flatten_node.meta["type"] = torch.Tensor input_node = flatten_node dim = -1 - topk_kwargs = {"input": input_node, "k": 1, "dim": dim, "largest": largest, "sorted": False} + topk_kwargs = { + "input": input_node, + "k": 1, + "dim": dim, + "largest": largest, + "sorted": False, + } topk_node = node.graph.call_function(topk, kwargs=topk_kwargs) # It's actually more like NamedTuple but tuple here should be fine. topk_node.meta["type"] = tuple @@ -812,6 +832,7 @@ def argmin_max_mapper_impl(node: torch.fx.Node, largest: bool) -> torch.fx.Node: output_node.meta = node.meta.copy() return output_node + @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.argmin), arg_replacement_tuples=[ @@ -827,6 +848,7 @@ def torch_argmin_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Nod """ return argmin_max_mapper_impl(node, largest=False) + @register_acc_op_mapping(op_and_target=("call_function", torch.linalg.norm)) @register_acc_op def linalg_norm(*, input, ord, dim, keepdim): @@ -974,6 +996,7 @@ def embedding_bag_byte_rowwise_offsets( ): return torch.ops.quantized.embedding_bag_byte_rowwise_offsets(**locals()) + @register_acc_op_mapping( op_and_target=( "call_function", @@ -1236,12 +1259,16 @@ def packed_quantized_linear_mapper( with node.graph.inserting_before(node): # Insert get_attr nodes for weight and bias get_weight = node.graph.get_attr(weight_name) - get_weight.meta["tensor_meta"] = _extract_tensor_metadata(linear_module.weight()) + get_weight.meta["tensor_meta"] = _extract_tensor_metadata( + linear_module.weight() + ) get_bias = None if linear_module.bias() is not None: get_bias = node.graph.get_attr(bias_name) - get_bias.meta["tensor_meta"] = _extract_tensor_metadata(linear_module.bias()) + get_bias.meta["tensor_meta"] = _extract_tensor_metadata( + linear_module.bias() + ) # Create kwargs for acc_op.quantized_linear kwargs = { @@ -1368,22 +1395,22 @@ def packed_quantized_convrelu2d_mapper( relu_node.meta = node.meta return relu_node + @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.dequantize), - arg_replacement_tuples=[ - ("input", "input") - ] + arg_replacement_tuples=[("input", "input")], ) @register_custom_acc_mapper_fn( op_and_target=("call_method", "dequantize"), - arg_replacement_tuples=[ - ("input", "input") - ] + arg_replacement_tuples=[("input", "input")], ) def custom_dequantize_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: assert isinstance(node.kwargs["input"], torch.fx.Node) assert "tensor_meta" in node.kwargs["input"].meta - new_kwargs = {"input": node.kwargs["input"], "input_tensor_meta": node.kwargs["input"].meta["tensor_meta"]} + new_kwargs = { + "input": node.kwargs["input"], + "input_tensor_meta": node.kwargs["input"].meta["tensor_meta"], + } # `input_tensor_meta` contains quantization parameters that can be used to lower # acc_ops.dequantize to TensorRT ops with node.graph.inserting_before(node): -- 2.7.4