From 09e610e36d0106410e37e129fd0cd5749c74ad5f Mon Sep 17 00:00:00 2001 From: Ray Peng Date: Tue, 31 Aug 2021 17:45:50 -0700 Subject: [PATCH] [Static Runtime] Out version for softmax (#64243) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64243 Test Plan: ``` > buck run //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --v=1 ... V0830 16:35:22.524479 613839 impl.cpp:1410] Switch to out variant for node: %5 : Tensor = aten::softmax(%a.1, %dim.1, %dtype.1) ... [ OK ] StaticRuntime.IndividualOps_Softmax (803 ms) ``` Reviewed By: hlu1 Differential Revision: D30656149 fbshipit-source-id: 115b7b4a75448fd6a5c526808080ca9a4251302c --- benchmarks/static_runtime/test_scripts.h | 10 ++++++++++ benchmarks/static_runtime/test_static_runtime.cc | 16 ++++++++++++++++ torch/csrc/jit/runtime/static/ops.cpp | 24 ++++++++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 37bb222..99b73db 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -676,6 +676,16 @@ const auto argmin_with_keep_dim_script = R"JIT( return torch.argmin(a, dim, True).clone() )JIT"; +const auto softmax_script = R"JIT( + def forward(self, a: Tensor, dim: int): + return torch.softmax(a, dim).clone() +)JIT"; + +const auto softmax_script_with_dtype = R"JIT( + def forward(self, a: Tensor, dim: int, dtype: int): + return torch.softmax(a, dim, dtype=dtype).clone() +)JIT"; + const auto getitem_dict_tensor_script = R"JIT( def forward(self, key: Tensor): d = {key: 1} diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 8e498db..16941da 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -1083,6 +1084,21 @@ TEST(StaticRuntime, IndividualOps_Argmin) { testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b); } +TEST(StaticRuntime, IndividualOps_Softmax) { + auto a = at::randn({2, 3}); + auto b = at::randn({3, 3, 3}); + + testStaticRuntime(softmax_script, {a, 0}); + testStaticRuntime(softmax_script, {a, 1}); + + testStaticRuntime(softmax_script, {b, 0}); + testStaticRuntime(softmax_script, {b, 1}); + testStaticRuntime(softmax_script, {b, 2}); + + testStaticRuntime(softmax_script_with_dtype, {a, 1, at::ScalarType::Float}); + testStaticRuntime(softmax_script_with_dtype, {b, 1, at::ScalarType::Float}); +} + TEST(StaticRuntime, IndividualOps_GetItem_Dict) { int int_key = 0; std::string str_key = "str"; diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 0cc38b0..7ede15c 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -1338,6 +1339,29 @@ REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator { }; }); +REGISTER_OPERATOR_FUNCTOR(aten::softmax, aten_softmax, [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "aten::softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& in_t = p_node->Input(0).toTensor(); + const auto& dim = p_node->Input(1).toInt(); + const auto& dtype = p_node->Input(2).toOptional(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::softmax(in_t, dim, dtype); + } else { + auto& out_t = p_node->Output(0).toTensor(); + fastResizeToZero(out_t); + + auto half_to_float = in_t.scalar_type() == at::ScalarType::Half && + dtype == at::ScalarType::Float; + at::cpu::_softmax_out(out_t, in_t, dim, half_to_float); + } + }; +}); + REGISTER_OPERATOR_FUNCTOR( static_runtime::layer_norm, aten_layer_norm, -- 2.7.4