Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / SubstituteSplitVToSplitPass.test.cpp
index 6e30103..43f9cc1 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
 
+#include "helpers/CreateCircleConst.h"
+
 #include <luci/test/TestIOGraph.h>
 
 #include <gtest/gtest.h>
@@ -30,51 +32,6 @@ const int C = 32;
 const int H = 8;
 const int W = 8;
 
-// Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
-                                     const std::vector<uint32_t> &shape,
-                                     const std::vector<T> &values)
-{
-  auto node = g->nodes()->create<luci::CircleConst>();
-  node->dtype(dtype);
-  node->rank(shape.size());
-
-  uint32_t size = 1;
-  for (uint32_t i = 0; i < shape.size(); ++i)
-  {
-    node->dim(i) = shape.at(i);
-    size *= shape.at(i);
-  }
-  node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT)                          \
-  {                                              \
-    node->size<DT>(size);                        \
-    for (uint32_t i = 0; i < values.size(); ++i) \
-      node->at<DT>(i) = values[i];               \
-  }
-
-  switch (dtype)
-  {
-    case loco::DataType::U8:
-      INIT_VALUES(loco::DataType::U8);
-      break;
-    case loco::DataType::S16:
-      INIT_VALUES(loco::DataType::S16);
-      break;
-    case loco::DataType::S32:
-      INIT_VALUES(loco::DataType::S32);
-      break;
-    case loco::DataType::FLOAT32:
-      INIT_VALUES(loco::DataType::FLOAT32)
-      break;
-    default:
-      INTERNAL_EXN("create_const_node called with unsupported type");
-      break;
-  }
-  return node;
-}
 /**
  *  graph having SplitV operator
  *
@@ -95,10 +52,10 @@ public:
   void init(loco::Graph *g)
   {
     const std::vector<int32_t> splits{16, 16};
-    auto size_splits = create_const_node(g, loco::DataType::S32, {2}, splits);
+    auto size_splits = luci::create_const_node(g, loco::DataType::S32, {2}, splits);
 
     const std::vector<int32_t> dim{3};
-    auto split_dim = create_const_node(g, loco::DataType::S32, {1}, dim);
+    auto split_dim = luci::create_const_node(g, loco::DataType::S32, {1}, dim);
 
     _sv = g->nodes()->create<luci::CircleSplitV>();
     _sv->size_splits(size_splits);