[StaticRuntime] Fix bug in HasInplaceOp (#63842)
authorHao Lu <hlu@fb.com>
Wed, 25 Aug 2021 00:06:18 +0000 (17:06 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 00:07:45 +0000 (17:07 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63842

Reviewed By: mikeiovine

Differential Revision: D30506914

fbshipit-source-id: b2e358cfb991dacdb295b61bbc37beb36b73b852

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/passes.cpp

index c82dd57..90f93b2 100644 (file)
@@ -138,6 +138,22 @@ const auto reshape_inplace_script = R"JIT(
       return (d, e, f)
 )JIT";
 
+const auto reshape_inplace_script_1 = R"JIT(
+  def forward(self, inp: Tensor, shape: List[int], flag: bool):
+    if flag:
+      a = inp + inp
+      b = a.reshape(shape)
+      c = b.sigmoid()
+    else:
+      a = inp * inp
+      b = a.sigmoid_()
+      c = b.reshape(shape)
+    d = c + c
+    e = a + a
+    f = b + b
+    return (d, e, f)
+)JIT";
+
 const auto sigmoid_inplace_script = R"JIT(
   def forward(self, inp: Tensor):
       a = torch.sigmoid(inp, out=inp).clone()
index 701231e..f6ec677 100644 (file)
@@ -69,6 +69,7 @@ Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind) {
 
 TEST(StaticRuntime, InPlace) {
   EXPECT_TRUE(testHasInplaceOp(reshape_inplace_script));
+  EXPECT_TRUE(testHasInplaceOp(reshape_inplace_script_1));
   EXPECT_TRUE(testHasInplaceOp(sigmoid_inplace_script));
   EXPECT_FALSE(testHasInplaceOp(sigmoid_out_script));
 }
index 2e9eb57..c8e1107 100644 (file)
@@ -12,7 +12,9 @@ namespace {
 bool HasInplaceOp(Block* block, const AliasDb& alias_db) {
   for (auto* node : block->nodes()) {
     for (Block* sub_block : node->blocks()) {
-      return HasInplaceOp(sub_block, alias_db);
+      if (HasInplaceOp(sub_block, alias_db)) {
+        return true;
+      }
     }
     auto inputs = node->inputs();
     // check if node modifies inputs (both inplace ops and certain out variants