[Static Runtime] Out version for torch.linalg.norm (#64070)
authorHarut Movsisyan <harutm@fb.com>
Mon, 30 Aug 2021 03:58:45 +0000 (20:58 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 04:00:11 +0000 (21:00 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64070

Test Plan:
Confirm out variant is called for both versions:

```
> buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1
```

Reviewed By: d1jang

Differential Revision: D30595816

fbshipit-source-id: e88d88d4fc698774e83a98efce66b8fa4e281563

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/ops.cpp

index bcc975b..004319c 100644 (file)
@@ -780,3 +780,13 @@ const std::string embedding_bag_byte_prepack_script = R"IR(
       %res: Tensor = aten::clone(%output, %none)
       return (%res)
 )IR";
+
+const auto linalg_norm_ord_scalar = R"JIT(
+  def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int):
+      return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
+)JIT";
+
+const auto linalg_norm_ord_str = R"JIT(
+  def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int):
+      return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
+)JIT";
index 1e987a9..f6e3680 100644 (file)
@@ -1265,3 +1265,29 @@ TEST(StaticRuntime, QEmbeddingBagByteUnpack) {
   testStaticRuntime(embedding_bag_byte_prepack_script, {a});
   testStaticRuntime(embedding_bag_byte_prepack_script, {a},{b});
 }
+
+TEST(StaticRuntime, IndividualOps_LinalgNorm_ScalarOrd) {
+  auto a = at::randn({2, 3});
+  auto dim = std::vector<int64_t>({1});
+  auto dtype = at::ScalarType::Float;
+
+  std::vector<IValue> args0{a, 4, dim, true, dtype};
+  testStaticRuntime(linalg_norm_ord_scalar, args0);
+
+  auto b = at::randn({4, 5});
+  std::vector<IValue> args1{b, 4, dim, true, dtype};
+  testStaticRuntime(linalg_norm_ord_scalar, args0, args1);
+}
+
+TEST(StaticRuntime, IndividualOps_LinalgNorm_StringOrd) {
+  auto a = at::randn({2, 3});
+  auto dim = std::vector<int64_t>({0, 1});
+  auto dtype = at::ScalarType::Float;
+
+  std::vector<IValue> args0{a, "fro", dim, true, dtype};
+  testStaticRuntime(linalg_norm_ord_str, args0);
+
+  auto b = at::randn({4, 5});
+  std::vector<IValue> args1{b, "fro", dim, true, dtype};
+  testStaticRuntime(linalg_norm_ord_str, args0, args1);
+}
index 3b58668..1233930 100644 (file)
@@ -1666,6 +1666,53 @@ REGISTER_OPERATOR_FUNCTOR(aten::fmod, aten_fmod, [](Node* n) -> SROperator {
   };
 });
 
+REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SROperator {
+  if (!n->matches(torch::schema(
+          "aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor")) &&
+      !n->matches(torch::schema(
+          "aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
+    LogAndDumpSchema(n);
+    return nullptr;
+  }
+  return [](ProcessedNode* p_node) {
+    const auto& input = p_node->Input(0).toTensor();
+    const auto dim = p_node->Input(2).toIntVector();
+    const auto keepdim = p_node->Input(3).toBool();
+    const auto dtype = p_node->Input(4).toOptional<c10::ScalarType>();
+
+    if (p_node->Output(0).isNone()) {
+      if (p_node->Input(1).isScalar()) {
+        p_node->Output(0) = at::native::linalg_norm(
+            input,
+            p_node->Input(1).toOptional<at::Scalar>(),
+            dim,
+            keepdim,
+            dtype);
+      } else {
+        p_node->Output(0) = at::native::linalg_norm(
+            input, p_node->Input(1).toStringView(), dim, keepdim, dtype);
+      }
+      return;
+    }
+
+    auto& output = p_node->Output(0).toTensor();
+    fastResizeToZero(output);
+
+    if (p_node->Input(1).isScalar()) {
+      at::native::linalg_norm_out(
+          input,
+          p_node->Input(1).toOptional<at::Scalar>(),
+          dim,
+          keepdim,
+          dtype,
+          output);
+    } else {
+      at::native::linalg_norm_out(
+          input, p_node->Input(1).toStringRef(), dim, keepdim, dtype, output);
+    }
+  };
+});
+
 namespace {
 
 void check_cat_no_zero_dim(const std::vector<at::Tensor>& tensors) {