Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseAddWithFullyConnectedPass.test.cpp
index 3007965..b132c6b 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
 
+#include "helpers/CreateCircleConst.h"
+
 #include <luci/IR/CircleNodes.h>
 
 #include <luci/test/TestIOGraph.h>
@@ -27,52 +29,6 @@ namespace
 
 using namespace luci::test;
 
-// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
-                                     const std::vector<uint32_t> &shape,
-                                     const std::vector<T> &values)
-{
-  auto node = g->nodes()->create<luci::CircleConst>();
-  node->dtype(dtype);
-  node->rank(shape.size());
-
-  uint32_t size = 1;
-  for (uint32_t i = 0; i < shape.size(); ++i)
-  {
-    node->dim(i) = shape.at(i);
-    size *= shape.at(i);
-  }
-  node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT)                          \
-  {                                              \
-    node->size<DT>(size);                        \
-    for (uint32_t i = 0; i < values.size(); ++i) \
-      node->at<DT>(i) = values[i];               \
-  }
-
-  switch (dtype)
-  {
-    case loco::DataType::U8:
-      INIT_VALUES(loco::DataType::U8);
-      break;
-    case loco::DataType::S16:
-      INIT_VALUES(loco::DataType::S16);
-      break;
-    case loco::DataType::S32:
-      INIT_VALUES(loco::DataType::S32);
-      break;
-    case loco::DataType::FLOAT32:
-      INIT_VALUES(loco::DataType::FLOAT32)
-      break;
-    default:
-      INTERNAL_EXN("create_const_node called with unsupported type");
-      break;
-  }
-  return node;
-}
-
 /**
  *  Simple graph for test
  *
@@ -95,10 +51,10 @@ public:
   void init(loco::Graph *g)
   {
     std::vector<float> weights_val(16 * 4);
-    _fc_f = create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
+    _fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
 
     std::vector<float> bias_val(16);
-    _fc_b = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
+    _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
 
     _fc = g->nodes()->create<luci::CircleFullyConnected>();
     _fc->weights(_fc_f);
@@ -111,7 +67,7 @@ public:
     std::vector<float> addition_val;
     for (uint32_t i = 0; i < 16; i++)
       addition_val.push_back(static_cast<float>(i));
-    _add_c = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
+    _add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
 
     _add = g->nodes()->create<luci::CircleAdd>();
     _add->x(_fc);