HG: Commit message of changeset 6281661. (#5622)
authorhlu1 <14827759+hlu1@users.noreply.github.com>
Fri, 22 May 2020 16:38:30 +0000 (09:38 -0700)
committerGitHub <noreply@github.com>
Fri, 22 May 2020 16:38:30 +0000 (09:38 -0700)
[Relay] Move compiler_begin/end_op to local static objects

src/relay/analysis/util.cc
src/relay/transforms/annotate_target.cc
src/relay/transforms/merge_compiler_regions.cc
src/relay/transforms/partition_graph.cc
src/relay/transforms/pass_util.h
src/runtime/object.cc

index 6d246c0..a05bb8f 100644 (file)
@@ -338,13 +338,13 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
   return true;
 }
 
-// Cache the operators that are checked recursively to reduce lookup overhead.
-static const auto& expand_dims_op = Op::Get("expand_dims");
-static const auto& reshape_op = Op::Get("reshape");
-static const auto& transpose_op = Op::Get("transpose");
-static const auto& squeeze_op = Op::Get("squeeze");
-
 bool IsAllPositiveConstant(const Expr& expr) {
+  // Cache the operators that are checked recursively to reduce lookup overhead.
+  static const auto& expand_dims_op = Op::Get("expand_dims");
+  static const auto& reshape_op = Op::Get("reshape");
+  static const auto& transpose_op = Op::Get("transpose");
+  static const auto& squeeze_op = Op::Get("squeeze");
+
   // peel through a few common transform ops.
   if (const auto* constant = expr.as<ConstantNode>()) {
     const auto& tensor = constant->data;
index 0d97005..bf2788f 100644 (file)
 #include <tvm/relay/transform.h>
 #include <tvm/runtime/container.h>
 
+#include "pass_util.h"
+
 namespace tvm {
 namespace relay {
 namespace annotate_target {
 
-static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
-static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
-
 const PackedFunc* make_begin_op =
     runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
 const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
@@ -66,12 +65,12 @@ class AnnotateTargetRewriter : public ExprRewriter {
       std::string arg_target = "default";
       const CallNode* call = arg.as<CallNode>();
 
-      if (call && call->op == compiler_begin_op) {
+      if (call && call->op == CompilerBeginOp()) {
         // Argument is already compiler begin node meaning that this is not the first time
         // running this pass, so we simply remove it and will add a new one later.
         CHECK_EQ(call->args.size(), 1U);
         const CallNode* end = call->args[0].as<CallNode>();
-        if (end->op == compiler_end_op) {
+        if (end->op == CompilerEndOp()) {
           arg_target = end->attrs.as<CompilerAttrs>()->compiler;
         }
         compiler_ends.push_back(call->args[0]);
@@ -115,12 +114,12 @@ class AnnotateTargetRewriter : public ExprRewriter {
     auto op_node = pre->op.as<OpNode>();
 
     // This graph has annotations, meaning that this is not the first time running this pass.
-    if (op_node && pre->op == compiler_begin_op) {
+    if (op_node && pre->op == CompilerBeginOp()) {
       // Bypass compiler begin due to lack of target information. It will be processed
       // when the following op handling arguments.
       CHECK_EQ(pre->args.size(), 1U);
       return post.as<CallNode>()->args[0];
-    } else if (op_node && pre->op == compiler_end_op) {
+    } else if (op_node && pre->op == CompilerEndOp()) {
       // Override compiler end with the new target.
       CHECK_EQ(pre->args.size(), 1U);
       auto input_expr = post.as<CallNode>()->args[0];
@@ -131,7 +130,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
     // Peek the first argument. If it is compiler begin then this node had annotated by
     // another target before, so we also consider that target as a supported target.
     const CallNode* first_arg_call = pre->args[0].as<CallNode>();
-    if (first_arg_call && first_arg_call->op == compiler_begin_op) {
+    if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
       std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
       if (arg_target != "default") {
         supported_targets.push_back(arg_target);
index 6fbd0d5..b3a606e 100644 (file)
 #include <vector>
 
 #include "../analysis/annotated_region_set.h"
+#include "pass_util.h"
 
 namespace tvm {
 namespace relay {
 namespace merge_compiler_region {
 
-// Cache compiler_begin and compiler_end annotation ops for equivalence check to
-// reduce registry lookup overhead.
-static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
-static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
-
 class RegionMerger : public MixedModeVisitor {
  public:
   explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
 
   void VisitExpr_(const CallNode* call) final {
-    if (call->op == compiler_end_op) {
+    if (call->op == CompilerEndOp()) {
       auto region = regions_->GetRegion(GetRef<Call>(call));
 
       // Skip this region if it has been merged to the other region.
@@ -75,7 +71,7 @@ class RegionMerger : public MixedModeVisitor {
         // Region inputs must be begin annotation, and the region of
         // the begin annotation's argument is the parent region.
         auto begin = Downcast<Call>(arg);
-        CHECK_EQ(begin->op, compiler_begin_op);
+        CHECK_EQ(begin->op, CompilerBeginOp());
         auto parent_region = regions_->GetRegion(begin->args[0]);
 
         // Skip this region if it has been merged.
@@ -90,7 +86,7 @@ class RegionMerger : public MixedModeVisitor {
       std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
       for (const auto& arg : region->GetInputs()) {
         auto begin = Downcast<Call>(arg);
-        CHECK_EQ(begin->op, compiler_begin_op);
+        CHECK_EQ(begin->op, CompilerBeginOp());
         auto parent_region = regions_->GetRegion(begin->args[0]);
         if (parent_region.defined()) {
           mergeable_regions.insert(parent_region);
@@ -147,9 +143,9 @@ class MergeAnnotations : public ExprRewriter {
     // Merge annotations which are now internal to a region.
     // This happens if we see a compiler begin next to a
     // compiler end and they're both in the same region.
-    if (call->op == compiler_begin_op && call->args[0]->IsInstance<CallNode>()) {
+    if (call->op == CompilerBeginOp() && call->args[0]->IsInstance<CallNode>()) {
       auto arg = Downcast<Call>(call->args[0]);
-      if (arg->op == compiler_end_op) {
+      if (arg->op == CompilerEndOp()) {
         auto region1 = regions_->GetRegion(GetRef<Call>(call));
         auto region2 = regions_->GetRegion(arg);
         if (region1 == region2) {
@@ -167,7 +163,7 @@ class MergeAnnotations : public ExprRewriter {
 
 Expr MergeCompilerRegions(const Expr& expr) {
   // Create regions using the annotations.
-  AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
+  AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, CompilerBeginOp(), CompilerEndOp());
 
   // Analyze the graph to explore the opportunities of merging regions.
   RegionMerger merger(regions);
index 0d25e0a..9481e07 100644 (file)
 
 #include "../analysis/annotated_region_set.h"
 #include "../backend/utils.h"
+#include "pass_util.h"
 
 namespace tvm {
 namespace relay {
 namespace partitioning {
 
-// Cache compiler_begin and compiler_end annotation ops for equivalence check to
-// reduce registry lookup overhead.
-static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
-static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
-
 /*! \brief This struct maintains the required metadata for a region to generate a corresponding
  * global function and function call. Global function will be passed to the target specific codegen
  * and function call will be used in the transform Relay graph to invoke the function in runtime.
@@ -123,8 +119,7 @@ class Partitioner : public MixedModeMutator {
       BaseFunc f_func = f.second;
 
       // Creating regionset per function in the module.
-      auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
-                                                   partitioning::compiler_end_op);
+      auto region_set = AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp());
       regions_sets_[region_set] = f_func;
     }
   }
@@ -133,7 +128,7 @@ class Partitioner : public MixedModeMutator {
     auto op_node = call->op.as<OpNode>();
     if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
       return post;
-    } else if (call->op == compiler_begin_op) {
+    } else if (call->op == CompilerBeginOp()) {
       // The annotation node is inserted on edge so it must have only one argument.
       CHECK_EQ(call->args.size(), 1U);
 
@@ -143,7 +138,7 @@ class Partitioner : public MixedModeMutator {
 
       // Backtrace the parent to find the first ancestor node that is not a begin or end op
       while (const auto* parent_call = parent.as<CallNode>()) {
-        if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) {
+        if (parent_call->op == CompilerBeginOp() || parent_call->op == CompilerEndOp()) {
           parent = parent_call->args[0];
         } else {
           break;
@@ -174,7 +169,7 @@ class Partitioner : public MixedModeMutator {
         return std::move(var);
       }
     } else {
-      CHECK_EQ(call->op, compiler_end_op);
+      CHECK_EQ(call->op, CompilerEndOp());
       // The annotation node is inserted on edge so it must have only one
       // argument.
       CHECK_EQ(call->args.size(), 1U);
@@ -420,7 +415,7 @@ IRModule FlattenTupleOutputs(IRModule module) {
     TupleOutFlattener() = default;
 
     Expr Rewrite_(const CallNode* call, const Expr& post) final {
-      if (call->op == compiler_end_op) {
+      if (call->op == CompilerEndOp()) {
         std::string target = call->attrs.as<CompilerAttrs>()->compiler;
         // Arguments of annotation ops should be 1
         CHECK_EQ(call->args.size(), 1U);
index cbdd4b4..35bbb23 100644 (file)
@@ -115,6 +115,26 @@ inline bool IsAtomic(const Expr& e) {
   return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
 }
 
+/*!
+ * \brief Cache the compiler_begin annotation op to reduce registry lookup overhead
+ * \param void
+ * \return compiler_begin op
+ */
+inline const Op& CompilerBeginOp() {
+  static Op op = Op::Get("annotation.compiler_begin");
+  return op;
+}
+
+/*!
+ * \brief Cache the compiler_end annotation op to reduce registry lookup overhead
+ * \param void
+ * \return compiler_end op
+ */
+inline const Op& CompilerEndOp() {
+  static Op op = Op::Get("annotation.compiler_end");
+  return op;
+}
+
 template <typename ConditionObjectPtr>
 struct TreeNode {
   typedef std::shared_ptr<TreeNode<ConditionObjectPtr>> pointer;
index c8e6671..e5d5ca9 100644 (file)
@@ -24,6 +24,7 @@
 #include <tvm/runtime/object.h>
 #include <tvm/runtime/registry.h>
 
+#include <iostream>
 #include <mutex>
 #include <string>
 #include <unordered_map>