[relay][refactor] Cache Op::Get in passes to reduce lookup overhead (#4594)
authorZhi <5145158+zhiics@users.noreply.github.com>
Tue, 31 Dec 2019 00:33:50 +0000 (16:33 -0800)
committermasahi <masahi129@gmail.com>
Tue, 31 Dec 2019 00:33:50 +0000 (09:33 +0900)
* Refactor to use IsOp utility

* retrigger CI

13 files changed:
include/tvm/relay/op.h
src/relay/backend/compile_engine.cc
src/relay/backend/interpreter.cc
src/relay/pass/canonicalize_cast.cc
src/relay/pass/canonicalize_ops.cc
src/relay/pass/combine_parallel_op.cc
src/relay/pass/combine_parallel_op.h
src/relay/pass/fold_constant.cc
src/relay/pass/fuse_ops.cc
src/relay/pass/partial_eval.cc
src/relay/pass/quantize/calibrate.cc
src/relay/pass/simplify_inference.cc
src/relay/pass/util.cc

index 7f1ef45..90f2937 100644 (file)
@@ -594,12 +594,11 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
   return map_.get<ValueType>(expr, def_value);
 }
 
-
 /*!
- * \brief Check that an expression is a "primtive operator".
+ * \brief Check that an expression is a "primitive operator".
  *
  * Will return true if the expression is an operator which
- * matches the form of primtive operators registered directly
+ * matches the form of primitive operators registered directly
  * by the Relay codebase.
  *
  * That is the arguments are all type variables, and there is a single
index fad1cc6..7c33ac9 100644 (file)
@@ -21,6 +21,8 @@
  * \file relay/backend/compile_engine.cc
  * \brief Internal compialtion engine.
  */
+#include "compile_engine.h"
+
 #include <tvm/schedule.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/operation.h>
@@ -29,6 +31,7 @@
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <topi/tags.h>
 #include <utility>
@@ -38,7 +41,6 @@
 #include <vector>
 #include <unordered_map>
 #include "../ir/type_functor.h"
-#include "compile_engine.h"
 
 namespace tvm {
 namespace relay {
@@ -102,7 +104,7 @@ class ScheduleGetter :
       public ExprFunctor<Array<Tensor>(const Expr&)> {
  public:
   explicit ScheduleGetter(Target target)
-      : target_(target) {}
+      : target_(target), device_copy_op_(Op::Get("device_copy")) {}
 
   std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
     static auto fschedule =
@@ -250,11 +252,9 @@ class ScheduleGetter :
     CHECK(call_node->op.as<OpNode>())
         << "Primitive function only allows call into primitive ops";
     Op op = Downcast<Op>(call_node->op);
-    // Check if the op is a device copy op.
-    bool is_copy_op = op.same_as(Op::Get("device_copy"));
     Array<Tensor> outputs;
     // Skip fcompute for device copy operators as it is not registered.
-    if (is_copy_op) {
+    if (op == device_copy_op_) {
       const auto* copy_input = inputs[0].operator->();
       outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
                                          Operation(), 0));
@@ -282,7 +282,7 @@ class ScheduleGetter :
     }
     // Set the name to `__copy`. It will be detected in graph runtime to perform
     // data copy across devices.
-    if (is_copy_op) {
+    if (op == device_copy_op_) {
       readable_name_stream_.str(std::string());
       readable_name_stream_ << "__copy";
     } else {
@@ -332,6 +332,9 @@ class ScheduleGetter :
   std::ostringstream readable_name_stream_;
   std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
   Array<Operation> scalars_;
+  // Cache device copy op for equivalence checking to reduce registry lookup
+  // overhead for each invocation of call node when retrieving schedules.
+  const Op& device_copy_op_;
 };
 
 // Creates shape function from functor.
index 1c57e9d..b5fd0c9 100644 (file)
@@ -246,10 +246,12 @@ class Interpreter :
       public ExprFunctor<Value(const Expr& n)>,
              PatternFunctor<bool(const Pattern& p, const Value& v)> {
  public:
-  Interpreter(Module mod,
-              DLContext context,
-              Target target)
-      : mod_(mod), context_(context), target_(target) {
+  Interpreter(Module mod, DLContext context, Target target)
+      : mod_(mod),
+        context_(context),
+        target_(target),
+        debug_op_(Op::Get("debug")),
+        shape_of_op_(Op::Get("shape_of")) {
     engine_ = CompileEngine::Global();
   }
 
@@ -263,7 +265,7 @@ class Interpreter :
     stack_.current_frame().locals.Set(id, v);
   }
 
-  inline Value Lookup(const Var& local) {
+  Value Lookup(const Var& local) {
     return stack_.Lookup(local);
   }
 
@@ -307,7 +309,7 @@ class Interpreter :
     return TupleValueNode::make(values);
   }
 
-  inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
+  Value MakeClosure(const Function& func, Var letrec_name = Var()) {
     tvm::Map<Var, Value> captured_mod;
     Array<Var> free_vars = FreeVars(func);
 
@@ -454,9 +456,9 @@ class Interpreter :
 
   Value InvokePrimitiveOp(const Function& func,
                           const Array<Value>& args) {
-    auto call_node = func->body.as<CallNode>();
+    const auto* call_node = func->body.as<CallNode>();
 
-    if (call_node && call_node->op == Op::Get("debug")) {
+    if (call_node && call_node->op == debug_op_) {
       auto dattrs = call_node->attrs.as<DebugAttrs>();
       auto interp_state = this->get_state(call_node->args[0]);
 
@@ -540,7 +542,7 @@ class Interpreter :
     Array<Shape> out_shapes;
     auto ret_type = func->body->checked_type();
     bool is_dyn = IsDynamic(func->checked_type());
-    if (call_node->op == Op::Get("shape_of")) {
+    if (call_node->op == shape_of_op_) {
       // The output shape of shape_of must be static since Relay doesn't support
       // dynamic rank tensors.
       is_dyn = false;
@@ -782,6 +784,9 @@ class Interpreter :
   Stack stack_;
   // Backend compile engine.
   CompileEngine engine_;
+  // Cache ops that need to be frequently used later to reduce lookup overhead.
+  const Op& debug_op_;
+  const Op& shape_of_op_;
 };
 
 
index fa306ea..6913eb2 100644 (file)
@@ -62,6 +62,8 @@ namespace relay {
 // \endcode
 class CastCanonicalizer : public ExprMutator {
  public:
+  CastCanonicalizer() : cast_op_(Op::Get("cast")) {}
+
   Expr VisitExpr_(const CallNode* call) {
     static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
 
@@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator {
 
  private:
   std::unordered_map<const Node*, size_t> ref_counter_;
+  // cast op is frequently checked for equivalence. Therefore, we cache it to
+  // reduce lookup overhead.
+  const Op& cast_op_;
+
 
   Expr GetNewCallArg(const Expr& e) {
     // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
-
-    static auto& cast = Op::Get("cast");
     Expr new_expr = this->VisitExpr(e);
 
     if (const CallNode* call = e.as<CallNode>()) {
-      if (call->op.same_as(cast)) {
+      if (call->op == cast_op_) {
         auto attrs = call->attrs.as<CastAttrs>();
         const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
         CHECK(from_type);
@@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator {
           if (++ref_counter_[call] > 1) {
             const CallNode* new_call = new_expr.as<CallNode>();
             CHECK(new_call);
-            CHECK(new_call->op.same_as(cast));
+            CHECK(new_call->op == cast_op_);
             return CallNode::make(new_call->op, new_call->args, new_call->attrs,
                  new_call->type_args);
           }
index 9755154..64b702c 100644 (file)
@@ -24,6 +24,7 @@
  */
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/transform.h>
 #include "pattern_util.h"
@@ -33,10 +34,11 @@ namespace relay {
 
 class BiasAddSimplifier : public ExprMutator {
  public:
+  BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}
+
   Expr VisitExpr_(const CallNode* n) {
-    static const Op& bias_add = Op::Get("nn.bias_add");
     auto new_n = ExprMutator::VisitExpr_(n);
-    if (n->op.same_as(bias_add)) {
+    if (n->op == bias_add_op_) {
       Call call = Downcast<Call>(new_n);
       CHECK_EQ(call->args.size(), 2);
       const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
@@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator {
     }
     return new_n;
   }
+
+ private:
+  // Cache the bias_add for equivalence checking.
+  const Op& bias_add_op_;
 };
 
 Expr CanonicalizeOps(const Expr& e) {
index 081216c..6b9926c 100644 (file)
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/transform.h>
+#include <algorithm>
+#include <utility>
 #include <unordered_map>
 #include <unordered_set>
-#include "./expr_subst.h"
-#include "./pattern_util.h"
-#include "./combine_parallel_op.h"
+#include "expr_subst.h"
+#include "pattern_util.h"
+#include "combine_parallel_op.h"
 
 
 namespace tvm {
 namespace relay {
 
-BranchGroupFinder::BranchGroupFinder(const std::string& op_name,
+BranchGroupFinder::BranchGroupFinder(const Op& op,
                                      FIsSupportedOp fis_supported_op,
                                      FAreCompatibleOps fare_compatible_ops)
-  : op_name_(op_name),
+  : cached_op_(op),
     fis_supported_op_(fis_supported_op),
     fare_compatible_ops_(fare_compatible_ops) {
 }
 
 std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
-  const Op& op = Op::Get(op_name_);
-
   this->VisitExpr(expr);
 
   std::vector<Group> groups;
@@ -57,7 +58,7 @@ std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
     const auto& children = children_map_.at(root);
     size_t ngroups = groups.size();
     for (const CallNode* child : children) {
-      if (!child->op.same_as(op)) continue;
+      if (child->op != cached_op_) continue;
 
       auto&& branch = CreateBranch(child);
       // add the branch to a group, or create a new group
@@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
 }
 
 void BranchGroupFinder::VisitExpr_(const CallNode* n) {
-  const Op& op = Op::Get(op_name_);
   ExprVisitor::VisitExpr_(n);
-  if (n->op.same_as(op) && fis_supported_op_(n)) {
+  if (n->op == cached_op_ && fis_supported_op_(n)) {
     op_roots_.insert(n->args[0]);
     children_map_[n->args[0]].push_back(n);
   } else {
@@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) {
 }
 
 ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches)
-  : op_name_(op_name),
+  : cached_op_(Op::Get(op_name)),
     min_num_branches_(min_num_branches) {
 }
 
 Expr ParallelOpCombiner::Combine(const Expr& expr) {
-  auto groups = BranchGroupFinder(op_name_,
+  auto groups = BranchGroupFinder(cached_op_,
                                   [&](const CallNode* n) {
                                     return IsSupportedOp(n);
                                   },
index 9004b04..858926e 100644 (file)
@@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor {
  public:
   /*
    * \brief Constructor
-   * \param op_name name of op to start each group
+   * \param op The op that indicates the start of each group
    * \param fis_supported_op function that returns true if op
    *                         is supported for combining
    * \param fare_compatible_ops function that returns true if
    *                            two ops are compatible for combining
    */
-  BranchGroupFinder(const std::string& op_name,
+  BranchGroupFinder(const Op& op,
                     FIsSupportedOp fis_supported_op,
                     FAreCompatibleOps fare_compatible_ops);
 
@@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor {
   std::vector<Group> Find(const Expr& expr);
 
  private:
-  /* \brief name of op to find parallel branches for */
-  std::string op_name_;
+  /* \brief Cache the op for finding parallel branches */
+  const Op& cached_op_;
 
   /* \brief function to return true if op is eligible to be combined,
    *         false otherwise 
@@ -205,8 +205,8 @@ class ParallelOpCombiner {
                                  ExprSubstMap* subst_map) = 0;
 
  private:
-  /* \brief name of op to be combined */
-  std::string op_name_;
+  /* \brief Cache the op to be combined */
+  const Op& cached_op_;
 
   /* \brief minimum number of parallel branches to combine */
   uint64_t min_num_branches_;
index c37fdae..1e22571 100644 (file)
@@ -22,6 +22,7 @@
  */
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/attrs/transform.h>
@@ -33,7 +34,6 @@ namespace relay {
 
 using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
 
-
 class ConstantChecker : private ExprVisitor {
  public:
   // Check whether an expression is constant. The results are memoized.
@@ -78,8 +78,14 @@ TVM_REGISTER_API("relay._analysis.check_constant")
 class ConstantFolder : public ExprMutator {
  public:
   explicit ConstantFolder(FInterpreter executor, Module module)
-      : executor_(executor), module_(module) {
-  }
+      : executor_(executor),
+        module_(module),
+        shape_of_op_(Op::Get("shape_of")),
+        invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
+        shape_func_op_(Op::Get("memory.shape_func")),
+        alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
+        alloc_storage_op_(Op::Get("memory.alloc_storage")),
+        cast_op_(Op::Get("cast")) {}
 
   Expr VisitExpr_(const LetNode* op) final {
     Expr value = this->Mutate(op->value);
@@ -119,15 +125,15 @@ class ConstantFolder : public ExprMutator {
     // skip stateful ops.
     if (op_stateful.get(GetRef<Op>(op), false)) return res;
     // Try to evaluate shape_of op
-    if (call->op.same_as(Op::Get("shape_of"))) {
+    if (call->op == shape_of_op_) {
       return EvaluateShapeOf(res, origin_args, call->attrs);
     }
 
     // We should think about potentially constant evaluation over these ops too.
-    if (call->op.same_as(Op::Get("memory.invoke_tvm_op")) ||
-        call->op.same_as(Op::Get("memory.shape_func")) ||
-        call->op.same_as(Op::Get("memory.alloc_tensor")) ||
-        call->op.same_as(Op::Get("memory.alloc_storage"))) {
+    if (call->op == invoke_tvm_op_ ||
+        call->op == shape_func_op_ ||
+        call->op == alloc_tensor_op_ ||
+        call->op == alloc_storage_op_) {
       return GetRef<Call>(call);
     }
 
@@ -162,6 +168,14 @@ class ConstantFolder : public ExprMutator {
   // Module
   Module module_;
 
+  // Cache the following ops for equivalence checking in this pass.
+  const Op& shape_of_op_;
+  const Op& invoke_tvm_op_;
+  const Op& shape_func_op_;
+  const Op& alloc_tensor_op_;
+  const Op& alloc_storage_op_;
+  const Op& cast_op_;
+
   // Convert value to expression.
   Expr ValueToExpr(Value value) {
     if (const auto* val = value.as<TensorValueNode>()) {
@@ -254,8 +268,7 @@ class ConstantFolder : public ExprMutator {
     // Cast the constant into correct dtype
     auto cast_attrs = make_node<CastAttrs>();
     cast_attrs->dtype = param->dtype;
-    static const Op& cast_op = Op::Get("cast");
-    Expr ret = CallNode::make(cast_op, { shape }, Attrs(cast_attrs), {});
+    Expr ret = CallNode::make(cast_op_, { shape }, Attrs(cast_attrs), {});
     return ConstEvaluate(ret);
   }
 };
index ba7feed..8209a80 100644 (file)
@@ -78,6 +78,8 @@ using common::LinkedList;
 
 constexpr uint32_t kMaxFusedOps = 256;
 
+static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");
+
 /*!
  * \brief Indexed data flow graph in forward direction.
  *  This is a temporary data structure used for operator fusion analysis.
@@ -860,7 +862,6 @@ class FuseMutator : private ExprMutator {
 
   // Transform calls.
   Expr VisitExpr_(const CallNode* call) {
-    static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
     if (call->op.as<OpNode>()) {
       static auto fnoncomputational =
         Op::GetAttr<TNonComputational>("TNonComputational");
@@ -872,7 +873,7 @@ class FuseMutator : private ExprMutator {
       // If it is a primitive op call
       // then we must have a group assignment for it already.
       CHECK(gmap_.count(call));
-      if (call->op.same_as(stop_fusion)) {
+      if (call->op == stop_fusion_op) {
         return ExprMutator::VisitExpr(call->args[0]);
       }
       auto* ret_group = gmap_.at(call)->FindRoot();
index a44568a..afcc493 100644 (file)
@@ -559,30 +559,28 @@ struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
 
 TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
 
-Op WithFuncIdOp() {
-  static const Op& op = Op::Get("annotation.with_funcid");
-  return op;
-}
-
-Expr MkWithFuncId(const Expr& expr, FuncId fid) {
-  auto attrs = make_node<WithFuncIdAttrs>();
-  attrs->fid = fid;
-  return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
-}
-
 RELAY_REGISTER_OP("annotation.with_funcid")
 .describe(R"code(Annotate a function with a funcid.)code"
 TVM_ADD_FILELINE)
 .set_num_inputs(1)
 .add_argument("func", "Function", "The input data.");
 
+// Cache with_funcid op to reduce lookup overhead during traversal.
+static const Op& with_funcid_op = Op::Get("annotation.with_funcid");
+
+Expr MkWithFuncId(const Expr& expr, FuncId fid) {
+  auto attrs = make_node<WithFuncIdAttrs>();
+  attrs->fid = fid;
+  return CallNode::make(with_funcid_op, {expr}, Attrs(attrs), {});
+}
+
 Expr StripWithFuncId(const Expr& e);
 
 Function AsFunc(const Expr& e) {
   if (e.as<FunctionNode>()) {
     return Downcast<Function>(e);
   } else if (const CallNode* c = e.as<CallNode>()) {
-    CHECK(c->op.same_as(WithFuncIdOp()));
+    CHECK(c->op == with_funcid_op);
     CHECK_EQ(c->args.size(), 1);
     return AsFunc(c->args[0]);
   } else {
@@ -604,7 +602,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
 
   PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
     if (const CallNode* c = e.as<CallNode>()) {
-      if (c->op.same_as(WithFuncIdOp())) {
+      if (c->op == with_funcid_op) {
         CHECK_EQ(c->args.size(), 1);
         return VisitExpr(c->args[0], ll, name);
       }
@@ -722,7 +720,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
   }
 
   PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
-    if (op->op.same_as(WithFuncIdOp())) {
+    if (op->op == with_funcid_op) {
       CHECK_EQ(op->args.size(), 1);
       return VisitExpr(op->args[0], ll);
     }
@@ -1096,7 +1094,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
       explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
 
       void VisitExpr_(const CallNode* op) final {
-        if (op->op.same_as(WithFuncIdOp())) {
+        if (op->op == with_funcid_op) {
           CHECK_EQ(op->args.size(), 1);
           CHECK(op->attrs.defined());
           CHECK(op->attrs.as<WithFuncIdAttrs>());
@@ -1194,7 +1192,7 @@ Expr Remap(const Expr& e) {
 Expr StripWithFuncId(const Expr& e) {
   struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
     Expr VisitExpr_(const CallNode* op) final {
-      if (op->op.same_as(WithFuncIdOp())) {
+      if (op->op == with_funcid_op) {
         CHECK_EQ(op->args.size(), 1);
         return VisitExpr(op->args[0]);
       } else {
index 3ecf449..e78abbf 100644 (file)
  */
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
 #include "./quantize.h"
 
-
 namespace tvm {
 namespace relay {
 namespace quantize {
 
 class StatsCollector : private ExprMutator {
  public:
+  StatsCollector() : simulated_quantize_op_(Op::Get("relay.op.annotation.simulated_quantize")) {}
+
   Expr Collect(const Expr& expr) {
     auto new_e = this->Mutate(expr);
     const FunctionNode* func = new_e.as<FunctionNode>();
@@ -45,13 +47,13 @@ class StatsCollector : private ExprMutator {
 
  private:
   Array<Expr> profile_data_;
+  const Op& simulated_quantize_op_;
 
   Expr VisitExpr_(const CallNode* call) {
-    static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
     Expr new_e = ExprMutator::VisitExpr_(call);
     const CallNode* new_call = new_e.as<CallNode>();
     CHECK(new_call);
-    if (new_call->op.same_as(simulated_quantize)) {
+    if (new_call->op == simulated_quantize_op_) {
       auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
       // rewrite the annotation
       auto new_attrs = make_node<SimulatedQuantizeAttrs>();
index e28120d..acd5163 100644 (file)
@@ -91,7 +91,6 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
   return out;
 }
 
-
 Expr InstanceNormToInferUnpack(const Attrs attrs,
                                Expr data,
                                Expr gamma,
@@ -125,23 +124,25 @@ Expr InstanceNormToInferUnpack(const Attrs attrs,
   return out;
 }
 
-
 class InferenceSimplifier : public ExprMutator {
  public:
-  Expr VisitExpr_(const TupleGetItemNode* n) final {
-    static const Op& batch_norm = Op::Get("nn.batch_norm");
-    static const Op& dropout = Op::Get("nn.dropout");
+  InferenceSimplifier()
+      : batch_norm_op_(Op::Get("nn.batch_norm")),
+        dropout_op_(Op::Get("nn.dropout")),
+        instance_norm_op_(Op::Get("nn.instance_norm")),
+        layer_norm_op_(Op::Get("nn.layer_norm")) {}
 
+  Expr VisitExpr_(const TupleGetItemNode* n) final {
     Expr new_e = ExprMutator::VisitExpr_(n);
     const auto* new_n = new_e.as<TupleGetItemNode>();
     if (new_n->index != 0) {
       return new_e;
     }
     if (const auto* call = new_n->tuple.as<CallNode>()) {
-      if (call->op.same_as(batch_norm)) {
+      if (call->op == batch_norm_op_) {
         return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
                                       call->args[3], call->args[4], ty_map_.at(call->args[0]));
-      } else if (call->op.same_as(dropout)) {
+      } else if (call->op == dropout_op_) {
         return call->args[0];
       }
     }
@@ -149,17 +150,14 @@ class InferenceSimplifier : public ExprMutator {
   }
 
   Expr VisitExpr_(const CallNode* n) {
-    static const Op& batch_norm = Op::Get("nn.batch_norm");
-    static const Op& instance_norm = Op::Get("nn.instance_norm");
-    static const Op& layer_norm = Op::Get("nn.layer_norm");
     auto new_n = ExprMutator::VisitExpr_(n);
-    if (n->op.same_as(batch_norm)) {
+    if (n->op == batch_norm_op_) {
       ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
-    } else if (n->op.same_as(layer_norm)) {
+    } else if (n->op == layer_norm_op_) {
       const auto* call = new_n.as<CallNode>();
       return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
                                     call->args[2], n->args[0]->checked_type());
-    } else if (n->op.same_as(instance_norm)) {
+    } else if (n->op == instance_norm_op_) {
       const auto* call = new_n.as<CallNode>();
       return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
                                     call->args[2], n->args[0]->checked_type());
@@ -168,6 +166,13 @@ class InferenceSimplifier : public ExprMutator {
   }
 
  private:
+  // Cache the following ops. They will be used in the passes repeatedly for
+  // operator equivalence checking so that the registry lookup overhead can be
+  // reduced.
+  const Op& batch_norm_op_;
+  const Op& dropout_op_;
+  const Op& instance_norm_op_;
+  const Op& layer_norm_op_;
   std::unordered_map<Expr, Type, NodeHash, NodeEqual> ty_map_;
 };
 
index 334c98b..17c527b 100644 (file)
@@ -25,6 +25,7 @@
  */
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
 #include <tvm/relay/pattern_functor.h>
 #include "pass_util.h"
 #include "../ir/type_functor.h"
@@ -360,13 +361,14 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
   return true;
 }
 
+// Cache the operators that are checked recursively to reduce lookup overhead.
+static const auto& expand_dims_op = Op::Get("expand_dims");
+static const auto& reshape_op = Op::Get("reshape");
+static const auto& transpose_op = Op::Get("transpose");
+static const auto& squeeze_op = Op::Get("squeeze");
+
 bool IsAllPositiveConstant(const Expr& expr) {
   // peel through a few common transform ops.
-  static const auto& expand_dims = Op::Get("expand_dims");
-  static const auto& reshape = Op::Get("reshape");
-  static const auto& transpose = Op::Get("transpose");
-  static const auto& squeeze = Op::Get("squeeze");
-
   if (const auto* constant = expr.as<ConstantNode>()) {
     const auto& tensor = constant->data;
     const auto& dtype = tensor->dtype;
@@ -389,10 +391,10 @@ bool IsAllPositiveConstant(const Expr& expr) {
     }
   } else if (const auto* op = expr.as<CallNode>()) {
     // tail recursion.
-    if (op->op.same_as(expand_dims) ||
-        op->op.same_as(reshape) ||
-        op->op.same_as(transpose) ||
-        op->op.same_as(squeeze)) {
+    if (op->op == expand_dims_op ||
+        op->op == reshape_op ||
+        op->op == transpose_op ||
+        op->op == squeeze_op) {
       return IsAllPositiveConstant(op->args[0]);
     } else {
       return false;