[Static Runtime] Out version for fmod (#64046)
authorHarut Movsisyan <harutm@fb.com>
Fri, 27 Aug 2021 10:03:32 +0000 (03:03 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 10:05:06 +0000 (03:05 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64046

Test Plan:
Confirm out variant is used:
```
> //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1

V0826 23:31:30.321382 193428 impl.cpp:1395] Switch to out variant for node: %4 : Tensor = aten::fmod(%a.1, %b.1)
```

Reviewed By: mikeiovine

Differential Revision: D30581228

fbshipit-source-id: dfab9a16ff8afd40b29338037769f938f154bf74

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/ops.cpp

index ecdd491..477b191 100644 (file)
@@ -762,3 +762,13 @@ const std::string quantize_script = R"IR(
       %1249: Tensor = aten::dequantize(%1254)
       return (%1249)
 )IR";
+
+const auto fmod_tensor = R"JIT(
+  def forward(self, a: Tensor, b: Tensor):
+      return torch.fmod(a, b).clone()
+)JIT";
+
+const auto fmod_scalar = R"JIT(
+  def forward(self, a: Tensor, b: int):
+      return torch.fmod(a, b).clone()
+)JIT";
index 0d42024..bd213c7 100644 (file)
@@ -1230,3 +1230,30 @@ TEST(StaticRuntime, IndividualOps_VarStack) {
 
   testStaticRuntime(var_stack_script, args1, args2);
 }
+
+TEST(StaticRuntime, IndividualOps_FmodTensor) {
+  // fmod tensor version
+  auto a = at::randn({2, 3});
+  auto b = at::randn({2, 3});
+  std::vector<IValue> args0{a, b};
+  testStaticRuntime(fmod_tensor, args0);
+
+  // check for dynamic shapes
+  auto c = at::randn({4, 3, 2});
+  auto d = at::randn({4, 3, 2});
+  std::vector<IValue> args1{c, d};
+  testStaticRuntime(fmod_tensor, args0, args1);
+}
+
+TEST(StaticRuntime, IndividualOps_FmodScalar) {
+  auto a = at::randn({2, 3});
+
+  // fmod scalar version
+  std::vector<IValue> args2{a, 3};
+  testStaticRuntime(fmod_scalar, args2);
+
+  // check for dynamic shapes
+  auto c = at::randn({4, 3, 2});
+  std::vector<IValue> args3{c, 4};
+  testStaticRuntime(fmod_scalar, args2, args3);
+}
index 7e78b77..36f796f 100644 (file)
@@ -1611,6 +1611,31 @@ REGISTER_OPERATOR_FUNCTOR(aten::linear, aten_linear, [](Node* n) -> SROperator {
   };
 });
 
+REGISTER_OPERATOR_FUNCTOR(aten::fmod, aten_fmod, [](Node* n) -> SROperator {
+  if (!n->matches(torch::schema(
+          "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor")) &&
+      !n->matches(torch::schema(
+          "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor"))) {
+    LogAndDumpSchema(n);
+    return nullptr;
+  }
+  return [](ProcessedNode* p_node) {
+    const auto& in0_t = p_node->Input(0).toTensor();
+    const auto& in1_t = p_node->Input(1).isTensor()
+        ? p_node->Input(1).toTensor()
+        : at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar());
+
+    if (p_node->Output(0).isNone()) {
+      p_node->Output(0) = at::cpu::fmod(in0_t, in1_t);
+    } else {
+      auto& out_t = p_node->Output(0).toTensor();
+      fastResizeToZero(out_t);
+
+      at::cpu::fmod_out(out_t, in0_t, in1_t);
+    }
+  };
+});
+
 namespace {
 
 void check_cat_no_zero_dim(const std::vector<at::Tensor>& tensors) {