From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Tue, 31 Dec 2019 00:33:50 +0000 (-0800) Subject: [relay][refactor] Cache Op::Get in passes to reduce lookup overhead (#4594) X-Git-Tag: upstream/0.7.0~1463 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=475158f6285c63b42efe574cb9ba8afec24261be;p=platform%2Fupstream%2Ftvm.git [relay][refactor] Cache Op::Get in passes to reduce lookup overhead (#4594) * Refactor to use IsOp utility * retrigger CI --- diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 7f1ef45..90f2937 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -594,12 +594,11 @@ inline ValueType OpMap::get(const Expr& expr, return map_.get(expr, def_value); } - /*! - * \brief Check that an expression is a "primtive operator". + * \brief Check that an expression is a "primitive operator". * * Will return true if the expression is an operator which - * matches the form of primtive operators registered directly + * matches the form of primitive operators registered directly * by the Relay codebase. * * That is the arguments are all type variables, and there is a single diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index fad1cc6..7c33ac9 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -21,6 +21,8 @@ * \file relay/backend/compile_engine.cc * \brief Internal compialtion engine. */ +#include "compile_engine.h" + #include #include #include @@ -29,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +41,6 @@ #include #include #include "../ir/type_functor.h" -#include "compile_engine.h" namespace tvm { namespace relay { @@ -102,7 +104,7 @@ class ScheduleGetter : public ExprFunctor(const Expr&)> { public: explicit ScheduleGetter(Target target) - : target_(target) {} + : target_(target), device_copy_op_(Op::Get("device_copy")) {} std::pair Create(const Function& prim_func) { static auto fschedule = @@ -250,11 +252,9 @@ class ScheduleGetter : CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - // Check if the op is a device copy op. - bool is_copy_op = op.same_as(Op::Get("device_copy")); Array outputs; // Skip fcompute for device copy operators as it is not registered. - if (is_copy_op) { + if (op == device_copy_op_) { const auto* copy_input = inputs[0].operator->(); outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype, Operation(), 0)); @@ -282,7 +282,7 @@ class ScheduleGetter : } // Set the name to `__copy`. It will be detected in graph runtime to perform // data copy across devices. - if (is_copy_op) { + if (op == device_copy_op_) { readable_name_stream_.str(std::string()); readable_name_stream_ << "__copy"; } else { @@ -332,6 +332,9 @@ class ScheduleGetter : std::ostringstream readable_name_stream_; std::unordered_map, NodeHash, NodeEqual> memo_; Array scalars_; + // Cache device copy op for equivalence checking to reduce registry lookup + // overhead for each invocation of call node when retrieving schedules. + const Op& device_copy_op_; }; // Creates shape function from functor. diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 1c57e9d..b5fd0c9 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -246,10 +246,12 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: - Interpreter(Module mod, - DLContext context, - Target target) - : mod_(mod), context_(context), target_(target) { + Interpreter(Module mod, DLContext context, Target target) + : mod_(mod), + context_(context), + target_(target), + debug_op_(Op::Get("debug")), + shape_of_op_(Op::Get("shape_of")) { engine_ = CompileEngine::Global(); } @@ -263,7 +265,7 @@ class Interpreter : stack_.current_frame().locals.Set(id, v); } - inline Value Lookup(const Var& local) { + Value Lookup(const Var& local) { return stack_.Lookup(local); } @@ -307,7 +309,7 @@ class Interpreter : return TupleValueNode::make(values); } - inline Value MakeClosure(const Function& func, Var letrec_name = Var()) { + Value MakeClosure(const Function& func, Var letrec_name = Var()) { tvm::Map captured_mod; Array free_vars = FreeVars(func); @@ -454,9 +456,9 @@ class Interpreter : Value InvokePrimitiveOp(const Function& func, const Array& args) { - auto call_node = func->body.as(); + const auto* call_node = func->body.as(); - if (call_node && call_node->op == Op::Get("debug")) { + if (call_node && call_node->op == debug_op_) { auto dattrs = call_node->attrs.as(); auto interp_state = this->get_state(call_node->args[0]); @@ -540,7 +542,7 @@ class Interpreter : Array out_shapes; auto ret_type = func->body->checked_type(); bool is_dyn = IsDynamic(func->checked_type()); - if (call_node->op == Op::Get("shape_of")) { + if (call_node->op == shape_of_op_) { // The output shape of shape_of must be static since Relay doesn't support // dynamic rank tensors. is_dyn = false; @@ -782,6 +784,9 @@ class Interpreter : Stack stack_; // Backend compile engine. CompileEngine engine_; + // Cache ops that need to be frequently used later to reduce lookup overhead. + const Op& debug_op_; + const Op& shape_of_op_; }; diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc index fa306ea..6913eb2 100644 --- a/src/relay/pass/canonicalize_cast.cc +++ b/src/relay/pass/canonicalize_cast.cc @@ -62,6 +62,8 @@ namespace relay { // \endcode class CastCanonicalizer : public ExprMutator { public: + CastCanonicalizer() : cast_op_(Op::Get("cast")) {} + Expr VisitExpr_(const CallNode* call) { static auto fpattern = Op::GetAttr("TOpPattern"); @@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator { private: std::unordered_map ref_counter_; + // cast op is frequently checked for equivalence. Therefore, we cache it to + // reduce lookup overhead. + const Op& cast_op_; + Expr GetNewCallArg(const Expr& e) { // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor - - static auto& cast = Op::Get("cast"); Expr new_expr = this->VisitExpr(e); if (const CallNode* call = e.as()) { - if (call->op.same_as(cast)) { + if (call->op == cast_op_) { auto attrs = call->attrs.as(); const auto* from_type = call->args[0]->type_as(); CHECK(from_type); @@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator { if (++ref_counter_[call] > 1) { const CallNode* new_call = new_expr.as(); CHECK(new_call); - CHECK(new_call->op.same_as(cast)); + CHECK(new_call->op == cast_op_); return CallNode::make(new_call->op, new_call->args, new_call->attrs, new_call->type_args); } diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 9755154..64b702c 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -24,6 +24,7 @@ */ #include #include +#include #include #include #include "pattern_util.h" @@ -33,10 +34,11 @@ namespace relay { class BiasAddSimplifier : public ExprMutator { public: + BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {} + Expr VisitExpr_(const CallNode* n) { - static const Op& bias_add = Op::Get("nn.bias_add"); auto new_n = ExprMutator::VisitExpr_(n); - if (n->op.same_as(bias_add)) { + if (n->op == bias_add_op_) { Call call = Downcast(new_n); CHECK_EQ(call->args.size(), 2); const BiasAddAttrs* param = call->attrs.as(); @@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator { } return new_n; } + + private: + // Cache the bias_add for equivalence checking. + const Op& bias_add_op_; }; Expr CanonicalizeOps(const Expr& e) { diff --git a/src/relay/pass/combine_parallel_op.cc b/src/relay/pass/combine_parallel_op.cc index 081216c..6b9926c 100644 --- a/src/relay/pass/combine_parallel_op.cc +++ b/src/relay/pass/combine_parallel_op.cc @@ -27,29 +27,30 @@ #include #include #include +#include #include #include +#include +#include #include #include -#include "./expr_subst.h" -#include "./pattern_util.h" -#include "./combine_parallel_op.h" +#include "expr_subst.h" +#include "pattern_util.h" +#include "combine_parallel_op.h" namespace tvm { namespace relay { -BranchGroupFinder::BranchGroupFinder(const std::string& op_name, +BranchGroupFinder::BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops) - : op_name_(op_name), + : cached_op_(op), fis_supported_op_(fis_supported_op), fare_compatible_ops_(fare_compatible_ops) { } std::vector BranchGroupFinder::Find(const Expr& expr) { - const Op& op = Op::Get(op_name_); - this->VisitExpr(expr); std::vector groups; @@ -57,7 +58,7 @@ std::vector BranchGroupFinder::Find(const Expr& expr) { const auto& children = children_map_.at(root); size_t ngroups = groups.size(); for (const CallNode* child : children) { - if (!child->op.same_as(op)) continue; + if (child->op != cached_op_) continue; auto&& branch = CreateBranch(child); // add the branch to a group, or create a new group @@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) { } void BranchGroupFinder::VisitExpr_(const CallNode* n) { - const Op& op = Op::Get(op_name_); ExprVisitor::VisitExpr_(n); - if (n->op.same_as(op) && fis_supported_op_(n)) { + if (n->op == cached_op_ && fis_supported_op_(n)) { op_roots_.insert(n->args[0]); children_map_[n->args[0]].push_back(n); } else { @@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) { } ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) - : op_name_(op_name), + : cached_op_(Op::Get(op_name)), min_num_branches_(min_num_branches) { } Expr ParallelOpCombiner::Combine(const Expr& expr) { - auto groups = BranchGroupFinder(op_name_, + auto groups = BranchGroupFinder(cached_op_, [&](const CallNode* n) { return IsSupportedOp(n); }, diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index 9004b04..858926e 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor { public: /* * \brief Constructor - * \param op_name name of op to start each group + * \param op The op that indicates the start of each group * \param fis_supported_op function that returns true if op * is supported for combining * \param fare_compatible_ops function that returns true if * two ops are compatible for combining */ - BranchGroupFinder(const std::string& op_name, + BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops); @@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor { std::vector Find(const Expr& expr); private: - /* \brief name of op to find parallel branches for */ - std::string op_name_; + /* \brief Cache the op for finding parallel branches */ + const Op& cached_op_; /* \brief function to return true if op is eligible to be combined, * false otherwise @@ -205,8 +205,8 @@ class ParallelOpCombiner { ExprSubstMap* subst_map) = 0; private: - /* \brief name of op to be combined */ - std::string op_name_; + /* \brief Cache the op to be combined */ + const Op& cached_op_; /* \brief minimum number of parallel branches to combine */ uint64_t min_num_branches_; diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index c37fdae..1e22571 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -33,7 +34,6 @@ namespace relay { using FInterpreter = runtime::TypedPackedFunc; - class ConstantChecker : private ExprVisitor { public: // Check whether an expression is constant. The results are memoized. @@ -78,8 +78,14 @@ TVM_REGISTER_API("relay._analysis.check_constant") class ConstantFolder : public ExprMutator { public: explicit ConstantFolder(FInterpreter executor, Module module) - : executor_(executor), module_(module) { - } + : executor_(executor), + module_(module), + shape_of_op_(Op::Get("shape_of")), + invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")), + shape_func_op_(Op::Get("memory.shape_func")), + alloc_tensor_op_(Op::Get("memory.alloc_tensor")), + alloc_storage_op_(Op::Get("memory.alloc_storage")), + cast_op_(Op::Get("cast")) {} Expr VisitExpr_(const LetNode* op) final { Expr value = this->Mutate(op->value); @@ -119,15 +125,15 @@ class ConstantFolder : public ExprMutator { // skip stateful ops. if (op_stateful.get(GetRef(op), false)) return res; // Try to evaluate shape_of op - if (call->op.same_as(Op::Get("shape_of"))) { + if (call->op == shape_of_op_) { return EvaluateShapeOf(res, origin_args, call->attrs); } // We should think about potentially constant evaluation over these ops too. - if (call->op.same_as(Op::Get("memory.invoke_tvm_op")) || - call->op.same_as(Op::Get("memory.shape_func")) || - call->op.same_as(Op::Get("memory.alloc_tensor")) || - call->op.same_as(Op::Get("memory.alloc_storage"))) { + if (call->op == invoke_tvm_op_ || + call->op == shape_func_op_ || + call->op == alloc_tensor_op_ || + call->op == alloc_storage_op_) { return GetRef(call); } @@ -162,6 +168,14 @@ class ConstantFolder : public ExprMutator { // Module Module module_; + // Cache the following ops for equivalence checking in this pass. + const Op& shape_of_op_; + const Op& invoke_tvm_op_; + const Op& shape_func_op_; + const Op& alloc_tensor_op_; + const Op& alloc_storage_op_; + const Op& cast_op_; + // Convert value to expression. Expr ValueToExpr(Value value) { if (const auto* val = value.as()) { @@ -254,8 +268,7 @@ class ConstantFolder : public ExprMutator { // Cast the constant into correct dtype auto cast_attrs = make_node(); cast_attrs->dtype = param->dtype; - static const Op& cast_op = Op::Get("cast"); - Expr ret = CallNode::make(cast_op, { shape }, Attrs(cast_attrs), {}); + Expr ret = CallNode::make(cast_op_, { shape }, Attrs(cast_attrs), {}); return ConstEvaluate(ret); } }; diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index ba7feed..8209a80 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -78,6 +78,8 @@ using common::LinkedList; constexpr uint32_t kMaxFusedOps = 256; +static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion"); + /*! * \brief Indexed data flow graph in forward direction. * This is a temporary data structure used for operator fusion analysis. @@ -860,7 +862,6 @@ class FuseMutator : private ExprMutator { // Transform calls. Expr VisitExpr_(const CallNode* call) { - static const Op& stop_fusion = Op::Get("annotation.stop_fusion"); if (call->op.as()) { static auto fnoncomputational = Op::GetAttr("TNonComputational"); @@ -872,7 +873,7 @@ class FuseMutator : private ExprMutator { // If it is a primitive op call // then we must have a group assignment for it already. CHECK(gmap_.count(call)); - if (call->op.same_as(stop_fusion)) { + if (call->op == stop_fusion_op) { return ExprMutator::VisitExpr(call->args[0]); } auto* ret_group = gmap_.at(call)->FindRoot(); diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index a44568a..afcc493 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -559,30 +559,28 @@ struct WithFuncIdAttrs : public tvm::AttrsNode { TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); -Op WithFuncIdOp() { - static const Op& op = Op::Get("annotation.with_funcid"); - return op; -} - -Expr MkWithFuncId(const Expr& expr, FuncId fid) { - auto attrs = make_node(); - attrs->fid = fid; - return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {}); -} - RELAY_REGISTER_OP("annotation.with_funcid") .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("func", "Function", "The input data."); +// Cache with_funcid op to reduce lookup overhead during traversal. +static const Op& with_funcid_op = Op::Get("annotation.with_funcid"); + +Expr MkWithFuncId(const Expr& expr, FuncId fid) { + auto attrs = make_node(); + attrs->fid = fid; + return CallNode::make(with_funcid_op, {expr}, Attrs(attrs), {}); +} + Expr StripWithFuncId(const Expr& e); Function AsFunc(const Expr& e) { if (e.as()) { return Downcast(e); } else if (const CallNode* c = e.as()) { - CHECK(c->op.same_as(WithFuncIdOp())); + CHECK(c->op == with_funcid_op); CHECK_EQ(c->args.size(), 1); return AsFunc(c->args[0]); } else { @@ -604,7 +602,7 @@ class PartialEvaluator : public ExprFunctor PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) { if (const CallNode* c = e.as()) { - if (c->op.same_as(WithFuncIdOp())) { + if (c->op == with_funcid_op) { CHECK_EQ(c->args.size(), 1); return VisitExpr(c->args[0], ll, name); } @@ -722,7 +720,7 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const CallNode* op, LetList* ll) final { - if (op->op.same_as(WithFuncIdOp())) { + if (op->op == with_funcid_op) { CHECK_EQ(op->args.size(), 1); return VisitExpr(op->args[0], ll); } @@ -1096,7 +1094,7 @@ class PartialEvaluator : public ExprFunctor explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(WithFuncIdOp())) { + if (op->op == with_funcid_op) { CHECK_EQ(op->args.size(), 1); CHECK(op->attrs.defined()); CHECK(op->attrs.as()); @@ -1194,7 +1192,7 @@ Expr Remap(const Expr& e) { Expr StripWithFuncId(const Expr& e) { struct StripWithFuncIdMutator : ExprMutator, PatternMutator { Expr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(WithFuncIdOp())) { + if (op->op == with_funcid_op) { CHECK_EQ(op->args.size(), 1); return VisitExpr(op->args[0]); } else { diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index 3ecf449..e78abbf 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -25,15 +25,17 @@ */ #include #include +#include #include "./quantize.h" - namespace tvm { namespace relay { namespace quantize { class StatsCollector : private ExprMutator { public: + StatsCollector() : simulated_quantize_op_(Op::Get("relay.op.annotation.simulated_quantize")) {} + Expr Collect(const Expr& expr) { auto new_e = this->Mutate(expr); const FunctionNode* func = new_e.as(); @@ -45,13 +47,13 @@ class StatsCollector : private ExprMutator { private: Array profile_data_; + const Op& simulated_quantize_op_; Expr VisitExpr_(const CallNode* call) { - static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); Expr new_e = ExprMutator::VisitExpr_(call); const CallNode* new_call = new_e.as(); CHECK(new_call); - if (new_call->op.same_as(simulated_quantize)) { + if (new_call->op == simulated_quantize_op_) { auto attrs = new_call->attrs.as(); // rewrite the annotation auto new_attrs = make_node(); diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index e28120d..acd5163 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -91,7 +91,6 @@ Expr LayerNormToInferUnpack(const Attrs attrs, return out; } - Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, @@ -125,23 +124,25 @@ Expr InstanceNormToInferUnpack(const Attrs attrs, return out; } - class InferenceSimplifier : public ExprMutator { public: - Expr VisitExpr_(const TupleGetItemNode* n) final { - static const Op& batch_norm = Op::Get("nn.batch_norm"); - static const Op& dropout = Op::Get("nn.dropout"); + InferenceSimplifier() + : batch_norm_op_(Op::Get("nn.batch_norm")), + dropout_op_(Op::Get("nn.dropout")), + instance_norm_op_(Op::Get("nn.instance_norm")), + layer_norm_op_(Op::Get("nn.layer_norm")) {} + Expr VisitExpr_(const TupleGetItemNode* n) final { Expr new_e = ExprMutator::VisitExpr_(n); const auto* new_n = new_e.as(); if (new_n->index != 0) { return new_e; } if (const auto* call = new_n->tuple.as()) { - if (call->op.same_as(batch_norm)) { + if (call->op == batch_norm_op_) { return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], call->args[3], call->args[4], ty_map_.at(call->args[0])); - } else if (call->op.same_as(dropout)) { + } else if (call->op == dropout_op_) { return call->args[0]; } } @@ -149,17 +150,14 @@ class InferenceSimplifier : public ExprMutator { } Expr VisitExpr_(const CallNode* n) { - static const Op& batch_norm = Op::Get("nn.batch_norm"); - static const Op& instance_norm = Op::Get("nn.instance_norm"); - static const Op& layer_norm = Op::Get("nn.layer_norm"); auto new_n = ExprMutator::VisitExpr_(n); - if (n->op.same_as(batch_norm)) { + if (n->op == batch_norm_op_) { ty_map_[new_n.as()->args[0]] = n->args[0]->checked_type(); - } else if (n->op.same_as(layer_norm)) { + } else if (n->op == layer_norm_op_) { const auto* call = new_n.as(); return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], n->args[0]->checked_type()); - } else if (n->op.same_as(instance_norm)) { + } else if (n->op == instance_norm_op_) { const auto* call = new_n.as(); return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], n->args[0]->checked_type()); @@ -168,6 +166,13 @@ class InferenceSimplifier : public ExprMutator { } private: + // Cache the following ops. They will be used in the passes repeatedly for + // operator equivalence checking so that the registry lookup overhead can be + // reduced. + const Op& batch_norm_op_; + const Op& dropout_op_; + const Op& instance_norm_op_; + const Op& layer_norm_op_; std::unordered_map ty_map_; }; diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 334c98b..17c527b 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -25,6 +25,7 @@ */ #include #include +#include #include #include "pass_util.h" #include "../ir/type_functor.h" @@ -360,13 +361,14 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) { return true; } +// Cache the operators that are checked recursively to reduce lookup overhead. +static const auto& expand_dims_op = Op::Get("expand_dims"); +static const auto& reshape_op = Op::Get("reshape"); +static const auto& transpose_op = Op::Get("transpose"); +static const auto& squeeze_op = Op::Get("squeeze"); + bool IsAllPositiveConstant(const Expr& expr) { // peel through a few common transform ops. - static const auto& expand_dims = Op::Get("expand_dims"); - static const auto& reshape = Op::Get("reshape"); - static const auto& transpose = Op::Get("transpose"); - static const auto& squeeze = Op::Get("squeeze"); - if (const auto* constant = expr.as()) { const auto& tensor = constant->data; const auto& dtype = tensor->dtype; @@ -389,10 +391,10 @@ bool IsAllPositiveConstant(const Expr& expr) { } } else if (const auto* op = expr.as()) { // tail recursion. - if (op->op.same_as(expand_dims) || - op->op.same_as(reshape) || - op->op.same_as(transpose) || - op->op.same_as(squeeze)) { + if (op->op == expand_dims_op || + op->op == reshape_op || + op->op == transpose_op || + op->op == squeeze_op) { return IsAllPositiveConstant(op->args[0]); } else { return false;