return _func_wrapper
-_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")
+
+
+@reg.register("nn.batch_norm", "target.dnnl")
+def batch_norm(attrs, args):
+ """Check if the external DNNL codegen should be used.
+ FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
+ """
+ return False
-def AnnotateTarget(target):
+def AnnotateTarget(targets):
"""Annotate ops in an experession with a provied compiler/target and then
use it for codegen.
Parameters
----------
- target : String
- The target compiler used for codegen.
+ targets : str or List[str]
+ The list of target compilers used for codegen.
Returns
-------
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
- return _ffi_api.AnnotateTarget(target)
+ if isinstance(targets, str):
+ targets = [targets]
+ return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
def Inline():
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
+#include <tvm/runtime/container.h>
#include <unordered_map>
#include <vector>
AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
for (auto candidate : regions_) {
- if (candidate->nodes.find(expr) != candidate->nodes.end()) {
+ if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
return candidate;
}
}
}
// Merge src to dest and erase src.
- dest->nodes.insert(src->nodes.begin(), src->nodes.end());
- for (const auto& input : src->ins) {
- dest->ins.push_back(input);
+ dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
+ for (const auto& input : src->ins_) {
+ dest->ins_.push_back(input);
}
- for (const auto& output : src->outs) {
- dest->outs.push_back(output);
+ for (const auto& output : src->outs_) {
+ dest->outs_.push_back(output);
}
// if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs
std::vector<Expr> ins_to_remove;
- for (const auto& input : dest->ins) {
+ for (const auto& input : dest->ins_) {
auto call = Downcast<Call>(input);
- auto it = src->nodes.find(call->args[0]);
- if (it != src->nodes.end()) {
- dest->outs.remove(*it);
+ auto it = src->nodes_.find(call->args[0]);
+ if (it != src->nodes_.end()) {
+ dest->outs_.remove(*it);
ins_to_remove.push_back(input);
}
}
for (const auto& input : ins_to_remove) {
- dest->ins.remove(input);
+ dest->ins_.remove(input);
}
regions_.erase(src);
}
if (src.defined()) {
MergeRegions(src, dest);
} else {
- dest->nodes.insert(expr);
+ dest->nodes_.insert(expr);
}
}
-AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
+AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
auto ret = regions_.emplace(AnnotatedRegion());
- (*ret.first)->id = region_id_++;
+ (*ret.first)->id_ = region_id_++;
+ (*ret.first)->target_ = target;
return *ret.first;
}
class AnnotatedRegionSet::Creator : public ExprVisitor {
public:
- Creator(const Op& region_begin_op, const Op& region_end_op) :
- begin_op_(region_begin_op), end_op_(region_end_op) {}
-
- AnnotatedRegionSet Create(const Expr& expr) {
- VisitExpr(expr);
- return std::move(region_set_);
- }
+ Creator(const Op& region_begin_op, const Op& region_end_op)
+ : begin_op_(region_begin_op), end_op_(region_end_op) {}
void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>();
<< "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
- region->ins.push_back(GetRef<Call>(call));
+ region->ins_.push_back(GetRef<Call>(call));
} else {
CHECK_EQ(call->op, end_op_);
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
+ std::string target = call->attrs.as<CompilerAttrs>()->compiler;
// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) {
- region = region_set_->MakeRegion();
- region->nodes.insert(call->args[0]);
+ // Create a new region if the argument is not belonged to any regions yet.
+ region = region_set_->MakeRegion(target);
+ region->nodes_.insert(call->args[0]);
+ } else {
+ // If the argument is belonged to a region, it must have the same target.
+ // Otherwise we should see a region_begin op.
+ CHECK_EQ(region->GetTarget(), target);
}
- region->nodes.insert(GetRef<Call>(call));
- region->outs.push_back(GetRef<Call>(call));
+ region->nodes_.insert(GetRef<Call>(call));
+ region->outs_.push_back(GetRef<Call>(call));
}
ExprVisitor::VisitExpr_(call);
}
+ AnnotatedRegionSet Create(const Expr& expr) {
+ VisitExpr(expr);
+ return std::move(region_set_);
+ }
+
void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) {
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
+#include <tvm/runtime/container.h>
#include <tvm/relay/transform.h>
#include <string>
class AnnotatedRegionNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {
- v->Visit("id", &id);
- Array<Expr> nodes_array(nodes.begin(), nodes.end());
+ v->Visit("id", &id_);
+ v->Visit("target", &target_);
+ Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
v->Visit("nodes", &nodes_array);
- Array<Expr> args_array(ins.begin(), ins.end());
+ Array<Expr> args_array(ins_.begin(), ins_.end());
v->Visit("args", &args_array);
- Array<Expr> rets_array(outs.begin(), outs.end());
+ Array<Expr> rets_array(outs_.begin(), outs_.end());
v->Visit("rets", &rets_array);
}
/*! \brief Get the region ID. */
int GetID() const {
- return id;
+ return id_;
+ }
+
+ /*! \brief Get the region target. */
+ std::string GetTarget() const {
+ return target_;
}
/*! \brief Get the region's inputs. */
std::list<Expr> GetInputs() const {
- return ins;
+ return ins_;
}
/*! \brief Get the region's outputs. */
std::list<Expr> GetOutputs() const {
- return outs;
+ return outs_;
}
/*! \brief Get the region's nodes. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
- return nodes;
+ return nodes_;
}
static constexpr const char* _type_key = "relay.AnnotatedRegion";
protected:
/*! \brief The region ID. */
- int id{-1};
+ int id_{-1};
+ /*! \brief The target for this region. */
+ std::string target_ = "default";
/*! \brief The inputs to this region. */
- std::list<Expr> ins;
+ std::list<Expr> ins_;
/*! \brief The outputs of this region */
- std::list<Expr> outs;
+ std::list<Expr> outs_;
/*! \brief Nodes in this region. */
- std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
+ std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes_;
friend class AnnotatedRegionSet;
friend class AnnotatedRegionSetNode;
void AddToRegion(AnnotatedRegion dest, const Expr& expr);
/*!
- * \brief Make a new region.
+ * \brief Make a new region for a target.
*
* \return The new region.
*/
- AnnotatedRegion MakeRegion();
+ AnnotatedRegion MakeRegion(const std::string& target);
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
/*! \brief The next region ID to assign. */
}
void VisitExpr_(const TupleGetItemNode* op) final {
- VisitExpr(op->tuple);
- CHECK(out_.size() > static_cast<size_t>(op->index));
-
- // Only keep the item we want for the child node.
- // FIXME(@comaniac): The other items should still be requried for the primary outputs.
- auto item = out_[op->index];
- out_.clear();
- out_.push_back(item);
+ // Do nothing
}
void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream;
-
+ std::ostringstream buf_stream;
// Args: ID
std::vector<std::string> args;
}
}
- // Analyze the output buffers
- std::vector<Type> out_types;
- if (call->checked_type()->IsInstance<TupleTypeNode>()) {
- auto type_node = call->checked_type().as<TupleTypeNode>();
- for (auto field : type_node->fields) {
- CHECK(field->IsInstance<TensorTypeNode>());
- out_types.push_back(field);
- }
- } else if (call->checked_type()->IsInstance<TensorTypeNode>()) {
- CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
- out_types.push_back(call->checked_type());
- } else {
- LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
- }
-
- out_.clear();
- for (auto out_type : out_types) {
- const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());
-
- std::string out = "buf_" + std::to_string(buf_idx_++);
- auto out_shape = GetShape(out_type);
- int out_size = 1;
- for (size_t i = 0; i < out_shape.size(); ++i) {
- out_size *= out_shape[i];
- }
- this->PrintIndents();
- std::ostringstream buf_stream;
- buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
- buf_decl_.push_back(buf_stream.str());
- decl_stream << ", " << out;
-
- // Update output buffer
- Output output;
- output.name = out;
- output.dtype = dtype;
- output.need_copy = true;
- output.size = out_size;
- out_.push_back(output);
+ // Analyze the output buffer
+ auto type_node = call->checked_type().as<TensorTypeNode>();
+ CHECK(type_node);
+ const auto& dtype = GetDtypeString(type_node);
+ std::string out = "buf_" + std::to_string(buf_idx_++);
+ auto out_shape = GetShape(call->checked_type());
+ int out_size = 1;
+ for (size_t i = 0; i < out_shape.size(); ++i) {
+ out_size *= out_shape[i];
}
+ this->PrintIndents();
+ buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
+ buf_decl_.push_back(buf_stream.str());
+ decl_stream << ", " << out;
// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
+
+ // Update output buffer
+ out_.clear();
+ Output output;
+ output.name = out;
+ output.dtype = dtype;
+ output.need_copy = true;
+ output.size = out_size;
+ out_.push_back(output);
}
std::string JIT(void) {
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());
- // Manifest the allocations.
- pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
- // Compute away possibly introduced constant computation.
- pass_seqs.push_back(transform::FoldConstant());
- // Fuse the shape functions.
- pass_seqs.push_back(transform::FuseOps());
-
// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
// external codegen.
pass_seqs.push_back(transform::Inline());
+ // Manifest the allocations.
+ pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
+ // Compute away possibly introduced constant computation.
+ pass_seqs.push_back(transform::FoldConstant());
+ // Fuse the shape functions.
+ pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
/*!
* \file src/relay/transforms/annotate_target.cc
- * \brief Wraps a call with compiler_begin and compiler_end to indicate that
- * the op of this call node will use external compiler.
+ * \brief Wraps an expr with compiler_begin and compiler_end to indicate that
+ * this expr should be handled by the external compiler.
*/
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+#include <tvm/runtime/container.h>
namespace tvm {
namespace relay {
namespace annotate_target {
-// Cache compiler_begin op for equivalence check.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
+static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
+
+const PackedFunc* make_begin_op =
+ runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
+const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
// A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator {
public:
- explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
-
- Expr Annotate(const Expr& expr) {
- return InsertEnd(Mutate(expr));
- }
-
- bool IsSupported(const Expr& expr) {
- if (expr->IsInstance<CallNode>()) {
- Call call = Downcast<Call>(expr);
- auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
- if (call->op->IsInstance<OpNode>()) {
- Op op = Downcast<Op>(call->op);
- CHECK(op.defined());
- if (fannotate.count(op)) {
- return fannotate[op](call->attrs, call->args);
- }
- } else if (call->op->IsInstance<FunctionNode>()) {
- // handle composite functions
- Function func = Downcast<Function>(call->op);
- CHECK(func.defined());
- auto comp_name = func->GetAttr<String>(attr::kComposite);
- if (comp_name.defined()) {
- std::string comp_name_str = comp_name;
- size_t i = comp_name_str.find('.');
- if (i != std::string::npos) {
- std::string target = comp_name_str.substr(0, i);
- if (target == target_) return true;
- }
+ explicit AnnotateTargetWrapper(Array<runtime::String> targets) : targets_(std::move(targets)) {}
+
+ /*!
+ * \brief This function annotates a compiler end and a compiler begin to all arguments.
+ *
+ * The compiler end is based on the arg target while the compiler begin is based on the given
+ * target. If target is not given and all arguments are going to the same target, then we will
+ * use that target; otherwise we use default for this op. Note that all arg exprs must be
+ * available in op_expr_to_target before calling this function.
+ *
+ * \param args An array of arguments of the given node.
+ * \param target The target of the current node.
+ * \return A pair of target and annotated argument expressions.
+ */
+ std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
+ const std::string& target = "") {
+ std::string ref_target = "";
+ Array<Expr> compiler_ends;
+ for (auto arg : args) {
+ std::string arg_target = "defualt";
+ const CallNode* call = arg.as<CallNode>();
+
+ if (call && call->op == compiler_begin_op) {
+ // Argument is already compiler begin node meaning that this is not the first time
+ // running this pass, so we simply remove it and will add a new one later.
+ CHECK_EQ(call->args.size(), 1U);
+ const CallNode* end = call->args[0].as<CallNode>();
+ if (end->op == compiler_end_op) {
+ arg_target = end->attrs.as<CompilerAttrs>()->compiler;
}
+ compiler_ends.push_back(call->args[0]);
+ } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
+ arg_target = op_expr_to_target_[arg];
+ compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
+ } else {
+ // Input vars.
+ compiler_ends.push_back(arg);
}
- }
- if (expr->IsInstance<TupleGetItemNode>()) {
- TupleGetItem get = Downcast<TupleGetItem>(expr);
- if (get->tuple->IsInstance<CallNode>() &&
- get->tuple.as<CallNode>()->op == compiler_begin_op) {
- return true;
+
+ // Maintain reference target in case the target of the current node is unassigned.
+ if (ref_target == "") {
+ ref_target = arg_target;
+ } else if (ref_target != arg_target) {
+ ref_target = "default";
}
}
- return false;
- }
- Expr InsertEnd(const Expr& arg) {
- if (IsSupported(arg)) {
- const auto *end_op =
- runtime::Registry::Get("relay.op.annotation._make.compiler_end");
- CHECK(end_op);
- Expr end = (*end_op)(arg, target_);
- return end;
+ // Determine compiler begin target.
+ std::string op_target = (target == "") ? ref_target : target;
+
+ Array<Expr> compiler_begins;
+ for (const auto& end : compiler_ends) {
+ compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
}
- return arg;
+
+ return {op_target, compiler_begins};
}
- Expr VisitExpr_(const CallNode* cn) {
- auto new_e = ExprMutator::VisitExpr_(cn);
+ Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) {
+ Expr new_op = (*ann_op)(expr, target);
+ new_op->checked_type_ = expr->checked_type_;
+ return new_op;
+ }
- Call call = Downcast<Call>(new_e);
+ Expr VisitExpr_(const CallNode* cn) final {
+ // Supported targets for this node. The order implies the priority.
+ std::vector<std::string> supported_targets;
+
+ auto op_node = cn->op.as<OpNode>();
+
+ // This graph has annotations, meaning that this is not the first time running this pass.
+ if (op_node && cn->op == compiler_begin_op) {
+ // Bypass compiler begin due to lack of target information. It will be processed
+ // when the following op handling arguments.
+ CHECK_EQ(cn->args.size(), 1U);
+ return VisitExpr(cn->args[0]);
+ } else if (op_node && cn->op == compiler_end_op) {
+ // Override compiler end with the new target.
+ CHECK_EQ(cn->args.size(), 1U);
+ auto input_expr = VisitExpr(cn->args[0]);
+ CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
+ return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
+ }
- // add end annotations if the args are supported
- Array<Expr> compiler_ends;
- for (const auto& it : call->args) {
- compiler_ends.push_back(InsertEnd(it));
+ // Peek the first argument. If it is compiler begin then this node had annotated by
+ // another target before, so we also consider that target as a supported target.
+ const CallNode* first_arg_call = cn->args[0].as<CallNode>();
+ if (first_arg_call && first_arg_call->op == compiler_begin_op) {
+ std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
+ if (arg_target != "default") {
+ supported_targets.push_back(arg_target);
+ }
}
- call = Call(call->op, compiler_ends, call->attrs);
-
- // add begin annotations if the call node is supported
- if (IsSupported(call)) {
- tvm::Array<tvm::relay::Expr> compiler_begins;
- const auto* begin_op =
- runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
- for (const auto& it : call->args) {
- CHECK(begin_op);
- Expr begin = (*begin_op)(it, target_);
- compiler_begins.push_back(begin);
+
+ // Check which targets this op can be offloaded.
+ if (op_node) {
+ // TVM operators: Check target specific op checking function and add to supported_targets
+ // if it is supported.
+ Op op = Downcast<Op>(cn->op);
+ CHECK(op.defined());
+ for (const auto& target : this->targets_) {
+ if (!Op::HasAttr("target." + std::string(target))) {
+ continue;
+ }
+ auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
+ if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) {
+ supported_targets.push_back(target);
+ }
+ }
+ } else if (cn->op->IsInstance<FunctionNode>()) {
+ // Composite function: Add the target of a composite function to supported_targets
+ // if it is in the target list.
+ Function func = Downcast<Function>(cn->op);
+ CHECK(func.defined());
+ auto comp_name = func->GetAttr<String>(attr::kComposite);
+ if (comp_name.defined()) {
+ std::string comp_name_str = comp_name;
+ size_t i = comp_name_str.find('.');
+ if (i != std::string::npos) {
+ std::string comp_target = comp_name_str.substr(0, i);
+ for (const auto& target : this->targets_) {
+ if (std::string(target) == comp_target) {
+ supported_targets.push_back(comp_target);
+ break;
+ }
+ }
+ }
}
- call = Call(call->op, compiler_begins, call->attrs);
}
+ supported_targets.push_back("default"); // Make default as the last option.
+
+ // TODO(@comaniac, @zhiics): Now we simply assign this node to the target with
+ // the highest priority, but we should preserve all supported targets so that
+ // we can make a better decision.
+ std::string target = supported_targets[0];
+
+ // Visit and mutate arguments after the target of this op has been determined.
+ auto new_call = Downcast<Call>(ExprMutator::VisitExpr_(cn));
+
+ // Add annotations to each arg.
+ auto target_n_args = AnnotateArgs(new_call->args, target);
+ Array<Expr> compiler_begins = std::get<1>(target_n_args);
+ Call call = Call(new_call->op, compiler_begins, new_call->attrs);
+ call->checked_type_ = cn->checked_type_;
+
+ // Update the target map.
+ op_expr_to_target_[call] = target;
return std::move(call);
}
- Expr VisitExpr_(const TupleNode* op) {
+ Expr VisitExpr_(const TupleNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
+ auto expr = Downcast<Tuple>(new_e);
- auto tup = Downcast<Tuple>(new_e);
- Array<Expr> new_fields;
- for (auto field : tup->fields) {
- new_fields.push_back(InsertEnd(field));
- }
- return Tuple(new_fields);
+ auto target_n_args = AnnotateArgs(expr->fields);
+ auto new_expr = Tuple(std::get<1>(target_n_args));
+ op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+ return std::move(new_expr);
}
- Expr VisitExpr_(const TupleGetItemNode* op) {
+ Expr VisitExpr_(const TupleGetItemNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
+ auto expr = Downcast<TupleGetItem>(new_e);
- auto get = Downcast<TupleGetItem>(new_e);
- if (IsSupported(get->tuple)) {
- const auto* begin_op =
- runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
- CHECK(begin_op);
- return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index);
- } else {
- return TupleGetItem(InsertEnd(get->tuple), get->index);
- }
+ auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
+ auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index);
+ op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+ return std::move(new_expr);
}
- Expr VisitExpr_(const FunctionNode* fn) {
+ Expr VisitExpr_(const FunctionNode* fn) final {
Function func;
Expr new_body;
// don't step into composite functions
} else {
auto new_e = ExprMutator::VisitExpr_(fn);
func = Downcast<Function>(new_e);
- new_body = InsertEnd(func->body);
+ new_body = func->body;
+ if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
+ new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
+ op_expr_to_target_[new_body] = op_expr_to_target_[func->body];
+ }
}
-
- return Function(
- func->params,
- new_body,
- func->ret_type,
- func->type_params,
- func->attrs);
+ return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
}
- Expr VisitExpr_(const LetNode* op) {
+ Expr VisitExpr_(const LetNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
-
auto let = Downcast<Let>(new_e);
- return Let(
- let->var,
- InsertEnd(let->value),
- InsertEnd(let->body));
+
+ auto target_n_args = AnnotateArgs({let->value, let->body});
+ auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
+ op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+ return std::move(new_expr);
}
- Expr VisitExpr_(const IfNode* op) {
+ Expr VisitExpr_(const IfNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
-
- auto iff = Downcast<If>(new_e);
- return If(
- InsertEnd(iff->cond),
- InsertEnd(iff->true_branch),
- InsertEnd(iff->false_branch));
+ auto expr = Downcast<If>(new_e);
+
+ auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch});
+ CHECK_EQ(std::get<1>(target_n_args).size(), 3U);
+ auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1],
+ std::get<1>(target_n_args)[2]);
+ op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+ return std::move(new_expr);
}
- Expr VisitExpr_(const RefCreateNode* op) {
+ Expr VisitExpr_(const RefCreateNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
+ auto expr = Downcast<RefCreate>(new_e);
- auto create = Downcast<RefCreate>(new_e);
- return RefCreate(InsertEnd(create->value));
+ auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
+ auto new_expr = RefCreate(std::get<1>(target_n_args)[0]);
+ op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+ return std::move(new_expr);
}
- Expr VisitExpr_(const RefReadNode* op) {
+ Expr VisitExpr_(const RefReadNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
+ auto expr = Downcast<RefRead>(new_e);
- auto read = Downcast<RefRead>(new_e);
- return RefRead(InsertEnd(read->ref));
+ auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
+ auto new_expr = RefRead(std::get<1>(target_n_args)[0]);
+ op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+ return std::move(new_expr);
}
- Expr VisitExpr_(const RefWriteNode* op) {
+ Expr VisitExpr_(const RefWriteNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
+ auto expr = Downcast<RefWrite>(new_e);
- auto write = Downcast<RefWrite>(new_e);
- return RefWrite(
- InsertEnd(write->ref),
- InsertEnd(write->value));
+ auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
+ auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
+ op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+ return std::move(new_expr);
}
private:
- std::string target_;
+ /*! \brief The target backends for annotation. */
+ Array<runtime::String> targets_;
+ /*! \brief Maintain the decision of the target for each op expr. */
+ std::unordered_map<Expr, std::string, ObjectHash, ObjectEqual> op_expr_to_target_;
};
-Expr AnnotateTarget(const Expr& expr, const std::string& target) {
- return AnnotateTargetWrapper(target).Annotate(expr);
+Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) {
+ return AnnotateTargetWrapper(targets).Mutate(expr);
}
} // namespace annotate_target
namespace transform {
-Pass AnnotateTarget(const std::string& target) {
+Pass AnnotateTarget(const Array<runtime::String>& targets) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
+ return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, targets));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
{"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
}
-TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget")
-.set_body_typed(AnnotateTarget);
+TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget").set_body_typed(AnnotateTarget);
} // namespace transform
namespace tvm {
namespace relay {
-namespace partitioning {
+namespace merge_compiler_region {
// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
-/*! \brief This is a pre-requisite pass to merge-supported pass.
- * The AnnotateRestDefault pass will put "default" Compiler Annotations to
- * nodes that are not annotated already. This is there to ensure that the
- * user will not leave un-annotated nodes MergeCompilerRegions pass is run.
- * Why? Because, MergeCompilerRegions pass assumes every node to be annotated.
- */
-class AnnotateRestDefault : public ExprMutator {
- public:
- explicit AnnotateRestDefault(const Expr& expr) {
- regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
- }
-
- Expr Annotate(const Expr& expr) {
- // Its a function that is being passed on to annotate
- func_ = Downcast<Function>(expr);
-
- // Corner Case CC1 : If the last node does not belong
- // to a region node to add a compiler_end
- auto region = regions_->GetRegion(func_->body);
- auto mutated_expr = this->VisitExpr(expr);
- if (!region.defined()) {
- func_ = Downcast<Function>(mutated_expr);
- // CC1 : add that compiler end after mutation
- auto body = InsertEnd(func_->body);
- func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs());
- return Downcast<Expr>(func_);
- }
- return mutated_expr;
- }
-
- /*! \brief This function adds compiler ends to nodes that
- * don't belong to a region already (default).
- * \param expr The expression to add a compiler end to.
- * \return expr The expression with or without a compiler end added.
- */
- Expr InsertEnd(const Expr& expr) {
- if (annotated_nodes_.find(expr) == annotated_nodes_.end() && !expr->IsInstance<VarNode>() &&
- !expr->IsInstance<ConstantNode>()) {
- const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
- CHECK(end_op);
- Expr end = (*end_op)(expr, target_);
- return end;
- }
- return expr;
- }
-
- /*! \brief This function adds compiler begins to nodes that
- * don't belong to a region already (default).
- * \param expr The expression to add a compiler begin to.
- * \return expr The expression with or without a compiler begin added.
- */
- Expr InsertBegin(const Expr& expr) {
- const auto* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
- CHECK(begin_op);
- Expr begin = (*begin_op)(expr, target_);
- annotated_nodes_.insert(begin);
- return begin;
- }
-
- Expr VisitExpr_(const CallNode* cn) final {
- auto region = regions_->GetRegion(GetRef<Call>(cn));
- auto new_e = ExprMutator::VisitExpr_(cn);
- Call call = Downcast<Call>(new_e);
-
- // Add compiler ends if the parent isn't annotated
- Array<Expr> args;
- for (auto arg : call->args) {
- args.push_back(InsertEnd(arg));
- }
-
- Expr updated_call = Call(call->op, args, call->attrs);
- if (!region.defined()) {
- // if the current node does not belong to annotated region
- // annotate the all incoming edges (args)
- // with "default" compiler_begin annotations.
- Array<Expr> compiler_begins;
- for (auto arg : args) {
- compiler_begins.push_back(InsertBegin(arg));
- }
- updated_call = Call(call->op, compiler_begins, call->attrs);
- } else {
- annotated_nodes_.insert(updated_call);
- }
- return updated_call;
- };
-
- Expr VisitExpr_(const TupleNode* op) {
- auto region = regions_->GetRegion(GetRef<Tuple>(op));
- auto new_e = ExprMutator::VisitExpr_(op);
- Tuple tup = Downcast<Tuple>(new_e);
-
- Array<Expr> fields;
- for (auto field : tup->fields) {
- fields.push_back(InsertEnd(field));
- }
-
- Expr updated_tuple = Tuple(fields);
- if (!region.defined()) {
- Array<Expr> compiler_begins;
- for (const auto& field : fields) {
- compiler_begins.push_back(InsertBegin(field));
- }
- updated_tuple = Tuple(compiler_begins);
- } else {
- annotated_nodes_.insert(updated_tuple);
- }
- return updated_tuple;
- }
-
- Expr VisitExpr_(const TupleGetItemNode* op) {
- auto region = regions_->GetRegion(GetRef<TupleGetItem>(op));
- auto new_e = ExprMutator::VisitExpr_(op);
- auto get = Downcast<TupleGetItem>(new_e);
-
- auto updated_tuple = InsertEnd(get->tuple);
- Expr updated_get = TupleGetItem(updated_tuple, get->index);
- if (!region.defined()) {
- updated_get = TupleGetItem(InsertBegin(updated_tuple), get->index);
- } else {
- annotated_nodes_.insert(updated_get);
- }
- return updated_get;
- }
-
- Expr VisitExpr_(const IfNode* op) {
- auto region = regions_->GetRegion(GetRef<If>(op));
- auto new_e = ExprMutator::VisitExpr_(op);
- auto iff = Downcast<If>(new_e);
-
- if (!region.defined()) {
- return If(InsertBegin(InsertEnd(iff->cond)), InsertBegin(InsertEnd(iff->true_branch)),
- InsertBegin(InsertEnd(iff->false_branch)));
- } else {
- Expr updated_iff =
- If(InsertEnd(iff->cond), InsertEnd(iff->true_branch), InsertEnd(iff->false_branch));
- annotated_nodes_.insert(updated_iff);
- return updated_iff;
- }
- }
-
- Expr VisitExpr_(const LetNode* op) {
- auto new_e = ExprMutator::VisitExpr_(op);
- auto let = Downcast<Let>(new_e);
- return Let(let->var, InsertEnd(let->value), InsertEnd(let->body));
- }
-
- Expr VisitExpr_(const RefCreateNode* op) {
- auto new_e = ExprMutator::VisitExpr_(op);
- auto create = Downcast<RefCreate>(new_e);
- return RefCreate(InsertEnd(create->value));
- }
-
- Expr VisitExpr_(const RefReadNode* op) {
- auto new_e = ExprMutator::VisitExpr_(op);
- auto read = Downcast<RefRead>(new_e);
- return RefRead(InsertEnd(read->ref));
- }
-
- Expr VisitExpr_(const RefWriteNode* op) {
- auto new_e = ExprMutator::VisitExpr_(op);
- auto write = Downcast<RefWrite>(new_e);
- return RefWrite(InsertEnd(write->ref), InsertEnd(write->value));
- }
-
- private:
- AnnotatedRegionSet regions_;
- const std::string target_ = "default";
- Function func_;
- std::unordered_set<Expr, ObjectHash, ObjectEqual> annotated_nodes_;
-};
-
-class MergeAnnotations : public ExprMutator {
- public:
- explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
-
- Expr VisitExpr_(const CallNode* call) final {
- // remove 'default' annotations
- auto attrs = call->attrs.as<CompilerAttrs>();
- if (attrs != nullptr && attrs->compiler == "default") {
- return VisitExpr(call->args[0]);
- }
- // Merge annotations which are now internal to a region.
- // This happens if we see a compiler begin next to a
- // compiler end and they're both in the same region.
- if (call->op == compiler_begin_op) {
- if (call->args[0]->IsInstance<CallNode>()) {
- auto arg = Downcast<Call>(call->args[0]);
- if (arg->op == compiler_end_op) {
- auto region1 = regions_->GetRegion(GetRef<Call>(call));
- auto region2 = regions_->GetRegion(arg);
- if (region1 == region2) {
- return VisitExpr(arg->args[0]);
- }
- }
- }
- }
- return ExprMutator::VisitExpr_(call);
- }
-
- private:
- AnnotatedRegionSet regions_;
-};
-
class RegionMerger : public ExprVisitor {
public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
void VisitExpr_(const CallNode* call) final {
if (call->op == compiler_end_op) {
auto region = regions_->GetRegion(GetRef<Call>(call));
- if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
- // set the region target
+
+ // Skip this region if it has been merged to the other region.
+ if (merged_regions_.find(region->GetID()) != merged_regions_.end()) {
+ return;
+ }
+
+ // Check the region target.
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
- region_targets_[region->GetID()] = compiler_attrs->compiler;
- // first look at the region args to determine the parent regions
+ CHECK_EQ(region->GetTarget(), compiler_attrs->compiler);
+
+ // Visit the unmerged parent regions.
for (const auto& arg : region->GetInputs()) {
- // all args should be begin annotations
+ // Region inputs must be begin annotation, and the region of
+ // the begin annotation's argument is the parent region.
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
- // the arguments of the begin annotations will be in the parent regions
auto parent_region = regions_->GetRegion(begin->args[0]);
- // if there is no parent region, move on
- if (!parent_region.defined()) continue;
- // merge the parent region if it hasn't been done already
- if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
+
+ // Skip this region if it has been merged.
+ if (!parent_region.defined()) {
+ continue;
+ } else if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
VisitExpr(begin->args[0]);
}
}
- // get the mergeable regions now all the parents have been visited
+
+ // Collect unmerged parent regions.
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
auto parent_region = regions_->GetRegion(begin->args[0]);
- if (!parent_region.defined()) continue;
- mergeable_regions.insert(parent_region);
+ if (parent_region.defined()) {
+ mergeable_regions.insert(parent_region);
+ }
}
+
+ // Propogate all the parent restrictions to the current region.
auto& region_restrictions = region_restrictions_[region->GetID()];
for (const auto& parent_region : mergeable_regions) {
- // add all the parent restrictions to the current region
auto parent_restrictions = region_restrictions_[parent_region->GetID()];
region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
}
+
for (const auto& parent_region : mergeable_regions) {
- bool merged = false;
- // check the parent region has the same target
- if (region_targets_[parent_region->GetID()] == compiler_attrs->compiler) {
- // check the parent region isn't in the restrictions
- if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) {
- // merge the parent region into the current region
- regions_->MergeRegions(parent_region, region);
- // update the restrictions of all other regions to reflect the
- // change in id
- for (const auto& r : regions_) {
- auto& restrictions = region_restrictions_[r->GetID()];
- if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
- restrictions.erase(parent_region->GetID());
- restrictions.insert(region->GetID());
- }
- }
- merged = true;
+ // Skip the parent region with a different target.
+ if (parent_region->GetTarget() != compiler_attrs->compiler) {
+ region_restrictions.insert(parent_region->GetID());
+ continue;
+ }
+
+ // Skip the parent region if it is in the restriction set.
+ if (region_restrictions.find(parent_region->GetID()) != region_restrictions.end()) {
+ continue;
+ }
+
+ // Merge the parent region to the current one.
+ regions_->MergeRegions(parent_region, region);
+
+ // Replace the parent region ID with the current region for all
+ // other regions' restriction sets.
+ for (const auto& r : regions_) {
+ auto& restrictions = region_restrictions_[r->GetID()];
+ if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
+ restrictions.erase(parent_region->GetID());
+ restrictions.insert(region->GetID());
}
}
- // if the parent wasn't merged, add it as a restriction to the current
- // region
- if (!merged) region_restrictions.insert(parent_region->GetID());
}
merged_regions_.insert(region->GetID());
}
private:
AnnotatedRegionSet regions_;
std::unordered_set<int> merged_regions_;
- std::map<int, std::unordered_set<int>> region_restrictions_;
- std::map<int, std::string> region_targets_;
+ std::unordered_map<int, std::unordered_set<int>> region_restrictions_;
};
-Expr MergeCompilerRegions(const Expr& expr) {
- // Annotate all the nodes that aren't annotated as 'default'.
- AnnotateRestDefault anno_default(expr);
- auto expr_all_annotated = anno_default.Annotate(expr);
+class MergeAnnotations : public ExprMutator {
+ public:
+ explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
+
+ Expr VisitExpr_(const CallNode* call) final {
+ // Merge annotations which are now internal to a region.
+ // This happens if we see a compiler begin next to a
+ // compiler end and they're both in the same region.
+ if (call->op == compiler_begin_op && call->args[0]->IsInstance<CallNode>()) {
+ auto arg = Downcast<Call>(call->args[0]);
+ if (arg->op == compiler_end_op) {
+ auto region1 = regions_->GetRegion(GetRef<Call>(call));
+ auto region2 = regions_->GetRegion(arg);
+ if (region1 == region2) {
+ return VisitExpr(arg->args[0]);
+ }
+ }
+ }
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ private:
+ AnnotatedRegionSet regions_;
+};
+Expr MergeCompilerRegions(const Expr& expr) {
// Create regions using the annotations.
- AnnotatedRegionSet regions =
- AnnotatedRegionSet::Create(expr_all_annotated, compiler_begin_op, compiler_end_op);
+ AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
- // By now, all the nodes have some sort of annotation.
- // Region merger is an ExprVisitor that will update the
- // AnnotatedRegionSet, merging all the regions that can be merged.
+ // Analyze the graph to explore the opportunities of merging regions.
RegionMerger merger(regions);
- merger.VisitExpr(expr_all_annotated);
+ merger.VisitExpr(expr);
- // This updates the expression to remove annotations that are now
- // 'internal' to a merged region.
+ // Remove annotations that are not in the region boundaries.
MergeAnnotations merge_anno(regions);
- return merge_anno.Mutate(expr_all_annotated);
+ return merge_anno.Mutate(expr);
}
-} // namespace partitioning
+} // namespace merge_compiler_region
namespace transform {
Pass MergeCompilerRegions() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
[=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(partitioning::MergeCompilerRegions(f));
+ return Downcast<Function>(merge_compiler_region::MergeCompilerRegions(f));
};
- auto partitioned = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {});
- return Sequential({partitioned, InferType()});
+ auto merged = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {});
+ return Sequential({merged, InferType()});
}
TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions")
IRModule module_;
};
+class DefaultRemover : public ExprMutator {
+ public:
+ explicit DefaultRemover(const IRModule& module) : module_(module) {}
+
+ IRModule Remove() {
+ auto glob_funcs = module_->functions;
+ for (const auto& pair : glob_funcs) {
+ if (auto* fn = pair.second.as<FunctionNode>()) {
+ auto func = GetRef<Function>(fn);
+ func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
+ func->attrs);
+ module_->Update(pair.first, func);
+ }
+ }
+ return module_;
+ }
+
+ Expr VisitExpr_(const CallNode* call) final {
+ auto attrs = call->attrs.as<CompilerAttrs>();
+ if (attrs != nullptr && attrs->compiler == "default") {
+ return VisitExpr(call->args[0]);
+ }
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ private:
+ IRModule module_;
+};
+
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
- [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
+ [=](IRModule m, PassContext pc) {
+ // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
+ // by treating them as un-annotated, but we don't have it yet. This workaround pass removes
+ // all "default" annotations and should be deleted in the future.
+ auto new_m = partitioning::DefaultRemover(m).Remove();
+ return partitioning::Partitioner(new_m).Partition();
+ };
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}
read_from_dnnl_memory(out, dst_memory);
}
-extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
- float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
+extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
+ float* variance, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) {
- // FIXME(@comaniac): BN has 3 outputs: out, new_mean and new_variance, but we do not update
- // the rest two because no one cares about them for now. Should update it in the future.
using tag = memory::format_tag;
using dt = memory::data_type;
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
- float* variance, float* out, float* new_mean, float* new_variance,
- int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);
+ float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
+ int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_);
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
+import tvm
from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end
-def check_region(region_set, args, nodes, rets):
+def check_region(region_set, target, args, nodes, rets):
region = region_set.get_region(args[0])
assert region
+ assert target == region.target
assert set(args) == set(region.args)
assert set(nodes) == set(region.nodes)
assert set(rets) == set(region.rets)
assert len(region_set) == 4
check_region(
region_set,
+ 'test_target',
[cb_1],
[cb_1, O_1, ce_1, ce_2],
[ce_1, ce_2],
)
check_region(
region_set,
+ 'test_target',
[cb_2],
[cb_2, O_2, ce_3],
[ce_3],
)
check_region(
region_set,
+ 'default',
[cb_d],
[cb_d, X, ce_d],
[ce_d],
)
check_region(
region_set,
+ 'test_target',
[cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4],
[ce_4],
cb_3 = compiler_begin(ce_3, 'test_target')
cb_4 = compiler_begin(ce_d, 'test_target')
O_3 = relay.add(cb_3, cb_4)
- ce_4 = compiler_end(O_3, 'test_target')
+ O_4 = relay.add(cb_3, cb_4)
+ O_5 = relay.Tuple([O_3, O_4])
+ ce_4 = compiler_end(O_5, 'test_target')
merged = relay.Function([data], ce_4)
region_set = relay.analysis.AnnotatedRegionSet(merged,
assert len(region_set) == 3
check_region(
region_set,
+ 'test_target',
[cb_1],
[cb_1, O_1, O_2, ce_2, ce_3],
[ce_2, ce_3],
)
check_region(
region_set,
+ 'default',
[cb_d],
[cb_d, X, ce_d],
[ce_d],
)
check_region(
region_set,
+ 'test_target',
[cb_3, cb_4],
- [cb_3, cb_4, O_3, ce_4],
+ [cb_3, cb_4, O_3, O_4, O_5, ce_4],
[ce_4],
)
if __name__ == "__main__":
test_region_set_creator_diamond()
test_region_set_creator_merged()
-
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
-@reg.register("nn.relu", "target.test")
-def relu(attrs, args):
- return True
-
-
def test_multiple_ends():
+ @reg.register("nn.relu", "target.test")
+ def relu(attrs, args): # pylint: disable=unused-variable
+ return True
+
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
r = relay.nn.relu(cb_1)
ce_1 = relay.annotation.compiler_end(r, "test")
ce_2 = relay.annotation.compiler_end(r, "test")
- a_1 = relay.abs(ce_1)
- a_2 = relay.abs(ce_2)
- out = relay.add(a_1, a_2)
- f = relay.Function([x], out)
+ cb_2 = relay.annotation.compiler_begin(ce_1, "default")
+ cb_3 = relay.annotation.compiler_begin(ce_2, "default")
+ a_1 = relay.abs(cb_2)
+ a_2 = relay.abs(cb_3)
+ ce_3 = relay.annotation.compiler_end(a_1, "default")
+ ce_4 = relay.annotation.compiler_end(a_2, "default")
+ cb_4 = relay.annotation.compiler_begin(ce_3, "default")
+ cb_5 = relay.annotation.compiler_begin(ce_4, "default")
+ out = relay.add(cb_4, cb_5)
+ ce_6 = relay.annotation.compiler_end(out, "default")
+ f = relay.Function([x], ce_6)
mod = tvm.IRModule.from_expr(f)
return mod
assert tvm.ir.structural_equal(expected, result)
+def test_type_propagation():
+ target = "test_type_propagation"
+
+ @reg.register("nn.relu", "target." + target)
+ def relu(attrs, args): # pylint: disable=unused-variable
+ return args[0].checked_type.dtype == "float32"
+
+ def before():
+ x = relay.var("x", shape=(10, 10))
+ r = relay.nn.relu(x)
+ out = relay.nn.relu(r)
+ f = relay.Function([x], out)
+ mod = tvm.IRModule.from_expr(f)
+ return mod
+
+ # If the type isn't propogated, then the relu checker function will fail to get the dtype.
+ assert transform.AnnotateTarget(target)(before())
+
+
+def test_tuple():
+ target = "test_tuple"
+
+ @reg.register("nn.relu", "target." + target)
+ def relu(attrs, args): # pylint: disable=unused-variable
+ return True
+
+ @reg.register("concatenate", "target." + target)
+ def concatenate(attrs, args): # pylint: disable=unused-variable
+ return True
+
+ """Test that TupleNode is included in annotation when surrounded by supported nodes."""
+ def before():
+ x = relay.var("x", shape=(10, 5))
+ y = relay.var("y", shape=(10, 5))
+ a_1 = relay.nn.relu(x)
+ a_2 = relay.nn.relu(y)
+ out = relay.concatenate((a_1, a_2), axis=1)
+ f = relay.Function([x, y], out)
+ mod = tvm.IRModule.from_expr(f)
+ return mod
+
+ def after():
+ x = relay.var("x", shape=(10, 5))
+ y = relay.var("y", shape=(10, 5))
+ cb_1 = relay.annotation.compiler_begin(x, target)
+ cb_2 = relay.annotation.compiler_begin(y, target)
+ a_1 = relay.nn.relu(cb_1)
+ a_2 = relay.nn.relu(cb_2)
+ ce_1 = relay.annotation.compiler_end(a_1, target)
+ ce_2 = relay.annotation.compiler_end(a_2, target)
+ cb_3 = relay.annotation.compiler_begin(ce_1, target)
+ cb_4 = relay.annotation.compiler_begin(ce_2, target)
+ tup = relay.Tuple([cb_3, cb_4])
+ ce_3 = relay.annotation.compiler_end(tup, target)
+ cb_3 = relay.annotation.compiler_begin(ce_3, target)
+ out = relay.op._make.concatenate(cb_3, 1)
+ ce_4 = relay.annotation.compiler_end(out, target)
+ f = relay.Function([x, y], ce_4)
+ mod = tvm.IRModule.from_expr(f)
+ return mod
+
+ result = transform.AnnotateTarget(target)(before())
+ expected = transform.InferType()(after())
+ assert tvm.ir.structural_equal(expected, result)
+
+
def test_composite_function():
def before():
a = relay.var('a', shape=(10, 10))
assert tvm.ir.structural_equal(expected, result)
+def test_multiple_runs():
+ @reg.register("nn.relu", "target.A")
+ def relu(attrs, args): # pylint: disable=unused-variable
+ return True
+
+ @reg.register("add", "target.B")
+ def add(attrs, args): # pylint: disable=unused-variable
+ return True
+
+ def before():
+ x = relay.var("x", shape=(10, 5))
+ a_1 = relay.nn.relu(x)
+ a_2 = relay.abs(a_1)
+ a_3 = relay.nn.relu(a_1)
+ out = relay.add(a_2, a_3)
+
+ f = relay.Function([x], out)
+ mod = tvm.IRModule.from_expr(f)
+ return mod
+
+ mod = transform.AnnotateTarget("A")(before())
+ mod = transform.AnnotateTarget("B")(mod)
+ expected = transform.AnnotateTarget(["A", "B"])(before())
+ assert tvm.ir.structural_equal(expected, mod)
+
+
if __name__ == "__main__":
- test_multiple_ends()
test_extern_dnnl()
- #test_extern_dnnl_mobilenet()
test_composite_function()
+ #test_extern_dnnl_mobilenet()
+ test_multiple_ends()
+ test_type_propagation()
+ test_tuple()
+ test_multiple_runs()
X = not supported by target
O O
- / \ / \
+ / \\ / \\
O X --> O + + X
- \ / \ /
+ \\ / \\ /
O O
Note that we can't just merge the three supported operators together,
ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test")
+ cb_3 = compiler_begin(ce_2, "default")
O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test")
- X = relay.tanh(ce_2)
- cb_3 = compiler_begin(ce_3, "test")
- cb_4 = compiler_begin(X, "test")
- O_3 = relay.add(cb_3, cb_4)
- ce_4 = compiler_end(O_3, "test")
+ X = relay.tanh(cb_3)
+ ce_4 = compiler_end(X, "default")
- diamond = relay.Function([data], ce_4)
+ cb_4 = compiler_begin(ce_3, "test")
+ cb_5 = compiler_begin(ce_4, "test")
+ O_3 = relay.add(cb_4, cb_5)
+ ce_5 = compiler_end(O_3, "test")
+
+ diamond = relay.Function([data], ce_5)
return diamond
def expected():
O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test")
- X = relay.tanh(ce_2)
+ cb_3 = compiler_begin(ce_2, "default")
+ X = relay.tanh(cb_3)
+ ce_4 = compiler_end(X, "default")
- cb_3 = compiler_begin(ce_3, "test")
- cb_4 = compiler_begin(X, "test")
- O_3 = relay.add(cb_3, cb_4)
- ce_4 = compiler_end(O_3, "test")
+ cb_4 = compiler_begin(ce_3, "test")
+ cb_5 = compiler_begin(ce_4, "test")
+ O_3 = relay.add(cb_4, cb_5)
+ ce_5 = compiler_end(O_3, "test")
- func = relay.Function([data], ce_4)
+ func = relay.Function([data], ce_5)
return func
result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions())
"""This tests the merging algorithm on the example used in the RFC.
See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830
- Blue nodes are adds, red nodes are subtracts.
+ Blue nodes are adds (target: test), red nodes are subtracts (target: default).
"""
def annotated():
in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
node2 = relay.add(begin4, begin5)
end2 = compiler_end(node2, "test")
- node3 = relay.subtract(in_5, in_6)
- node4 = relay.subtract(in_7, node3)
+ dbegin0 = compiler_begin(in_5, "default")
+ dbegin1 = compiler_begin(in_6, "default")
+ node3 = relay.subtract(dbegin0, dbegin1)
+ dbegin2 = compiler_begin(in_7, "default")
+ dend1 = compiler_end(node3, "default")
+ dbegin3 = compiler_begin(dend1, "default")
+ node4 = relay.subtract(dbegin2, dbegin3)
+ dend2 = compiler_end(node4, "default")
begin6 = compiler_begin(end2, "test")
- begin7 = compiler_begin(node4, "test")
+ begin7 = compiler_begin(dend2, "test")
node5 = relay.add(begin6, begin7)
end3 = compiler_end(node5, "test")
end4 = compiler_end(node5, "test")
- node6 = relay.subtract(in_8, end3)
+ dbegin4 = compiler_begin(in_8, "default")
+ dbegin5 = compiler_begin(end3, "default")
+ node6 = relay.subtract(dbegin4, dbegin5)
begin8 = compiler_begin(in_9, "test")
begin9 = compiler_begin(end4, "test")
node7 = relay.add(begin8, begin9)
end5 = compiler_end(node7, "test")
- begin10 = compiler_begin(node6, "test")
+ dend3 = compiler_end(node6, "default")
+ begin10 = compiler_begin(dend3, "test")
begin11 = compiler_begin(end5, "test")
node8 = relay.add(begin10, begin11)
end6 = compiler_end(node8, "test")
node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1)
- node3 = relay.subtract(in_5, in_6)
- node4 = relay.subtract(in_7, node3)
+ dbegin0 = compiler_begin(in_5, "default")
+ dbegin1 = compiler_begin(in_6, "default")
+ dbegin2 = compiler_begin(in_7, "default")
+ node3 = relay.subtract(dbegin0, dbegin1)
+ node4 = relay.subtract(dbegin2, node3)
+ dend0 = compiler_end(node4, "default")
- begin4 = compiler_begin(node4, "test")
+ begin4 = compiler_begin(dend0, "test")
begin5 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin4)
end1 = compiler_end(node5, "test")
- node6 = relay.subtract(in_8, end1)
+ dbegin4 = compiler_begin(end1, "default")
+ dbegin5 = compiler_begin(in_8, "default")
+ node6 = relay.subtract(dbegin5, dbegin4)
+ dend1 = compiler_end(node6, "default")
node7 = relay.add(begin5, node5)
end2 = compiler_end(node7, "test")
begin6 = compiler_begin(end2, "test")
- begin7 = compiler_begin(node6, "test")
+ begin7 = compiler_begin(dend1, "test")
node8 = relay.add(begin7, begin6)
"""Unit tests for graph partitioning."""
import os
import sys
+
import numpy as np
import pytest
from tvm import runtime
from tvm.relay import transform
from tvm.contrib import util
-from tvm.relay.op.annotation import compiler_begin, compiler_end
+from tvm.relay import transform
+from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator
+from tvm.relay.op.annotation import compiler_begin, compiler_end
+from tvm.runtime import container
+
# Leverage the pass manager to write a simple white list based annotator
@transform.function_pass(opt_level=0)
return lib
def check_vm_result():
+ compile_engine.get().clear()
with relay.build_config(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result():
+ compile_engine.get().clear()
with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib)
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')
- op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
- mod = WhiteListAnnotator(op_list, "dnnl")(mod)
+ mod = transform.AnnotateTarget(["dnnl"])(mod)
+ mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
test_extern_ccompiler_default_ops()
test_extern_ccompiler()
test_extern_dnnl()
+ # TODO(@comaniac, @zhiics): Fix constant node and re-open this case.
#test_extern_dnnl_mobilenet()
test_function_lifting()
test_function_lifting_inline()