From 8af1407eab140a3abf12ea99883fea529791883e Mon Sep 17 00:00:00 2001 From: Harut Movsisyan Date: Sun, 29 Aug 2021 20:58:45 -0700 Subject: [PATCH] [Static Runtime] Out version for torch.linalg.norm (#64070) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64070 Test Plan: Confirm out variant is called for both versions: ``` > buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1 ``` Reviewed By: d1jang Differential Revision: D30595816 fbshipit-source-id: e88d88d4fc698774e83a98efce66b8fa4e281563 --- benchmarks/static_runtime/test_scripts.h | 10 +++++ benchmarks/static_runtime/test_static_runtime.cc | 26 +++++++++++++ torch/csrc/jit/runtime/static/ops.cpp | 47 ++++++++++++++++++++++++ 3 files changed, 83 insertions(+) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index bcc975b..004319c 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -780,3 +780,13 @@ const std::string embedding_bag_byte_prepack_script = R"IR( %res: Tensor = aten::clone(%output, %none) return (%res) )IR"; + +const auto linalg_norm_ord_scalar = R"JIT( + def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int): + return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone() +)JIT"; + +const auto linalg_norm_ord_str = R"JIT( + def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int): + return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone() +)JIT"; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 1e987a9..f6e3680 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1265,3 +1265,29 @@ TEST(StaticRuntime, QEmbeddingBagByteUnpack) { testStaticRuntime(embedding_bag_byte_prepack_script, {a}); testStaticRuntime(embedding_bag_byte_prepack_script, {a},{b}); } + +TEST(StaticRuntime, IndividualOps_LinalgNorm_ScalarOrd) { + auto a = at::randn({2, 3}); + auto dim = std::vector({1}); + auto dtype = at::ScalarType::Float; + + std::vector args0{a, 4, dim, true, dtype}; + testStaticRuntime(linalg_norm_ord_scalar, args0); + + auto b = at::randn({4, 5}); + std::vector args1{b, 4, dim, true, dtype}; + testStaticRuntime(linalg_norm_ord_scalar, args0, args1); +} + +TEST(StaticRuntime, IndividualOps_LinalgNorm_StringOrd) { + auto a = at::randn({2, 3}); + auto dim = std::vector({0, 1}); + auto dtype = at::ScalarType::Float; + + std::vector args0{a, "fro", dim, true, dtype}; + testStaticRuntime(linalg_norm_ord_str, args0); + + auto b = at::randn({4, 5}); + std::vector args1{b, "fro", dim, true, dtype}; + testStaticRuntime(linalg_norm_ord_str, args0, args1); +} diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 3b58668..1233930 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1666,6 +1666,53 @@ REGISTER_OPERATOR_FUNCTOR(aten::fmod, aten_fmod, [](Node* n) -> SROperator { }; }); +REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor")) && + !n->matches(torch::schema( + "aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& input = p_node->Input(0).toTensor(); + const auto dim = p_node->Input(2).toIntVector(); + const auto keepdim = p_node->Input(3).toBool(); + const auto dtype = p_node->Input(4).toOptional(); + + if (p_node->Output(0).isNone()) { + if (p_node->Input(1).isScalar()) { + p_node->Output(0) = at::native::linalg_norm( + input, + p_node->Input(1).toOptional(), + dim, + keepdim, + dtype); + } else { + p_node->Output(0) = at::native::linalg_norm( + input, p_node->Input(1).toStringView(), dim, keepdim, dtype); + } + return; + } + + auto& output = p_node->Output(0).toTensor(); + fastResizeToZero(output); + + if (p_node->Input(1).isScalar()) { + at::native::linalg_norm_out( + input, + p_node->Input(1).toOptional(), + dim, + keepdim, + dtype, + output); + } else { + at::native::linalg_norm_out( + input, p_node->Input(1).toStringRef(), dim, keepdim, dtype, output); + } + }; +}); + namespace { void check_cat_no_zero_dim(const std::vector& tensors) { -- 2.7.4