[Static Runtime] Implement and enable variadic tuple unpack (#64934)
authorMike Iovine <mikeiovine@fb.com>
Mon, 20 Sep 2021 17:25:57 +0000 (10:25 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 20 Sep 2021 17:36:11 +0000 (10:36 -0700)
commit99e4ab5d44870f09e9f036288d58609f6c11ea95
tree888024db82d47ae6654bbe3fddbce4bd9a6bf85a
parent14347d0dd544023b482629cb6a0d3b2a4ac2959d
[Static Runtime] Implement and enable variadic tuple unpack (#64934)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64934

Add a new op `static_runtime::VarTupleUnpack` and a graph pass transforming graph sequences from:
```
%0, %1 = prim::TupleUnpack(%a)
%2, %3 = prim::TupleUnpack(%b)
```
into:
```
%0, %1, %2, %3 = static_runtime::VarTupleUnpack(%a, %b)
```

The pass is only applied to contiguous blocks of `TupleUnpack` nodes. This is the most straightforward way to guarantee correctness, and it is sufficient for the models we care about.

Test Plan: New unit tests: `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- VarTupleUnpack`

Reviewed By: d1jang

Differential Revision: D30872109

fbshipit-source-id: 1ed4a7e201c532da28f703a3a50241c392a6c7e9
benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/impl.cpp
torch/csrc/jit/runtime/static/native_ops.cpp
torch/csrc/jit/runtime/static/passes.cpp
torch/csrc/jit/runtime/static/passes.h