[Static Runtime] Check if outputs of a node do not overlap with each other (#63013)
authorDon Jang <djang@fb.com>
Wed, 15 Sep 2021 15:35:57 +0000 (08:35 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 15:38:05 +0000 (08:38 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63013

This change enhances the current memory overlapping check to include outputs: the enhancement enforces a constraint that all outputs of a node should NOT overlap with each other since they are supposed to be update by a node at the same time, holding the node's outputs.

This check will detect a problem like T97393697 immediately in debug mode.

Test Plan:
- Added a unittest `ProcessedNode.VerifyMemoryOverlapWithOverlappingOutputs`

- Ran `inline_cvr` on ./buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench with this diff and confirmed that the checking condition holds true during the run.

Reviewed By: hlu1

Differential Revision: D30211705

fbshipit-source-id: 994d8dace2422e2498e504eb61452a55739238c0

benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/impl.cpp
torch/csrc/jit/runtime/static/impl.h

index 5eb3dfe..d39facb 100644 (file)
@@ -939,7 +939,7 @@ TEST(StaticRuntime, FusionPass) {
 
 TEST(
     ProcessedNode,
-    VerifyOutputsNotOverlappingWithImmutableInputsWithImmutableArguments) {
+    VerifyNoMemoryOverlapWithImmutableInputsWithImmutableArguments) {
   script::Module module("module");
   // Not using out= variant.
   module.define(sigmoid_script);
@@ -951,15 +951,15 @@ TEST(
   ProcessedNode pnode(sigmoid_node, std::move(ivalue_inputs), true);
 
   pnode.Output(0) = b;
-  EXPECT_TRUE(pnode.verify_outputs_not_overlapping_with_immutable_inputs());
+  EXPECT_TRUE(pnode.verify_no_memory_overlap());
 
   pnode.Output(0) = a;
-  EXPECT_FALSE(pnode.verify_outputs_not_overlapping_with_immutable_inputs());
+  EXPECT_FALSE(pnode.verify_no_memory_overlap());
 }
 
 TEST(
     ProcessedNode,
-    VerifyOutputsNotOverlappingWithImmutableInputsWithMutableArguments) {
+    VerifyNoMemoryOverlapWithImmutableInputsWithMutableArguments) {
   script::Module module("module");
   // Using out= variant.
   module.define(sigmoid_inplace_script);
@@ -971,10 +971,40 @@ TEST(
   ProcessedNode pnode(sigmoid_node, std::move(ivalue_inputs), true);
 
   pnode.Output(0) = b;
-  EXPECT_TRUE(pnode.verify_outputs_not_overlapping_with_immutable_inputs());
+  EXPECT_TRUE(pnode.verify_no_memory_overlap());
 
   pnode.Output(0) = a;
-  EXPECT_TRUE(pnode.verify_outputs_not_overlapping_with_immutable_inputs());
+  EXPECT_TRUE(pnode.verify_no_memory_overlap());
+}
+
+TEST(ProcessedNode, VerifyNoMemoryOverlapWithOverlappingOutputs) {
+  auto g = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(
+      R"IR(
+    graph(%0):
+      %1 : Tensor, %2 : Tensor = prim::ListUnpack(%0)
+      return (%1, %2))IR",
+      g.get());
+  torch::jit::StaticModule smodule(g);
+  Node* list_unpack_node = getNodeWithKind(smodule, "prim::ListUnpack");
+  {
+    auto a = at::randn({2, 3});
+    IValue ivalue(a);
+    std::vector<const IValue*> inputs{&ivalue};
+    ProcessedNode list_unpack_pnode(list_unpack_node, std::move(inputs), /*enable_out_variant=*/true);
+    ASSERT_EQ(list_unpack_pnode.outputs().size(), 2);
+    EXPECT_TRUE(list_unpack_pnode.verify_no_memory_overlap());
+  }
+  {
+    auto a = at::randn({2, 3});
+    IValue ivalue(a);
+    std::vector<const IValue*> inputs{&ivalue};
+    ProcessedNode list_unpack_pnode(list_unpack_node, std::move(inputs), /*enable_out_variant=*/true);
+    auto b = at::randn({2, 3});
+    list_unpack_pnode.Output(0) = b;
+    list_unpack_pnode.Output(1) = b;
+    EXPECT_FALSE(list_unpack_pnode.verify_no_memory_overlap());
+  }
 }
 
 TEST(StaticRuntime, IndividualOps_isinstance) {
index e001721..19d25d5 100644 (file)
@@ -1462,7 +1462,7 @@ ProcessedNode::ProcessedNode(
 }
 
 void ProcessedNode::run() {
-  DCHECK(verify_outputs_not_overlapping_with_immutable_inputs());
+  DCHECK(verify_no_memory_overlap());
   if (fn_) {
     fn_(this);
   } else if (native_fn_) {
@@ -1489,8 +1489,35 @@ void ProcessedNode::run() {
   }
 }
 
-bool ProcessedNode::verify_outputs_not_overlapping_with_immutable_inputs()
-    const {
+static bool checkNoMemoryOverlap(const at::Tensor& a, const at::Tensor& b) {
+  at::MemOverlapStatus status = at::get_overlap_status(a, b);
+  if (status == at::MemOverlapStatus::FULL ||
+      status == at::MemOverlapStatus::PARTIAL) {
+    return false;
+  }
+  if (status == at::MemOverlapStatus::TOO_HARD) {
+    LOG(WARNING) << "Detected TOO_HARD memory overlap status";
+  }
+  return true;
+}
+
+bool ProcessedNode::verify_no_memory_overlap() const {
+  for (size_t i = 0; i < outputs_.size(); ++i) {
+    if (!outputs_[i].isTensor()) {
+      continue;
+    }
+    const auto& out0_t = outputs_[i].toTensor();
+    for (size_t j = i + 1; j < outputs_.size(); ++j) {
+      if (!outputs_[j].isTensor()) {
+        continue;
+      }
+      const auto& out1_t = outputs_[j].toTensor();
+      if (!checkNoMemoryOverlap(out0_t, out1_t)) {
+        return false;
+      }
+    }
+  }
+
   auto schema = node()->maybeSchema();
   if (!schema || schema->is_mutable()) {
     return true;
@@ -1505,8 +1532,7 @@ bool ProcessedNode::verify_outputs_not_overlapping_with_immutable_inputs()
         continue;
       }
       const auto& out_t = out.toTensor();
-      at::MemOverlapStatus status = at::get_overlap_status(in_t, out_t);
-      if (status != at::MemOverlapStatus::NO) {
+      if (!checkNoMemoryOverlap(in_t, out_t)) {
         return false;
       }
     }
index 4aa3608..4b5560f 100644 (file)
@@ -445,7 +445,7 @@ class TORCH_API ProcessedNode {
     return static_cast<bool>(native_fn_);
   }
 
-  bool verify_outputs_not_overlapping_with_immutable_inputs() const;
+  bool verify_no_memory_overlap() const;
 
  private:
   Node* node_;