[Static Runtime] Manage temporary Tensors for aten::layer_norm (#64078)
authorDon Jang <djang@fb.com>
Fri, 27 Aug 2021 09:43:22 +0000 (02:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 09:44:43 +0000 (02:44 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64078

This change converts `aten::layer_norm -> output Tensor` to `static_runtime::layer_norm -> (output Tensor, temp1 Tensor, tmp2 Tensor)` to manage `tmp1` and `tmp2` Tensors by the static runtime.

Currently the out-variant of `aten::layer_norm` creates two temporary Tensors inside it:
```
    at::Tensor mean = create_empty_from({M}, *X);
    at::Tensor rstd = create_empty_from({M}, *X);
```
that the static runtime misses an opportunity to manage.

This change puts them into (unused) output Tensors of a new placeholder op `static_runtime::layer_norm` so that the static runtime can mange them since the static runtime as of now chooses to manage only output tensors.

Test Plan:
- Enhanced `StaticRuntime.LayerNorm` to ensure that `static_runtime::layer_norm` gets activated.

- Confirmed that the new op gets activated during testing:

```
V0825 12:51:50.017890 2265227 impl.cpp:1396] Switch to out variant for node: %8 : Tensor, %9 : Tensor, %10 : Tensor = static_runtime::layer_norm(%input.1, %normalized_shape.1, %4, %4, %5, %3)

```

Reviewed By: hlu1

Differential Revision: D30486475

fbshipit-source-id: 5121c44ab58c2d8a954aa0bbd9dfeb7468347a2d

benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/impl.cpp
torch/csrc/jit/runtime/static/ops.cpp
torch/csrc/jit/runtime/static/passes.cpp
torch/csrc/jit/runtime/static/passes.h

index 4441b7d..0d42024 100644 (file)
@@ -209,6 +209,13 @@ TEST(StaticRuntime, EmbeddingBag) {
 }
 
 TEST(StaticRuntime, LayerNorm) {
+#ifdef FBCODE_CAFFE2
+  script::Module module("module");
+  module.define(layer_norm_with_weights);
+  torch::jit::StaticModule smodule(module);
+  ASSERT_EQ(getNodeWithKind(smodule, "aten::layer_norm"), nullptr);
+  ASSERT_NE(getNodeWithKind(smodule, "static_runtime::layer_norm"), nullptr);
+#endif
   const auto a = torch::rand({1, 2, 2, 2});
   const auto b = torch::rand({3, 2, 2, 2});
   for (int normalized_size : {2, 3}) {
index b3e1eb1..643842a 100644 (file)
@@ -74,6 +74,7 @@ void OptimizeGraph(
   if (opts.enable_out_variant) {
     FuseListUnpack(graph);
     ReplaceWithCopy(graph);
+    EnableStaticRuntimeLayerNorm(graph);
   }
 #endif
   ConstantPropagation(graph);
index 54c0456..7e78b77 100644 (file)
@@ -1308,55 +1308,76 @@ REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator {
   };
 });
 
-REGISTER_OPERATOR_FUNCTOR(aten::layer_norm, aten_layer_norm, [](Node* n) -> SROperator {
-  if (!n->matches(torch::schema(
-          "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"))) {
-    LogAndDumpSchema(n);
-    return nullptr;
-  }
-  return [](ProcessedNode* p_node) {
-    // ignore Input(5): `bool cudnn_enable=True`
-    const auto& input = p_node->Input(0).toTensor();
-    const auto normalized_shape = p_node->Input(1).toIntVector();
-    auto weight_opt = p_node->Input(2).toOptional<at::Tensor>();
-    auto bias_opt = p_node->Input(3).toOptional<at::Tensor>();
-    float eps = p_node->Input(4).toDouble();
-
-    c10::MaybeOwned<at::Tensor> weight_maybe_owned =
-        at::borrow_from_optional_tensor(weight_opt);
-    const at::Tensor& weight = *weight_maybe_owned;
-    c10::MaybeOwned<at::Tensor> bias_maybe_owned =
-        at::borrow_from_optional_tensor(bias_opt);
-    const at::Tensor& bias = *bias_maybe_owned;
-
-    auto M_N = at::native::_check_layer_norm_inputs(
-        input, normalized_shape, weight, bias);
-    auto M = M_N.first;
-    auto N = M_N.second;
-    auto X = input.expect_contiguous();
-    auto gamma = weight.expect_contiguous();
-    auto beta = bias.expect_contiguous();
-
-    if (p_node->Output(0).isNone()) {
-      p_node->Output(0) = at::native::empty_like(
-          *X,
-          c10::nullopt /* dtype */,
-          c10::nullopt /* layout */,
-          c10::nullopt /* device */,
-          c10::nullopt /* pin_memory */,
-          at::MemoryFormat::Contiguous);
-    } else {
-      at::native::resize_(
-          p_node->Output(0).toTensor(), X->sizes(), c10::nullopt);
-    }
-    at::Tensor& output = p_node->Output(0).toTensor();
-    at::Tensor mean = create_empty_from({M}, *X);
-    at::Tensor rstd = create_empty_from({M}, *X);
+REGISTER_OPERATOR_FUNCTOR(
+    static_runtime::layer_norm,
+    aten_layer_norm,
+    [](Node* n) -> SROperator {
+      if (!n->matches(torch::schema(
+              "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor,Tensor,Tensor)"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        // ignore Input(5): `bool cudnn_enable=True`
+        const auto& input = p_node->Input(0).toTensor();
+        const auto normalized_shape = p_node->Input(1).toIntVector();
+        auto weight_opt = p_node->Input(2).toOptional<at::Tensor>();
+        auto bias_opt = p_node->Input(3).toOptional<at::Tensor>();
+        float eps = p_node->Input(4).toDouble();
+
+        c10::MaybeOwned<at::Tensor> weight_maybe_owned =
+            at::borrow_from_optional_tensor(weight_opt);
+        const at::Tensor& weight = *weight_maybe_owned;
+        c10::MaybeOwned<at::Tensor> bias_maybe_owned =
+            at::borrow_from_optional_tensor(bias_opt);
+        const at::Tensor& bias = *bias_maybe_owned;
+
+        auto M_N = at::native::_check_layer_norm_inputs(
+            input, normalized_shape, weight, bias);
+        auto M = M_N.first;
+        auto N = M_N.second;
+        auto X = input.expect_contiguous();
+        auto gamma = weight.expect_contiguous();
+        auto beta = bias.expect_contiguous();
 
-    at::native::layer_norm_cpu_out(
-        output, mean, rstd, input, normalized_shape, *gamma, *beta, eps, M, N);
-  };
-});
+        if (p_node->Output(0).isNone()) {
+          p_node->Output(0) = at::native::empty_like(
+              *X,
+              c10::nullopt /* dtype */,
+              c10::nullopt /* layout */,
+              c10::nullopt /* device */,
+              c10::nullopt /* pin_memory */,
+              at::MemoryFormat::Contiguous);
+        } else {
+          at::native::resize_(
+              p_node->Output(0).toTensor(), X->sizes(), c10::nullopt);
+        }
+        if (p_node->Output(1).isNone()) {
+          p_node->Output(1) = create_empty_from({M}, *X);
+        } else {
+          at::native::resize_(p_node->Output(1).toTensor(), {M}, c10::nullopt);
+        }
+        if (p_node->Output(2).isNone()) {
+          p_node->Output(2) = create_empty_from({M}, *X);
+        } else {
+          at::native::resize_(p_node->Output(2).toTensor(), {M}, c10::nullopt);
+        }
+        at::Tensor& output = p_node->Output(0).toTensor();
+        at::Tensor mean = p_node->Output(1).toTensor();
+        at::Tensor rstd = p_node->Output(2).toTensor();
+        at::native::layer_norm_cpu_out(
+            output,
+            mean,
+            rstd,
+            input,
+            normalized_shape,
+            *gamma,
+            *beta,
+            eps,
+            M,
+            N);
+      };
+    });
 
 REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator {
   if (!n->matches(torch::schema(
index 1133e39..5099dc1 100644 (file)
@@ -303,6 +303,9 @@ TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
       "static_runtime::to_copy.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
   m.def(
       "static_runtime::to_copy.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
+  m.def(torch::schema(
+      "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor, Tensor, Tensor)",
+      c10::AliasAnalysisKind::PURE_FUNCTION));
 }
 
 bool HasInplaceOp(std::shared_ptr<Graph>& graph, const AliasDb& alias_db) {
@@ -469,5 +472,35 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
 #endif
 }
 
+void EnableStaticRuntimeLayerNorm(std::shared_ptr<torch::jit::Graph>& graph) {
+  const c10::Symbol static_runtime_layer_norm_symbol =
+      c10::Symbol::fromQualString("static_runtime::layer_norm");
+  auto nodes = graph->nodes();
+  std::vector<std::pair<Node*, Node*>> replacement;
+  for (auto it = nodes.begin(); it != nodes.end(); ++it) {
+    Node* old_node = *it;
+    if (!old_node->matches(torch::schema(
+            "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"))) {
+      continue;
+    }
+    TORCH_CHECK(old_node->outputs().size() == 1);
+    auto* new_node = graph->create(
+        static_runtime_layer_norm_symbol,
+        /*layer_norm*/ 1 + /*mean*/ 1 + /*rst=*/1);
+    new_node->insertBefore(old_node);
+    for (auto* input : old_node->inputs()) {
+      new_node->addInput(input);
+    }
+    replacement.emplace_back(old_node, new_node);
+  }
+  for (const auto& p : replacement) {
+    auto* old_node = p.first;
+    auto* new_node = p.second;
+    new_node->output(0)->copyMetadata(old_node->output(0));
+    old_node->output(0)->replaceAllUsesWith(new_node->output(0));
+    old_node->destroy();
+  }
+}
+
 } // namespace jit
 } // namespace torch
index 11ab4bd..a42bc97 100644 (file)
@@ -13,6 +13,9 @@ TORCH_API void ReplaceWithCopy(
     std::shared_ptr<torch::jit::Graph>& graph,
     bool outputs_are_immutable = true);
 
+TORCH_API void EnableStaticRuntimeLayerNorm(
+    std::shared_ptr<torch::jit::Graph>& graph);
+
 TORCH_API bool HasInplaceOp(
     std::shared_ptr<Graph>& graph,
     const AliasDb& alias_db);