[Static Runtime] aten::cat out version when it is not being replaced by prim::VarConc...
authorHarut Movsisyan <harutm@fb.com>
Mon, 30 Aug 2021 16:36:46 +0000 (09:36 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 16:42:38 +0000 (09:42 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64157

UseVariadicCat optimization is not applied to aten::cat if list input to the op can not be moved to the position before op (https://fburl.com/diffusion/l6kweimu). For these cases we will need out version for SR.

Test Plan:
Confirm out variant is called:
```
> buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1
```

Reviewed By: d1jang

Differential Revision: D30598574

fbshipit-source-id: 74cfa8291dc8b5df4aef58adfb1ab2a16f10d90a

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/ops.cpp

index 004319c..7fdb113 100644 (file)
@@ -790,3 +790,14 @@ const auto linalg_norm_ord_str = R"JIT(
   def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int):
       return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
 )JIT";
+
+const std::string cat_script = R"IR(
+  graph(%a: Tensor, %b: Tensor, %dim: int):
+      %ten_list: Tensor[] = prim::ListConstruct(%a, %b)
+      %1 : int = prim::Constant[value=0]()
+      %2 : int = prim::Constant[value=1]()
+      %3 : int = prim::Constant[value=1]()
+      %ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3)
+      %ret: Tensor = aten::cat(%ten_list2, %dim)
+      return (%ret)
+)IR";
index f6e3680..b7201ba 100644 (file)
@@ -1,5 +1,6 @@
 #include <gtest/gtest.h>
 #include <torch/csrc/jit/ir/alias_analysis.h>
+#include <torch/csrc/jit/ir/irparser.h>
 #include <torch/csrc/jit/runtime/static/fusion.h>
 #include <torch/csrc/jit/runtime/static/impl.h>
 #include <torch/csrc/jit/runtime/static/passes.h>
@@ -1291,3 +1292,22 @@ TEST(StaticRuntime, IndividualOps_LinalgNorm_StringOrd) {
   std::vector<IValue> args1{b, "fro", dim, true, dtype};
   testStaticRuntime(linalg_norm_ord_str, args0, args1);
 }
+
+TEST(StaticRuntime, IndividualOps_Cat) {
+  auto graph = std::make_shared<Graph>();
+  std::unordered_map<std::string, Value*> vmap;
+  parseIR(cat_script, graph.get(), vmap);
+  torch::jit::StaticModule smodule(graph);
+  ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat"));
+
+  auto a = at::randn({2, 4});
+  auto b = at::randn({3, 4});
+  std::vector<IValue> args0{a, b, 0};
+
+  testStaticRuntime(cat_script, args0);
+
+  auto c = at::randn({3, 4});
+  auto d = at::randn({3, 5});
+  std::vector<IValue> args1{c, d, 1};
+  testStaticRuntime(cat_script, args0, args1);
+}
index 1233930..cf91f33 100644 (file)
@@ -1713,6 +1713,26 @@ REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SR
   };
 });
 
+REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
+  if (!n->matches(
+          torch::schema("aten::cat(Tensor[] tensors, int dim=0) -> Tensor"))) {
+    LogAndDumpSchema(n);
+    return nullptr;
+  }
+  return [](ProcessedNode* p_node) {
+    const auto inputs = p_node->Input(0).toTensorVector();
+    const auto dim = p_node->Input(1).toInt();
+    if (p_node->Output(0).isNone()) {
+      p_node->Output(0) = at::native::_cat_cpu(inputs, dim);
+      return;
+    }
+
+    auto& output = p_node->Output(0).toTensor();
+    fastResizeToZero(output);
+    at::native::_cat_out_cpu(inputs, dim, output);
+  };
+});
+
 namespace {
 
 void check_cat_no_zero_dim(const std::vector<at::Tensor>& tensors) {