* external functions, and they will use the provided compiler for codegen.
*/
+#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
-#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
-#include <utility>
#include <unordered_map>
#include <unordered_set>
+#include <utility>
#include <vector>
-#include "../backend/utils.h"
#include "../analysis/annotated_region_set.h"
-
+#include "../backend/utils.h"
namespace tvm {
namespace relay {
return true;
}
- void VisitExpr_(const CallNode *call) final {
+ void VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return;
* in the TVM stack.
*
* Input : A Relay module that have functions with disjoint annotated regions
- * using compiler_begin and compiler_end. There could be multiple outputs.
+ * using compiler_begin and compiler_end. There could be multiple
+ * outputs.
*
- * Output : A Relay module with global functions for such disjoint annotated regions
- * with calls inserted at the respective location
+ * Output : A Relay module with global functions for such disjoint annotated
+ * regions with calls inserted at the respective location
*
- * Dependencies : RegionSet Utility class.
+ * Dependencies : AnnotatedRegionSet Utility class.
*
* Methodology :
- * 1) The RegionSet utility class is able to construct a collection of
- * nodes that are bound by a given annotation -- here we use compiler_begin
- * and compiler_end
+ * 1) The AnnotatedRegionSet utility class is able to construct a collection
+ * of nodes that are bound by a given annotation -- here we use
+ * compiler_begin and compiler_end
* 2) Initially, for each function in the module RegionSets are populated.
* 3) Then, Vistor pass is traversed until a compiler_end node is encountered
* that belongs to a "region".
- * 4) When the first compiler_end of a given annotated region is found, a function is
- * formed and inserted.
- * a) if the region has multiple outputs, a Tuple node (capturing all outputs)
- * is returned.
- * 5) Thereafter, if we encounter an another output of the same annotated region,
- * it is important to note that the function is already formed. Therefore, it will
- * lookup the function and add a TupleGetItemNode.
- * a) We will use the location index of "rets" of each "Region" of RegionSet
- * as TupleGetItemNode index.
- * 6) Therefore, functions will be created for all annotated regions. The name for each
- * global function is created using "Region" id and the compiler name.
+ * 4) When the first compiler_end of a given annotated region is found,
+ * a function is formed and inserted.
+ * a) if the region has multiple outputs, a Tuple node (capturing
+ * all outputs) is returned.
+ * 5) Thereafter, if we encounter an another output of the same annotated
+ * region, it is important to note that the function is already formed.
+ * Therefore, it will lookup the function and add a TupleGetItemNode.
+ * a) We will use the location index of "rets" of each Region" of
+ * AnnotatedRegionSet as TupleGetItemNode index.
+ * 6) Therefore, functions will be created for all annotated regions.
+ * The name for each global function is created using "Region" id and
+ * the compiler name.
*/
class Partitioner : public ExprMutator {
}
}
- Expr VisitExpr_(const CallNode *call) final {
+ Expr VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return ExprMutator::VisitExpr_(call);
} else if (call->op == compiler_begin_op) {
- // The annotation node is inserted on edge so it must have only one argument.
+ // The annotation node is inserted on edge so it must have only one
+ // argument.
CHECK_EQ(call->args.size(), 1U);
// Traverse the rest graph.
// The type of the created variable is the same as the compiler_begin
// node.
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
- std::string varname = target + "_" + std::to_string(sg->GetID())
- + "_i" + std::to_string(index);
+ std::string varname =
+ target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
auto var = Var(varname, GetRef<Call>(call)->checked_type_);
auto cand = std::make_pair(var, input_expr);
- if (std::find(region_args[sg].begin(),
- region_args[sg].end(), cand) == region_args[sg].end()) {
+ if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
+ region_args[sg].end()) {
region_args[sg].push_back(cand);
- }
+ }
return std::move(var);
} else {
CHECK_EQ(call->op, compiler_end_op);
- // The annotation node is inserted on edge so it must have only one argument.
+ // The annotation node is inserted on edge so it must have only one
+ // argument.
CHECK_EQ(call->args.size(), 1U);
AnnotatedRegion region = GetRegion(GetRef<Call>(call));
// (each annotated regions) --> created function
if (region_function_calls.find(region) != region_function_calls.end()) {
- // This section is executed only if there are multiple outputs in the region
- // Thus, the function is always created and at the end there would be a tuple node
- // Therefore, we insert a tuple get item node.
+ // This section is executed only if there are multiple outputs in the
+ // region Thus, the function is always created and at the end there
+ // would be a tuple node Therefore, we insert a tuple get item node.
- // Use the already created tuple node
+ // Use the already created tuple node
auto sg_call = region_function_calls[region];
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
Function global_region_func;
if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple
- global_region_func = Function(params, fields[0],
- call->args[0]->checked_type_, {}, DictAttrs());
+ global_region_func =
+ Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol,
tir::StringImmNode::make(name));
- global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive,
- tvm::Integer(1));
+ global_region_func =
+ WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::tir::StringImmNode::make(target));
- global_region_func = WithAttr(std::move(global_region_func), attr::kInline,
- tvm::Integer(1));
+ global_region_func =
+ WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
- << "Global function " << fname << " already exists";
+ << "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
- // codegen to the module scope and rely on the pass manager to prevent relay
- // function level passes (i.e. simplify inference and fusion) optimizing it.
+ // codegen to the module scope and rely on the pass manager to prevent
+ // relay function level passes (i.e. simplify inference and fusion)
+ // optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
region_function_calls[region] = ret;
if (region->GetOutputs().size() == 1) {
- // If there is only a single output; no need to add a tuplegetitem node
+ // If there is only a single output; no need to add a tuplegetitem
+ // node
return std::move(ret);
} else {
// Add a tuplegetitem node to select this output out of many
}
}
- Expr VisitExpr_(const TupleNode *op) final {
+ Expr VisitExpr_(const TupleNode* op) final {
auto region = GetRegion(GetRef<Tuple>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
}
}
- Expr VisitExpr_(const TupleGetItemNode *g) final {
+ Expr VisitExpr_(const TupleGetItemNode* g) final {
auto region = GetRegion(GetRef<TupleGetItem>(g));
if (!region.defined()) {
return ExprMutator::VisitExpr_(g);
}
}
- Expr VisitExpr_(const FunctionNode *op) final {
+ Expr VisitExpr_(const FunctionNode* op) final {
auto region = GetRegion(GetRef<Function>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
}
}
- Expr VisitExpr_(const LetNode *op) final {
+ Expr VisitExpr_(const LetNode* op) final {
auto region = GetRegion(GetRef<Let>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
}
}
- Expr VisitExpr_(const IfNode *op) final {
+ Expr VisitExpr_(const IfNode* op) final {
auto region = GetRegion(GetRef<If>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
}
}
- Expr VisitExpr_(const RefCreateNode *op) final {
+ Expr VisitExpr_(const RefCreateNode* op) final {
auto region = GetRegion(GetRef<RefCreate>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
}
}
- Expr VisitExpr_(const RefReadNode *op) final {
+ Expr VisitExpr_(const RefReadNode* op) final {
auto region = GetRegion(GetRef<RefRead>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
}
}
- Expr VisitExpr_(const RefWriteNode *op) final {
+ Expr VisitExpr_(const RefWriteNode* op) final {
auto region = GetRegion(GetRef<RefWrite>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
IRModule Partition() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
- if (auto *fn = pair.second.as<FunctionNode>()) {
+ if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
- func = Function(func->params,
- VisitExpr(func->body),
- func->ret_type,
- func->type_params,
+ func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
int idx = 0;
for (auto arg_ : sg->GetInputs()) {
if (arg == arg_) {
- return idx;
+ return idx;
}
idx++;
}
/*!
* \brief This map maintains the already created function calls.
- * This is required in the multi-output scenario, to link rest of the outputs to call
+ * This is required in the multi-output scenario, to link rest of the outputs
+ * to call
*/
std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
/*!
- * \brief This map maintains arguments (of region) visits through visitor patterns.
- * Those arguement var and expression will be used to when creating the function.
+ * \brief This map maintains arguments (of region) visits through visitor
+ * patterns. Those arguement var and expression will be used to when creating
+ * the function.
*/
- std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>,
- ObjectHash, ObjectEqual> region_args;
+ std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
+ region_args;
/*!
* \brief Each region set is associated with a function in the module.
- * This map maintains the mapping between regionsets and the function it belongs to
+ * This map maintains the mapping between regionsets and the function it
+ * belongs to
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
IRModule module_;
};
-
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
- [=](IRModule m, PassContext pc) {
- return partitioning::Partitioner(m).Partition();
- };
+ [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}
-TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph")
-.set_body_typed(transform::PartitionGraph);
+TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed(transform::PartitionGraph);
} // namespace transform