[luci] Extract export_node to reduce class LOC 2 (#3960)
authorSaeHie Park <saehie.park@gmail.com>
Tue, 25 Aug 2020 02:16:11 +0000 (11:16 +0900)
committerGitHub <noreply@github.com>
Tue, 25 Aug 2020 02:16:11 +0000 (11:16 +0900)
This will revise OperationExporter with extracting class methods to functions to reduce class LOC

ONE-DCO-1.0-Signed-off-by: SaeHie Park <saehie.park@gmail.com>

compiler/luci/export/src/CircleOperationExporter.cpp

index c6c48d9..fef08ac 100644 (file)
@@ -354,6 +354,258 @@ void export_node(ExportContext &ctx, luci::CircleReverseV2 *node)
   ctx.gd._operators.push_back(op_offset);
 }
 
+void export_node(ExportContext &ctx, luci::CircleSplit *node)
+{
+  auto split_outs = loco::succs(node);
+  assert(int32_t(split_outs.size()) == node->num_split());
+
+  uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT, node->op_version());
+  // NOTE BuiltinOperator_SPLIT input is placed at second position
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->split_dim()),
+                                  get_tensor_index(node->input())};
+  std::vector<int32_t> outputs_vec;
+
+  for (int32_t index = 0; index < node->num_split(); index++)
+  {
+    // store in order of index
+    bool found = false;
+    for (auto out : split_outs)
+    {
+      auto split_out = loco::must_cast<luci::CircleSplitOut *>(out);
+      if (split_out->index() == index)
+      {
+        outputs_vec.push_back(get_tensor_index(split_out));
+        found = true;
+        break;
+      }
+    }
+    if (!found)
+    {
+      INTERNAL_EXN("Invalid Split output");
+    }
+  }
+
+  auto inputs = ctx.builder.CreateVector(inputs_vec);
+  auto outputs = ctx.builder.CreateVector(outputs_vec);
+  auto options = CreateSplitOptions(ctx.builder, node->num_split());
+  auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
+                                  circle::BuiltinOptions_SplitOptions, options.Union());
+  ctx.gd._operators.push_back(op_offset);
+}
+
+void export_node(ExportContext &ctx, luci::CircleSplitV *node)
+{
+  auto split_outs = loco::succs(node);
+  assert(int32_t(split_outs.size()) == node->num_split());
+
+  uint32_t op_idx =
+      ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT_V, node->op_version());
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->input()),
+                                  get_tensor_index(node->size_splits()),
+                                  get_tensor_index(node->split_dim())};
+  std::vector<int32_t> outputs_vec;
+
+  for (int32_t index = 0; index < node->num_split(); index++)
+  {
+    // store in order of index
+    bool found = false;
+    for (auto out : split_outs)
+    {
+      auto split_out = loco::must_cast<luci::CircleSplitVOut *>(out);
+      if (split_out->index() == index)
+      {
+        outputs_vec.push_back(get_tensor_index(split_out));
+        found = true;
+        break;
+      }
+    }
+    if (!found)
+    {
+      INTERNAL_EXN("Invalid SplitV output");
+    }
+  }
+
+  auto inputs = ctx.builder.CreateVector(inputs_vec);
+  auto outputs = ctx.builder.CreateVector(outputs_vec);
+  auto options = CreateSplitVOptions(ctx.builder, node->num_split());
+  auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
+                                  circle::BuiltinOptions_SplitVOptions, options.Union());
+  ctx.gd._operators.push_back(op_offset);
+}
+
+void export_node(ExportContext &ctx, luci::CircleTopKV2 *node)
+{
+  auto topkv2_outs = loco::succs(node);
+  int outs_count = int32_t(topkv2_outs.size());
+  assert(outs_count == 2);
+
+  uint32_t op_idx =
+      ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_TOPK_V2, node->op_version());
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->k())};
+  std::vector<int32_t> outputs_vec;
+
+  for (int32_t index = 0; index < outs_count; index++)
+  {
+    // store in order of index
+    bool found = false;
+    for (auto out : topkv2_outs)
+    {
+      auto topkv2_out = loco::must_cast<luci::CircleTopKV2Out *>(out);
+      if (topkv2_out->index() == index)
+      {
+        outputs_vec.push_back(get_tensor_index(topkv2_out));
+        found = true;
+        break;
+      }
+    }
+    if (!found)
+    {
+      INTERNAL_EXN("Invalid TopKV2 output");
+    }
+  }
+
+  auto inputs = ctx.builder.CreateVector(inputs_vec);
+  auto outputs = ctx.builder.CreateVector(outputs_vec);
+  auto options = CreateTopKV2Options(ctx.builder);
+  auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
+                                  circle::BuiltinOptions_TopKV2Options, options.Union());
+  ctx.gd._operators.push_back(op_offset);
+}
+
+void export_node(ExportContext &ctx, luci::CircleUnique *node)
+{
+  auto unique_outs = loco::succs(node);
+  assert(int32_t(unique_outs.size()) == 2);
+  uint32_t op_idx =
+      ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNIQUE, node->op_version());
+
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+  std::vector<int32_t> outputs_vec;
+
+  for (int32_t index = 0; index < 2; index++)
+  {
+    // store in order of index
+    bool found = false;
+    for (auto out : unique_outs)
+    {
+      auto unique_out = loco::must_cast<luci::CircleUniqueOut *>(out);
+      if (unique_out->index() == index)
+      {
+        outputs_vec.push_back(get_tensor_index(unique_out));
+        found = true;
+        break;
+      }
+    }
+    if (!found)
+    {
+      INTERNAL_EXN("Invalid Unique output");
+    }
+  }
+
+  auto inputs = ctx.builder.CreateVector(inputs_vec);
+  auto outputs = ctx.builder.CreateVector(outputs_vec);
+  auto options = CreateUniqueOptions(ctx.builder, to_circle_tensortype(node->idx_out_type()));
+  auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
+                                  circle::BuiltinOptions_UniqueOptions, options.Union());
+  ctx.gd._operators.push_back(op_offset);
+}
+
+void export_node(ExportContext &ctx, luci::CircleUnpack *node)
+{
+  LOGGER(l);
+  auto settings = luci::UserSettings::settings();
+
+  auto unpack_outs = loco::succs(node);
+  // NOTE real models may not use all of the outputs
+  if (static_cast<int32_t>(unpack_outs.size()) != node->num())
+  {
+    if (settings->get(luci::UserSettings::Key::DisableValidation))
+    {
+      WARN(l) << "Warning: export Unpack(" << node->name() << ") 'num' not same as outputs";
+    }
+    else
+      assert(false);
+  }
+
+  uint32_t op_idx =
+      ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNPACK, node->op_version());
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
+  std::vector<int32_t> outputs_vec;
+
+  for (int32_t index = 0; index < node->num(); index++)
+  {
+    // store in order of index
+    bool found = false;
+    for (auto out : unpack_outs)
+    {
+      auto unpack_out = loco::must_cast<luci::CircleUnpackOut *>(out);
+      if (unpack_out->index() == index)
+      {
+        outputs_vec.push_back(get_tensor_index(unpack_out));
+        found = true;
+        break;
+      }
+    }
+    // NOTE real models may not use all of the outputs
+    if (!found)
+    {
+      if (settings->get(luci::UserSettings::Key::DisableValidation))
+      {
+        WARN(l) << "Warning: export Unpack(" << node->name() << ") output " << index << " not used";
+      }
+      else
+        assert(false);
+    }
+  }
+
+  auto inputs = ctx.builder.CreateVector(inputs_vec);
+  auto outputs = ctx.builder.CreateVector(outputs_vec);
+  auto options = CreateUnpackOptions(ctx.builder, node->num(), node->axis());
+  auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
+                                  circle::BuiltinOptions_UnpackOptions, options.Union());
+  ctx.gd._operators.push_back(op_offset);
+}
+
+void export_node(ExportContext &ctx, luci::CircleWhile *node)
+{
+  auto while_outs = loco::succs(node);
+  assert(while_outs.size() == node->output_count());
+
+  uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_WHILE, node->op_version());
+  std::vector<int32_t> inputs_vec;
+  std::vector<int32_t> outputs_vec;
+
+  for (uint32_t idx = 0; idx < node->input_count(); ++idx)
+    inputs_vec.push_back(get_tensor_index(node->input(idx)));
+
+  for (uint32_t idx = 0; idx < node->output_count(); ++idx)
+  {
+    // store in order of index
+    bool found = false;
+    for (auto out : while_outs)
+    {
+      auto while_out = loco::must_cast<luci::CircleWhileOut *>(out);
+      if (while_out->index() == static_cast<int32_t>(idx))
+      {
+        outputs_vec.push_back(get_tensor_index(while_out));
+        found = true;
+        break;
+      }
+    }
+    if (!found)
+    {
+      INTERNAL_EXN("Invalid CircleWhile output");
+    }
+  }
+
+  auto inputs = ctx.builder.CreateVector(inputs_vec);
+  auto outputs = ctx.builder.CreateVector(outputs_vec);
+  auto options = CreateWhileOptions(ctx.builder, node->cond_branch(), node->body_branch());
+  auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
+                                  circle::BuiltinOptions_WhileOptions, options.Union());
+  ctx.gd._operators.push_back(op_offset);
+}
+
 class OperationExporter final : public luci::CircleNodeMutableVisitor<void>,
                                 public loco::CanonicalNodeMutableVisitor<void>
 {
@@ -1078,80 +1330,14 @@ void OperationExporter::visit(luci::CircleSparseToDense *node)
 
 void OperationExporter::visit(luci::CircleSplit *node)
 {
-  auto split_outs = loco::succs(node);
-  assert(int32_t(split_outs.size()) == node->num_split());
-
-  uint32_t op_idx = md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT, node->op_version());
-  // NOTE BuiltinOperator_SPLIT input is placed at second position
-  std::vector<int32_t> inputs_vec{get_tensor_index(node->split_dim()),
-                                  get_tensor_index(node->input())};
-  std::vector<int32_t> outputs_vec;
-
-  for (int32_t index = 0; index < node->num_split(); index++)
-  {
-    // store in order of index
-    bool found = false;
-    for (auto out : split_outs)
-    {
-      auto split_out = loco::must_cast<luci::CircleSplitOut *>(out);
-      if (split_out->index() == index)
-      {
-        outputs_vec.push_back(get_tensor_index(split_out));
-        found = true;
-        break;
-      }
-    }
-    if (!found)
-    {
-      INTERNAL_EXN("Invalid Split output");
-    }
-  }
-
-  auto inputs = builder.CreateVector(inputs_vec);
-  auto outputs = builder.CreateVector(outputs_vec);
-  auto options = CreateSplitOptions(builder, node->num_split());
-  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
-                                  circle::BuiltinOptions_SplitOptions, options.Union());
-  gd._operators.push_back(op_offset);
+  ExportContext ctx{builder, md, gd};
+  export_node(ctx, node);
 }
 
 void OperationExporter::visit(luci::CircleSplitV *node)
 {
-  auto split_outs = loco::succs(node);
-  assert(int32_t(split_outs.size()) == node->num_split());
-
-  uint32_t op_idx = md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT_V, node->op_version());
-  std::vector<int32_t> inputs_vec{get_tensor_index(node->input()),
-                                  get_tensor_index(node->size_splits()),
-                                  get_tensor_index(node->split_dim())};
-  std::vector<int32_t> outputs_vec;
-
-  for (int32_t index = 0; index < node->num_split(); index++)
-  {
-    // store in order of index
-    bool found = false;
-    for (auto out : split_outs)
-    {
-      auto split_out = loco::must_cast<luci::CircleSplitVOut *>(out);
-      if (split_out->index() == index)
-      {
-        outputs_vec.push_back(get_tensor_index(split_out));
-        found = true;
-        break;
-      }
-    }
-    if (!found)
-    {
-      INTERNAL_EXN("Invalid SplitV output");
-    }
-  }
-
-  auto inputs = builder.CreateVector(inputs_vec);
-  auto outputs = builder.CreateVector(outputs_vec);
-  auto options = CreateSplitVOptions(builder, node->num_split());
-  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
-                                  circle::BuiltinOptions_SplitVOptions, options.Union());
-  gd._operators.push_back(op_offset);
+  ExportContext ctx{builder, md, gd};
+  export_node(ctx, node);
 }
 
 void OperationExporter::visit(luci::CircleSqrt *node)
@@ -1215,40 +1401,8 @@ void OperationExporter::visit(luci::CircleTile *node)
 
 void OperationExporter::visit(luci::CircleTopKV2 *node)
 {
-  auto topkv2_outs = loco::succs(node);
-  int outs_count = int32_t(topkv2_outs.size());
-  assert(outs_count == 2);
-
-  uint32_t op_idx = md.registerBuiltinOpcode(circle::BuiltinOperator_TOPK_V2, node->op_version());
-  std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->k())};
-  std::vector<int32_t> outputs_vec;
-
-  for (int32_t index = 0; index < outs_count; index++)
-  {
-    // store in order of index
-    bool found = false;
-    for (auto out : topkv2_outs)
-    {
-      auto topkv2_out = loco::must_cast<luci::CircleTopKV2Out *>(out);
-      if (topkv2_out->index() == index)
-      {
-        outputs_vec.push_back(get_tensor_index(topkv2_out));
-        found = true;
-        break;
-      }
-    }
-    if (!found)
-    {
-      INTERNAL_EXN("Invalid TopKV2 output");
-    }
-  }
-
-  auto inputs = builder.CreateVector(inputs_vec);
-  auto outputs = builder.CreateVector(outputs_vec);
-  auto options = CreateTopKV2Options(builder);
-  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
-                                  circle::BuiltinOptions_TopKV2Options, options.Union());
-  gd._operators.push_back(op_offset);
+  ExportContext ctx{builder, md, gd};
+  export_node(ctx, node);
 }
 
 void OperationExporter::visit(luci::CircleTranspose *node)
@@ -1268,94 +1422,14 @@ void OperationExporter::visit(luci::CircleTransposeConv *node)
 
 void OperationExporter::visit(luci::CircleUnique *node)
 {
-  auto unique_outs = loco::succs(node);
-  assert(int32_t(unique_outs.size()) == 2);
-  uint32_t op_idx = md.registerBuiltinOpcode(circle::BuiltinOperator_UNIQUE, node->op_version());
-
-  std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
-  std::vector<int32_t> outputs_vec;
-
-  for (int32_t index = 0; index < 2; index++)
-  {
-    // store in order of index
-    bool found = false;
-    for (auto out : unique_outs)
-    {
-      auto unique_out = loco::must_cast<luci::CircleUniqueOut *>(out);
-      if (unique_out->index() == index)
-      {
-        outputs_vec.push_back(get_tensor_index(unique_out));
-        found = true;
-        break;
-      }
-    }
-    if (!found)
-    {
-      INTERNAL_EXN("Invalid Unique output");
-    }
-  }
-
-  auto inputs = builder.CreateVector(inputs_vec);
-  auto outputs = builder.CreateVector(outputs_vec);
-  auto options = CreateUniqueOptions(builder, to_circle_tensortype(node->idx_out_type()));
-  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
-                                  circle::BuiltinOptions_UniqueOptions, options.Union());
-  gd._operators.push_back(op_offset);
+  ExportContext ctx{builder, md, gd};
+  export_node(ctx, node);
 }
 
 void OperationExporter::visit(luci::CircleUnpack *node)
 {
-  LOGGER(l);
-  auto settings = luci::UserSettings::settings();
-
-  auto unpack_outs = loco::succs(node);
-  // NOTE real models may not use all of the outputs
-  if (static_cast<int32_t>(unpack_outs.size()) != node->num())
-  {
-    if (settings->get(luci::UserSettings::Key::DisableValidation))
-    {
-      WARN(l) << "Warning: export Unpack(" << node->name() << ") 'num' not same as outputs";
-    }
-    else
-      assert(false);
-  }
-
-  uint32_t op_idx = md.registerBuiltinOpcode(circle::BuiltinOperator_UNPACK, node->op_version());
-  std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
-  std::vector<int32_t> outputs_vec;
-
-  for (int32_t index = 0; index < node->num(); index++)
-  {
-    // store in order of index
-    bool found = false;
-    for (auto out : unpack_outs)
-    {
-      auto unpack_out = loco::must_cast<luci::CircleUnpackOut *>(out);
-      if (unpack_out->index() == index)
-      {
-        outputs_vec.push_back(get_tensor_index(unpack_out));
-        found = true;
-        break;
-      }
-    }
-    // NOTE real models may not use all of the outputs
-    if (!found)
-    {
-      if (settings->get(luci::UserSettings::Key::DisableValidation))
-      {
-        WARN(l) << "Warning: export Unpack(" << node->name() << ") output " << index << " not used";
-      }
-      else
-        assert(false);
-    }
-  }
-
-  auto inputs = builder.CreateVector(inputs_vec);
-  auto outputs = builder.CreateVector(outputs_vec);
-  auto options = CreateUnpackOptions(builder, node->num(), node->axis());
-  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
-                                  circle::BuiltinOptions_UnpackOptions, options.Union());
-  gd._operators.push_back(op_offset);
+  ExportContext ctx{builder, md, gd};
+  export_node(ctx, node);
 }
 
 void OperationExporter::visit(luci::CircleWhere *node)
@@ -1366,42 +1440,8 @@ void OperationExporter::visit(luci::CircleWhere *node)
 
 void OperationExporter::visit(luci::CircleWhile *node)
 {
-  auto while_outs = loco::succs(node);
-  assert(while_outs.size() == node->output_count());
-
-  uint32_t op_idx = md.registerBuiltinOpcode(circle::BuiltinOperator_WHILE, node->op_version());
-  std::vector<int32_t> inputs_vec;
-  std::vector<int32_t> outputs_vec;
-
-  for (uint32_t idx = 0; idx < node->input_count(); ++idx)
-    inputs_vec.push_back(get_tensor_index(node->input(idx)));
-
-  for (uint32_t idx = 0; idx < node->output_count(); ++idx)
-  {
-    // store in order of index
-    bool found = false;
-    for (auto out : while_outs)
-    {
-      auto while_out = loco::must_cast<luci::CircleWhileOut *>(out);
-      if (while_out->index() == static_cast<int32_t>(idx))
-      {
-        outputs_vec.push_back(get_tensor_index(while_out));
-        found = true;
-        break;
-      }
-    }
-    if (!found)
-    {
-      INTERNAL_EXN("Invalid CircleWhile output");
-    }
-  }
-
-  auto inputs = builder.CreateVector(inputs_vec);
-  auto outputs = builder.CreateVector(outputs_vec);
-  auto options = CreateWhileOptions(builder, node->cond_branch(), node->body_branch());
-  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
-                                  circle::BuiltinOptions_WhileOptions, options.Union());
-  gd._operators.push_back(op_offset);
+  ExportContext ctx{builder, md, gd};
+  export_node(ctx, node);
 }
 
 void OperationExporter::visit(luci::CircleZerosLike *node)