[RELAY] Partition graph codestyle fixes (#5202)
authormanupa-arm <61496855+manupa-arm@users.noreply.github.com>
Wed, 1 Apr 2020 17:43:44 +0000 (18:43 +0100)
committerGitHub <noreply@github.com>
Wed, 1 Apr 2020 17:43:44 +0000 (02:43 +0900)
* [RELAY] Codestyle fixes for Graph Partitioner
*ran through clang-format

* *formatting comments

* *further codestyle changes (after clang-format)

src/relay/transforms/partition_graph.cc

index 32db40a..d8e93ed 100644 (file)
  * external functions, and they will use the provided compiler for codegen.
  */
 
+#include <tvm/ir/error.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
 #include <tvm/relay/expr.h>
-#include <tvm/ir/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
 
-#include <utility>
 #include <unordered_map>
 #include <unordered_set>
+#include <utility>
 #include <vector>
 
-#include "../backend/utils.h"
 #include "../analysis/annotated_region_set.h"
-
+#include "../backend/utils.h"
 
 namespace tvm {
 namespace relay {
@@ -73,7 +72,7 @@ class AnnotationChecker : public ExprVisitor {
     return true;
   }
 
-  void VisitExpr_(const CallNode *call) final {
+  void VisitExpr_(const CallNodecall) final {
     auto op_node = call->op.as<OpNode>();
     if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
       return;
@@ -95,31 +94,33 @@ class AnnotationChecker : public ExprVisitor {
  * in the TVM stack.
  *
  * Input : A Relay module that have functions with disjoint annotated regions
- *         using compiler_begin and compiler_end. There could be multiple outputs.
+ *         using compiler_begin and compiler_end. There could be multiple
+ * outputs.
  *
- * Output : A Relay module with global functions for such disjoint annotated regions
- *          with calls inserted at the respective location
+ * Output : A Relay module with global functions for such disjoint annotated
+ * regions with calls inserted at the respective location
  *
- * Dependencies : RegionSet Utility class.
+ * Dependencies : AnnotatedRegionSet Utility class.
  *
  * Methodology :
- *      1) The RegionSet utility class is able to construct a collection of
- *         nodes that are bound by a given annotation -- here we use compiler_begin
- *         and compiler_end
+ *      1) The AnnotatedRegionSet utility class is able to construct a collection
+ *      of nodes that are bound by a given annotation -- here we use
+ *      compiler_begin and compiler_end
  *      2) Initially, for each function in the module RegionSets are populated.
  *      3) Then, Vistor pass is traversed until a compiler_end node is encountered
  *         that belongs to a "region".
- *      4) When the first compiler_end of a given annotated region is found, a function is
- *         formed and inserted.
- *         a) if the region has multiple outputs, a Tuple node (capturing all outputs)
- *            is returned.
- *      5) Thereafter, if we encounter an another output of the same annotated region,
- *         it is important to note that the function is already formed. Therefore, it will
- *         lookup the function and add a TupleGetItemNode.
- *          a) We will use the location index of "rets" of each "Region" of RegionSet
- *             as TupleGetItemNode index.
- *      6) Therefore, functions will be created for all annotated regions. The name for each
- *         global function is created using "Region" id and the compiler name.
+ *      4) When the first compiler_end of a given annotated region is found,
+ *         a function is formed and inserted.
+ *         a) if the region has multiple outputs, a Tuple node (capturing
+ *            all outputs) is returned.
+ *      5) Thereafter, if we encounter an another output of the same annotated
+ *         region, it is important to note that the function is already formed.
+ *         Therefore, it will lookup the function and add a TupleGetItemNode.
+ *         a) We will use the location index of "rets" of each Region" of
+ *         AnnotatedRegionSet as TupleGetItemNode index.
+ *      6) Therefore, functions will be created for all annotated regions.
+ *         The name for each global function is created using "Region" id and
+ *         the compiler name.
  */
 
 class Partitioner : public ExprMutator {
@@ -136,12 +137,13 @@ class Partitioner : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const CallNode *call) final {
+  Expr VisitExpr_(const CallNodecall) final {
     auto op_node = call->op.as<OpNode>();
     if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
       return ExprMutator::VisitExpr_(call);
     } else if (call->op == compiler_begin_op) {
-      // The annotation node is inserted on edge so it must have only one argument.
+      // The annotation node is inserted on edge so it must have only one
+      // argument.
       CHECK_EQ(call->args.size(), 1U);
 
       // Traverse the rest graph.
@@ -153,20 +155,21 @@ class Partitioner : public ExprMutator {
       // The type of the created variable is the same as the compiler_begin
       // node.
       std::string target = call->attrs.as<CompilerAttrs>()->compiler;
-      std::string varname = target + "_" + std::to_string(sg->GetID())
-                            + "_i" + std::to_string(index);
+      std::string varname =
+          target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
       auto var = Var(varname, GetRef<Call>(call)->checked_type_);
 
       auto cand = std::make_pair(var, input_expr);
-      if (std::find(region_args[sg].begin(),
-                    region_args[sg].end(), cand) == region_args[sg].end()) {
+      if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
+          region_args[sg].end()) {
         region_args[sg].push_back(cand);
-     }
+      }
 
       return std::move(var);
     } else {
       CHECK_EQ(call->op, compiler_end_op);
-      // The annotation node is inserted on edge so it must have only one argument.
+      // The annotation node is inserted on edge so it must have only one
+      // argument.
       CHECK_EQ(call->args.size(), 1U);
 
       AnnotatedRegion region = GetRegion(GetRef<Call>(call));
@@ -185,11 +188,11 @@ class Partitioner : public ExprMutator {
       // (each annotated regions) --> created function
 
       if (region_function_calls.find(region) != region_function_calls.end()) {
-      // This section is executed only if there are multiple outputs in the region
-      // Thus, the function is always created and at the end there would be a tuple node
-      // Therefore, we insert a tuple get item node.
+        // This section is executed only if there are multiple outputs in the
+        // region Thus, the function is always created and at the end there
+        // would be a tuple node Therefore, we insert a tuple get item node.
 
-      // Use the already created tuple node
+        // Use the already created tuple node
         auto sg_call = region_function_calls[region];
         int index = GetRetIdx(region, GetRef<Call>(call));
         CHECK_NE(index, -1);
@@ -226,8 +229,8 @@ class Partitioner : public ExprMutator {
         Function global_region_func;
         if (region->GetOutputs().size() == 1) {
           // If there are only a single output; no need to add a tuple
-          global_region_func = Function(params, fields[0],
-                                  call->args[0]->checked_type_, {}, DictAttrs());
+          global_region_func =
+              Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
         } else {
           auto tuple = Tuple(fields);
           global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
@@ -238,12 +241,12 @@ class Partitioner : public ExprMutator {
 
         global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol,
                                       tir::StringImmNode::make(name));
-        global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive,
-                            tvm::Integer(1));
+        global_region_func =
+            WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
         global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
                                       tvm::tir::StringImmNode::make(target));
-        global_region_func = WithAttr(std::move(global_region_func), attr::kInline,
-                            tvm::Integer(1));
+        global_region_func =
+            WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
 
         // Constant propagation
         if (!params_bind.empty()) {
@@ -252,11 +255,12 @@ class Partitioner : public ExprMutator {
 
         std::string fname = name;
         CHECK(!module_->ContainGlobalVar(fname))
-                << "Global function " << fname << " already exists";
+            << "Global function " << fname << " already exists";
         // Create a global function and add it to the IRModule for the region.
         // This way we lift the functions that should be handled by external
-        // codegen to the module scope and rely on the pass manager to prevent relay
-        // function level passes (i.e. simplify inference and fusion) optimizing it.
+        // codegen to the module scope and rely on the pass manager to prevent
+        // relay function level passes (i.e. simplify inference and fusion)
+        // optimizing it.
         GlobalVar glob_func(fname);
         module_->Add(glob_func, global_region_func);
 
@@ -266,7 +270,8 @@ class Partitioner : public ExprMutator {
         region_function_calls[region] = ret;
 
         if (region->GetOutputs().size() == 1) {
-          // If there is only a single output; no need to add a tuplegetitem node
+          // If there is only a single output; no need to add a tuplegetitem
+          // node
           return std::move(ret);
         } else {
           // Add a tuplegetitem node to select this output out of many
@@ -278,7 +283,7 @@ class Partitioner : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const TupleNode *op) final {
+  Expr VisitExpr_(const TupleNodeop) final {
     auto region = GetRegion(GetRef<Tuple>(op));
     if (!region.defined()) {
       return ExprMutator::VisitExpr_(op);
@@ -291,7 +296,7 @@ class Partitioner : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const TupleGetItemNode *g) final {
+  Expr VisitExpr_(const TupleGetItemNodeg) final {
     auto region = GetRegion(GetRef<TupleGetItem>(g));
     if (!region.defined()) {
       return ExprMutator::VisitExpr_(g);
@@ -301,7 +306,7 @@ class Partitioner : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const FunctionNode *op) final {
+  Expr VisitExpr_(const FunctionNodeop) final {
     auto region = GetRegion(GetRef<Function>(op));
     if (!region.defined()) {
       return ExprMutator::VisitExpr_(op);
@@ -316,7 +321,7 @@ class Partitioner : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const LetNode *op) final {
+  Expr VisitExpr_(const LetNodeop) final {
     auto region = GetRegion(GetRef<Let>(op));
     if (!region.defined()) {
       return ExprMutator::VisitExpr_(op);
@@ -328,7 +333,7 @@ class Partitioner : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const IfNode *op) final {
+  Expr VisitExpr_(const IfNodeop) final {
     auto region = GetRegion(GetRef<If>(op));
     if (!region.defined()) {
       return ExprMutator::VisitExpr_(op);
@@ -340,7 +345,7 @@ class Partitioner : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const RefCreateNode *op) final {
+  Expr VisitExpr_(const RefCreateNodeop) final {
     auto region = GetRegion(GetRef<RefCreate>(op));
     if (!region.defined()) {
       return ExprMutator::VisitExpr_(op);
@@ -350,7 +355,7 @@ class Partitioner : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const RefReadNode *op) final {
+  Expr VisitExpr_(const RefReadNodeop) final {
     auto region = GetRegion(GetRef<RefRead>(op));
     if (!region.defined()) {
       return ExprMutator::VisitExpr_(op);
@@ -360,7 +365,7 @@ class Partitioner : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const RefWriteNode *op) final {
+  Expr VisitExpr_(const RefWriteNodeop) final {
     auto region = GetRegion(GetRef<RefWrite>(op));
     if (!region.defined()) {
       return ExprMutator::VisitExpr_(op);
@@ -374,12 +379,9 @@ class Partitioner : public ExprMutator {
   IRModule Partition() {
     auto glob_funcs = module_->functions;
     for (const auto& pair : glob_funcs) {
-      if (auto *fn = pair.second.as<FunctionNode>()) {
+      if (autofn = pair.second.as<FunctionNode>()) {
         auto func = GetRef<Function>(fn);
-        func = Function(func->params,
-                        VisitExpr(func->body),
-                        func->ret_type,
-                        func->type_params,
+        func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
                         func->attrs);
         module_->Update(pair.first, func);
       }
@@ -428,7 +430,7 @@ class Partitioner : public ExprMutator {
     int idx = 0;
     for (auto arg_ : sg->GetInputs()) {
       if (arg == arg_) {
-       return idx;
+        return idx;
       }
       idx++;
     }
@@ -452,41 +454,40 @@ class Partitioner : public ExprMutator {
 
   /*!
    * \brief This map maintains the already created function calls.
-   * This is required in the multi-output scenario, to link rest of the outputs to call
+   * This is required in the multi-output scenario, to link rest of the outputs
+   * to call
    */
   std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
 
   /*!
-   * \brief This map maintains arguments (of region) visits through visitor patterns.
-   * Those arguement var and expression will be used to when creating the function.
+   * \brief This map maintains arguments (of region) visits through visitor
+   * patterns. Those arguement var and expression will be used to when creating
+   * the function.
    */
-  std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>,
-                                         ObjectHash, ObjectEqual> region_args;
+  std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
+      region_args;
 
   /*!
    * \brief Each region set is associated with a function in the module.
-   * This map maintains the mapping between regionsets and the function it belongs to
+   * This map maintains the mapping between regionsets and the function it
+   * belongs to
    */
   std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
   IRModule module_;
 };
 
-
 }  // namespace partitioning
 
 namespace transform {
 
 Pass PartitionGraph() {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
-                                    [=](IRModule m, PassContext pc) {
-     return partitioning::Partitioner(m).Partition();
-  };
+      [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
   auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
   return Sequential({partitioned, InferType()});
 }
 
-TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph")
-.set_body_typed(transform::PartitionGraph);
+TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed(transform::PartitionGraph);
 
 }  // namespace transform