From e24c3644d87acfb0359cb14bde4afcd62a9255ba Mon Sep 17 00:00:00 2001 From: Harut Movsisyan Date: Mon, 30 Aug 2021 09:36:46 -0700 Subject: [PATCH] [Static Runtime] aten::cat out version when it is not being replaced by prim::VarConcat (#64157) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64157 UseVariadicCat optimization is not applied to aten::cat if list input to the op can not be moved to the position before op (https://fburl.com/diffusion/l6kweimu). For these cases we will need out version for SR. Test Plan: Confirm out variant is called: ``` > buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1 ``` Reviewed By: d1jang Differential Revision: D30598574 fbshipit-source-id: 74cfa8291dc8b5df4aef58adfb1ab2a16f10d90a --- benchmarks/static_runtime/test_scripts.h | 11 +++++++++++ benchmarks/static_runtime/test_static_runtime.cc | 20 ++++++++++++++++++++ torch/csrc/jit/runtime/static/ops.cpp | 20 ++++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 004319c..7fdb113 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -790,3 +790,14 @@ const auto linalg_norm_ord_str = R"JIT( def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int): return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone() )JIT"; + +const std::string cat_script = R"IR( + graph(%a: Tensor, %b: Tensor, %dim: int): + %ten_list: Tensor[] = prim::ListConstruct(%a, %b) + %1 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=1]() + %ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3) + %ret: Tensor = aten::cat(%ten_list2, %dim) + return (%ret) +)IR"; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index f6e3680..b7201ba 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -1291,3 +1292,22 @@ TEST(StaticRuntime, IndividualOps_LinalgNorm_StringOrd) { std::vector args1{b, "fro", dim, true, dtype}; testStaticRuntime(linalg_norm_ord_str, args0, args1); } + +TEST(StaticRuntime, IndividualOps_Cat) { + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR(cat_script, graph.get(), vmap); + torch::jit::StaticModule smodule(graph); + ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat")); + + auto a = at::randn({2, 4}); + auto b = at::randn({3, 4}); + std::vector args0{a, b, 0}; + + testStaticRuntime(cat_script, args0); + + auto c = at::randn({3, 4}); + auto d = at::randn({3, 5}); + std::vector args1{c, d, 1}; + testStaticRuntime(cat_script, args0, args1); +} diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 1233930..cf91f33 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1713,6 +1713,26 @@ REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SR }; }); +REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator { + if (!n->matches( + torch::schema("aten::cat(Tensor[] tensors, int dim=0) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto inputs = p_node->Input(0).toTensorVector(); + const auto dim = p_node->Input(1).toInt(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::_cat_cpu(inputs, dim); + return; + } + + auto& output = p_node->Output(0).toTensor(); + fastResizeToZero(output); + at::native::_cat_out_cpu(inputs, dim, output); + }; +}); + namespace { void check_cat_no_zero_dim(const std::vector& tensors) { -- 2.7.4