[Static Runtime] Manage temporary Tensors for aten::layer_norm (#64078)
authorDon Jang <djang@fb.com>
Fri, 27 Aug 2021 09:43:22 +0000 (02:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 09:44:43 +0000 (02:44 -0700)
commitc90b3cb1dabe712aa07e082b3735f1f2a9134c9b
tree43e4873ce6ae61d6745621a3e88da18694a55c1a
parent3c3bba4169067a7340ff1d786a6b61282cf26820
[Static Runtime] Manage temporary Tensors for aten::layer_norm (#64078)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64078

This change converts `aten::layer_norm -> output Tensor` to `static_runtime::layer_norm -> (output Tensor, temp1 Tensor, tmp2 Tensor)` to manage `tmp1` and `tmp2` Tensors by the static runtime.

Currently the out-variant of `aten::layer_norm` creates two temporary Tensors inside it:
```
    at::Tensor mean = create_empty_from({M}, *X);
    at::Tensor rstd = create_empty_from({M}, *X);
```
that the static runtime misses an opportunity to manage.

This change puts them into (unused) output Tensors of a new placeholder op `static_runtime::layer_norm` so that the static runtime can mange them since the static runtime as of now chooses to manage only output tensors.

Test Plan:
- Enhanced `StaticRuntime.LayerNorm` to ensure that `static_runtime::layer_norm` gets activated.

- Confirmed that the new op gets activated during testing:

```
V0825 12:51:50.017890 2265227 impl.cpp:1396] Switch to out variant for node: %8 : Tensor, %9 : Tensor, %10 : Tensor = static_runtime::layer_norm(%input.1, %normalized_shape.1, %4, %4, %5, %3)

```

Reviewed By: hlu1

Differential Revision: D30486475

fbshipit-source-id: 5121c44ab58c2d8a954aa0bbd9dfeb7468347a2d
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/impl.cpp
torch/csrc/jit/runtime/static/ops.cpp
torch/csrc/jit/runtime/static/passes.cpp
torch/csrc/jit/runtime/static/passes.h