[RELAY][FIX] Fix hang in MergeCompilerRegions (#5227)
authormbaret <55580676+mbaret@users.noreply.github.com>
Fri, 3 Apr 2020 16:33:15 +0000 (17:33 +0100)
committerGitHub <noreply@github.com>
Fri, 3 Apr 2020 16:33:15 +0000 (09:33 -0700)
For certain network topologies, MCR could hang.
This patch fixes that case.

Change-Id: I3edd8a8a6b452b2b838b777720adea22a3b995b4

src/relay/analysis/annotated_region_set.cc
src/relay/transforms/merge_compiler_regions.cc

index f7b9b42..ad2b9e1 100644 (file)
@@ -22,7 +22,6 @@
 #include <tvm/relay/expr.h>
 #include <tvm/ir/error.h>
 
-#include <algorithm>
 #include <unordered_map>
 #include <vector>
 
@@ -58,8 +57,8 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
   std::vector<Expr> ins_to_remove;
   for (const auto& input : dest->ins) {
     auto call = Downcast<Call>(input);
-    auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
-    if (it != src->outs.end()) {
+    auto it = src->nodes.find(call->args[0]);
+    if (it != src->nodes.end()) {
       dest->outs.remove(*it);
       ins_to_remove.push_back(input);
     }
index 4a8ff64..5253010 100644 (file)
@@ -263,6 +263,7 @@ class RegionMerger : public ExprVisitor {
   void VisitExpr_(const CallNode* call) final {
     if (call->op == compiler_end_op) {
       auto region = regions_->GetRegion(GetRef<Call>(call));
+      if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
       // set the region target
       auto compiler_attrs = call->attrs.as<CompilerAttrs>();
       region_targets_[region->GetID()] = compiler_attrs->compiler;
@@ -281,13 +282,13 @@ class RegionMerger : public ExprVisitor {
         }
       }
       // get the mergeable regions now all the parents have been visited
-      std::vector<AnnotatedRegion> mergeable_regions;
+      std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
       for (const auto& arg : region->GetInputs()) {
         auto begin = Downcast<Call>(arg);
         CHECK_EQ(begin->op, compiler_begin_op);
         auto parent_region = regions_->GetRegion(begin->args[0]);
         if (!parent_region.defined()) continue;
-        mergeable_regions.push_back(parent_region);
+        mergeable_regions.insert(parent_region);
       }
       auto& region_restrictions = region_restrictions_[region->GetID()];
       for (const auto& parent_region : mergeable_regions) {