From: Don Jang Date: Fri, 27 Aug 2021 09:43:22 +0000 (-0700) Subject: [Static Runtime] Manage temporary Tensors for aten::layer_norm (#64078) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~657 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c90b3cb1dabe712aa07e082b3735f1f2a9134c9b;p=platform%2Fupstream%2Fpytorch.git [Static Runtime] Manage temporary Tensors for aten::layer_norm (#64078) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64078 This change converts `aten::layer_norm -> output Tensor` to `static_runtime::layer_norm -> (output Tensor, temp1 Tensor, tmp2 Tensor)` to manage `tmp1` and `tmp2` Tensors by the static runtime. Currently the out-variant of `aten::layer_norm` creates two temporary Tensors inside it: ``` at::Tensor mean = create_empty_from({M}, *X); at::Tensor rstd = create_empty_from({M}, *X); ``` that the static runtime misses an opportunity to manage. This change puts them into (unused) output Tensors of a new placeholder op `static_runtime::layer_norm` so that the static runtime can mange them since the static runtime as of now chooses to manage only output tensors. Test Plan: - Enhanced `StaticRuntime.LayerNorm` to ensure that `static_runtime::layer_norm` gets activated. - Confirmed that the new op gets activated during testing: ``` V0825 12:51:50.017890 2265227 impl.cpp:1396] Switch to out variant for node: %8 : Tensor, %9 : Tensor, %10 : Tensor = static_runtime::layer_norm(%input.1, %normalized_shape.1, %4, %4, %5, %3) ``` Reviewed By: hlu1 Differential Revision: D30486475 fbshipit-source-id: 5121c44ab58c2d8a954aa0bbd9dfeb7468347a2d --- diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 4441b7d..0d42024 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -209,6 +209,13 @@ TEST(StaticRuntime, EmbeddingBag) { } TEST(StaticRuntime, LayerNorm) { +#ifdef FBCODE_CAFFE2 + script::Module module("module"); + module.define(layer_norm_with_weights); + torch::jit::StaticModule smodule(module); + ASSERT_EQ(getNodeWithKind(smodule, "aten::layer_norm"), nullptr); + ASSERT_NE(getNodeWithKind(smodule, "static_runtime::layer_norm"), nullptr); +#endif const auto a = torch::rand({1, 2, 2, 2}); const auto b = torch::rand({3, 2, 2, 2}); for (int normalized_size : {2, 3}) { diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index b3e1eb1..643842a 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -74,6 +74,7 @@ void OptimizeGraph( if (opts.enable_out_variant) { FuseListUnpack(graph); ReplaceWithCopy(graph); + EnableStaticRuntimeLayerNorm(graph); } #endif ConstantPropagation(graph); diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 54c0456..7e78b77 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1308,55 +1308,76 @@ REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator { }; }); -REGISTER_OPERATOR_FUNCTOR(aten::layer_norm, aten_layer_norm, [](Node* n) -> SROperator { - if (!n->matches(torch::schema( - "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"))) { - LogAndDumpSchema(n); - return nullptr; - } - return [](ProcessedNode* p_node) { - // ignore Input(5): `bool cudnn_enable=True` - const auto& input = p_node->Input(0).toTensor(); - const auto normalized_shape = p_node->Input(1).toIntVector(); - auto weight_opt = p_node->Input(2).toOptional(); - auto bias_opt = p_node->Input(3).toOptional(); - float eps = p_node->Input(4).toDouble(); - - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); - const at::Tensor& weight = *weight_maybe_owned; - c10::MaybeOwned bias_maybe_owned = - at::borrow_from_optional_tensor(bias_opt); - const at::Tensor& bias = *bias_maybe_owned; - - auto M_N = at::native::_check_layer_norm_inputs( - input, normalized_shape, weight, bias); - auto M = M_N.first; - auto N = M_N.second; - auto X = input.expect_contiguous(); - auto gamma = weight.expect_contiguous(); - auto beta = bias.expect_contiguous(); - - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::empty_like( - *X, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - c10::nullopt /* device */, - c10::nullopt /* pin_memory */, - at::MemoryFormat::Contiguous); - } else { - at::native::resize_( - p_node->Output(0).toTensor(), X->sizes(), c10::nullopt); - } - at::Tensor& output = p_node->Output(0).toTensor(); - at::Tensor mean = create_empty_from({M}, *X); - at::Tensor rstd = create_empty_from({M}, *X); +REGISTER_OPERATOR_FUNCTOR( + static_runtime::layer_norm, + aten_layer_norm, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor,Tensor,Tensor)"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + // ignore Input(5): `bool cudnn_enable=True` + const auto& input = p_node->Input(0).toTensor(); + const auto normalized_shape = p_node->Input(1).toIntVector(); + auto weight_opt = p_node->Input(2).toOptional(); + auto bias_opt = p_node->Input(3).toOptional(); + float eps = p_node->Input(4).toDouble(); + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const at::Tensor& weight = *weight_maybe_owned; + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(bias_opt); + const at::Tensor& bias = *bias_maybe_owned; + + auto M_N = at::native::_check_layer_norm_inputs( + input, normalized_shape, weight, bias); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + auto beta = bias.expect_contiguous(); - at::native::layer_norm_cpu_out( - output, mean, rstd, input, normalized_shape, *gamma, *beta, eps, M, N); - }; -}); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::empty_like( + *X, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + at::MemoryFormat::Contiguous); + } else { + at::native::resize_( + p_node->Output(0).toTensor(), X->sizes(), c10::nullopt); + } + if (p_node->Output(1).isNone()) { + p_node->Output(1) = create_empty_from({M}, *X); + } else { + at::native::resize_(p_node->Output(1).toTensor(), {M}, c10::nullopt); + } + if (p_node->Output(2).isNone()) { + p_node->Output(2) = create_empty_from({M}, *X); + } else { + at::native::resize_(p_node->Output(2).toTensor(), {M}, c10::nullopt); + } + at::Tensor& output = p_node->Output(0).toTensor(); + at::Tensor mean = p_node->Output(1).toTensor(); + at::Tensor rstd = p_node->Output(2).toTensor(); + at::native::layer_norm_cpu_out( + output, + mean, + rstd, + input, + normalized_shape, + *gamma, + *beta, + eps, + M, + N); + }; + }); REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator { if (!n->matches(torch::schema( diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index 1133e39..5099dc1 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -303,6 +303,9 @@ TORCH_LIBRARY_FRAGMENT(static_runtime, m) { "static_runtime::to_copy.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"); m.def( "static_runtime::to_copy.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"); + m.def(torch::schema( + "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor, Tensor, Tensor)", + c10::AliasAnalysisKind::PURE_FUNCTION)); } bool HasInplaceOp(std::shared_ptr& graph, const AliasDb& alias_db) { @@ -469,5 +472,35 @@ void FuseListUnpack(std::shared_ptr& graph) { #endif } +void EnableStaticRuntimeLayerNorm(std::shared_ptr& graph) { + const c10::Symbol static_runtime_layer_norm_symbol = + c10::Symbol::fromQualString("static_runtime::layer_norm"); + auto nodes = graph->nodes(); + std::vector> replacement; + for (auto it = nodes.begin(); it != nodes.end(); ++it) { + Node* old_node = *it; + if (!old_node->matches(torch::schema( + "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"))) { + continue; + } + TORCH_CHECK(old_node->outputs().size() == 1); + auto* new_node = graph->create( + static_runtime_layer_norm_symbol, + /*layer_norm*/ 1 + /*mean*/ 1 + /*rst=*/1); + new_node->insertBefore(old_node); + for (auto* input : old_node->inputs()) { + new_node->addInput(input); + } + replacement.emplace_back(old_node, new_node); + } + for (const auto& p : replacement) { + auto* old_node = p.first; + auto* new_node = p.second; + new_node->output(0)->copyMetadata(old_node->output(0)); + old_node->output(0)->replaceAllUsesWith(new_node->output(0)); + old_node->destroy(); + } +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h index 11ab4bd..a42bc97 100644 --- a/torch/csrc/jit/runtime/static/passes.h +++ b/torch/csrc/jit/runtime/static/passes.h @@ -13,6 +13,9 @@ TORCH_API void ReplaceWithCopy( std::shared_ptr& graph, bool outputs_are_immutable = true); +TORCH_API void EnableStaticRuntimeLayerNorm( + std::shared_ptr& graph); + TORCH_API bool HasInplaceOp( std::shared_ptr& graph, const AliasDb& alias_db);