From f2c47cf4dbbdd0cafc1bd2118121c6eda3947f3f Mon Sep 17 00:00:00 2001 From: Harut Movsisyan Date: Fri, 27 Aug 2021 03:03:32 -0700 Subject: [PATCH] [Static Runtime] Out version for fmod (#64046) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64046 Test Plan: Confirm out variant is used: ``` > //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1 V0826 23:31:30.321382 193428 impl.cpp:1395] Switch to out variant for node: %4 : Tensor = aten::fmod(%a.1, %b.1) ``` Reviewed By: mikeiovine Differential Revision: D30581228 fbshipit-source-id: dfab9a16ff8afd40b29338037769f938f154bf74 --- benchmarks/static_runtime/test_scripts.h | 10 +++++++++ benchmarks/static_runtime/test_static_runtime.cc | 27 ++++++++++++++++++++++++ torch/csrc/jit/runtime/static/ops.cpp | 25 ++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index ecdd491..477b191 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -762,3 +762,13 @@ const std::string quantize_script = R"IR( %1249: Tensor = aten::dequantize(%1254) return (%1249) )IR"; + +const auto fmod_tensor = R"JIT( + def forward(self, a: Tensor, b: Tensor): + return torch.fmod(a, b).clone() +)JIT"; + +const auto fmod_scalar = R"JIT( + def forward(self, a: Tensor, b: int): + return torch.fmod(a, b).clone() +)JIT"; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 0d42024..bd213c7 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1230,3 +1230,30 @@ TEST(StaticRuntime, IndividualOps_VarStack) { testStaticRuntime(var_stack_script, args1, args2); } + +TEST(StaticRuntime, IndividualOps_FmodTensor) { + // fmod tensor version + auto a = at::randn({2, 3}); + auto b = at::randn({2, 3}); + std::vector args0{a, b}; + testStaticRuntime(fmod_tensor, args0); + + // check for dynamic shapes + auto c = at::randn({4, 3, 2}); + auto d = at::randn({4, 3, 2}); + std::vector args1{c, d}; + testStaticRuntime(fmod_tensor, args0, args1); +} + +TEST(StaticRuntime, IndividualOps_FmodScalar) { + auto a = at::randn({2, 3}); + + // fmod scalar version + std::vector args2{a, 3}; + testStaticRuntime(fmod_scalar, args2); + + // check for dynamic shapes + auto c = at::randn({4, 3, 2}); + std::vector args3{c, 4}; + testStaticRuntime(fmod_scalar, args2, args3); +} diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 7e78b77..36f796f 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1611,6 +1611,31 @@ REGISTER_OPERATOR_FUNCTOR(aten::linear, aten_linear, [](Node* n) -> SROperator { }; }); +REGISTER_OPERATOR_FUNCTOR(aten::fmod, aten_fmod, [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor")) && + !n->matches(torch::schema( + "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& in0_t = p_node->Input(0).toTensor(); + const auto& in1_t = p_node->Input(1).isTensor() + ? p_node->Input(1).toTensor() + : at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar()); + + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::cpu::fmod(in0_t, in1_t); + } else { + auto& out_t = p_node->Output(0).toTensor(); + fastResizeToZero(out_t); + + at::cpu::fmod_out(out_t, in0_t, in1_t); + } + }; +}); + namespace { void check_cat_no_zero_dim(const std::vector& tensors) { -- 2.7.4