From 195c60d84460d16311ad606e504eda17b795a820 Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Mon, 23 Aug 2021 18:17:20 -0700 Subject: [PATCH] [fx2trt] Add acc op and converter for torch.pow (#63795) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63795 att Test Plan: buck run mode/opt caffe2/torch/fb/fx2trt:test_binary_ops Reviewed By: jackm321, wushirong Differential Revision: D30492488 fbshipit-source-id: 6d615770567b13720316f06fd2f866ea2fdc2995 --- torch/fx/experimental/fx2trt/converters/acc_ops_converters.py | 5 +++++ torch/fx/experimental/fx_acc/acc_ops.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index eddb079..566359b 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -764,6 +764,11 @@ def acc_ops_mul(network, target, args, kwargs, name): network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.PROD, name ) +@tensorrt_converter(acc_ops.pow) +def acc_ops_pow(network, target, args, kwargs, name): + return add_binary_elementwise_layer( + network, kwargs["input"], kwargs["exponent"], trt.ElementWiseOperation.POW, name + ) @tensorrt_converter(acc_ops.min_two_tensors_input) def acc_ops_min_two_tensors_input(network, target, args, kwargs, name): diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index 0c0965a..95fffaa 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -496,6 +496,12 @@ def div(*, input, other): return input / other +@register_acc_op_mapping(op_and_target=("call_function", torch.pow)) +@register_acc_op +def pow(*, input, exponent): + return torch.pow(input, exponent) + + @register_acc_op_mapping(op_and_target=("call_function", nn.functional.relu)) @register_acc_op_mapping( op_and_target=("call_function", torch.relu), -- 2.7.4