return true;
}
-// Cache the operators that are checked recursively to reduce lookup overhead.
-static const auto& expand_dims_op = Op::Get("expand_dims");
-static const auto& reshape_op = Op::Get("reshape");
-static const auto& transpose_op = Op::Get("transpose");
-static const auto& squeeze_op = Op::Get("squeeze");
-
bool IsAllPositiveConstant(const Expr& expr) {
+ // Cache the operators that are checked recursively to reduce lookup overhead.
+ static const auto& expand_dims_op = Op::Get("expand_dims");
+ static const auto& reshape_op = Op::Get("reshape");
+ static const auto& transpose_op = Op::Get("transpose");
+ static const auto& squeeze_op = Op::Get("squeeze");
+
// peel through a few common transform ops.
if (const auto* constant = expr.as<ConstantNode>()) {
const auto& tensor = constant->data;
#include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>
+#include "pass_util.h"
+
namespace tvm {
namespace relay {
namespace annotate_target {
-static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
-static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
-
const PackedFunc* make_begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
std::string arg_target = "default";
const CallNode* call = arg.as<CallNode>();
- if (call && call->op == compiler_begin_op) {
+ if (call && call->op == CompilerBeginOp()) {
// Argument is already compiler begin node meaning that this is not the first time
// running this pass, so we simply remove it and will add a new one later.
CHECK_EQ(call->args.size(), 1U);
const CallNode* end = call->args[0].as<CallNode>();
- if (end->op == compiler_end_op) {
+ if (end->op == CompilerEndOp()) {
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
}
compiler_ends.push_back(call->args[0]);
auto op_node = pre->op.as<OpNode>();
// This graph has annotations, meaning that this is not the first time running this pass.
- if (op_node && pre->op == compiler_begin_op) {
+ if (op_node && pre->op == CompilerBeginOp()) {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
CHECK_EQ(pre->args.size(), 1U);
return post.as<CallNode>()->args[0];
- } else if (op_node && pre->op == compiler_end_op) {
+ } else if (op_node && pre->op == CompilerEndOp()) {
// Override compiler end with the new target.
CHECK_EQ(pre->args.size(), 1U);
auto input_expr = post.as<CallNode>()->args[0];
// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
- if (first_arg_call && first_arg_call->op == compiler_begin_op) {
+ if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
#include <vector>
#include "../analysis/annotated_region_set.h"
+#include "pass_util.h"
namespace tvm {
namespace relay {
namespace merge_compiler_region {
-// Cache compiler_begin and compiler_end annotation ops for equivalence check to
-// reduce registry lookup overhead.
-static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
-static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
-
class RegionMerger : public MixedModeVisitor {
public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
void VisitExpr_(const CallNode* call) final {
- if (call->op == compiler_end_op) {
+ if (call->op == CompilerEndOp()) {
auto region = regions_->GetRegion(GetRef<Call>(call));
// Skip this region if it has been merged to the other region.
// Region inputs must be begin annotation, and the region of
// the begin annotation's argument is the parent region.
auto begin = Downcast<Call>(arg);
- CHECK_EQ(begin->op, compiler_begin_op);
+ CHECK_EQ(begin->op, CompilerBeginOp());
auto parent_region = regions_->GetRegion(begin->args[0]);
// Skip this region if it has been merged.
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg);
- CHECK_EQ(begin->op, compiler_begin_op);
+ CHECK_EQ(begin->op, CompilerBeginOp());
auto parent_region = regions_->GetRegion(begin->args[0]);
if (parent_region.defined()) {
mergeable_regions.insert(parent_region);
// Merge annotations which are now internal to a region.
// This happens if we see a compiler begin next to a
// compiler end and they're both in the same region.
- if (call->op == compiler_begin_op && call->args[0]->IsInstance<CallNode>()) {
+ if (call->op == CompilerBeginOp() && call->args[0]->IsInstance<CallNode>()) {
auto arg = Downcast<Call>(call->args[0]);
- if (arg->op == compiler_end_op) {
+ if (arg->op == CompilerEndOp()) {
auto region1 = regions_->GetRegion(GetRef<Call>(call));
auto region2 = regions_->GetRegion(arg);
if (region1 == region2) {
Expr MergeCompilerRegions(const Expr& expr) {
// Create regions using the annotations.
- AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
+ AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, CompilerBeginOp(), CompilerEndOp());
// Analyze the graph to explore the opportunities of merging regions.
RegionMerger merger(regions);
#include "../analysis/annotated_region_set.h"
#include "../backend/utils.h"
+#include "pass_util.h"
namespace tvm {
namespace relay {
namespace partitioning {
-// Cache compiler_begin and compiler_end annotation ops for equivalence check to
-// reduce registry lookup overhead.
-static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
-static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
-
/*! \brief This struct maintains the required metadata for a region to generate a corresponding
* global function and function call. Global function will be passed to the target specific codegen
* and function call will be used in the transform Relay graph to invoke the function in runtime.
BaseFunc f_func = f.second;
// Creating regionset per function in the module.
- auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
- partitioning::compiler_end_op);
+ auto region_set = AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp());
regions_sets_[region_set] = f_func;
}
}
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return post;
- } else if (call->op == compiler_begin_op) {
+ } else if (call->op == CompilerBeginOp()) {
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
// Backtrace the parent to find the first ancestor node that is not a begin or end op
while (const auto* parent_call = parent.as<CallNode>()) {
- if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) {
+ if (parent_call->op == CompilerBeginOp() || parent_call->op == CompilerEndOp()) {
parent = parent_call->args[0];
} else {
break;
return std::move(var);
}
} else {
- CHECK_EQ(call->op, compiler_end_op);
+ CHECK_EQ(call->op, CompilerEndOp());
// The annotation node is inserted on edge so it must have only one
// argument.
CHECK_EQ(call->args.size(), 1U);
TupleOutFlattener() = default;
Expr Rewrite_(const CallNode* call, const Expr& post) final {
- if (call->op == compiler_end_op) {
+ if (call->op == CompilerEndOp()) {
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
// Arguments of annotation ops should be 1
CHECK_EQ(call->args.size(), 1U);
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}
+/*!
+ * \brief Cache the compiler_begin annotation op to reduce registry lookup overhead
+ * \param void
+ * \return compiler_begin op
+ */
+inline const Op& CompilerBeginOp() {
+ static Op op = Op::Get("annotation.compiler_begin");
+ return op;
+}
+
+/*!
+ * \brief Cache the compiler_end annotation op to reduce registry lookup overhead
+ * \param void
+ * \return compiler_end op
+ */
+inline const Op& CompilerEndOp() {
+ static Op op = Op::Get("annotation.compiler_end");
+ return op;
+}
+
template <typename ConditionObjectPtr>
struct TreeNode {
typedef std::shared_ptr<TreeNode<ConditionObjectPtr>> pointer;
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
+#include <iostream>
#include <mutex>
#include <string>
#include <unordered_map>