Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ReplaceNonConstFCWithBatchMatMulPass.test.cpp
index 93024f3..194893f 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
 
+#include "helpers/CreateCircleConst.h"
+
 #include <luci/test/TestIOGraph.h>
 #include <luci/IR/CircleNodes.h>
 
@@ -26,52 +28,6 @@ namespace
 
 using namespace luci::test;
 
-// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
-                                     const std::vector<uint32_t> &shape,
-                                     const std::vector<T> &values)
-{
-  auto node = g->nodes()->create<luci::CircleConst>();
-  node->dtype(dtype);
-  node->rank(shape.size());
-
-  uint32_t size = 1;
-  for (uint32_t i = 0; i < shape.size(); ++i)
-  {
-    node->dim(i) = shape.at(i);
-    size *= shape.at(i);
-  }
-  node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT)                          \
-  {                                              \
-    node->size<DT>(size);                        \
-    for (uint32_t i = 0; i < values.size(); ++i) \
-      node->at<DT>(i) = values[i];               \
-  }
-
-  switch (dtype)
-  {
-    case loco::DataType::U8:
-      INIT_VALUES(loco::DataType::U8);
-      break;
-    case loco::DataType::S16:
-      INIT_VALUES(loco::DataType::S16);
-      break;
-    case loco::DataType::S32:
-      INIT_VALUES(loco::DataType::S32);
-      break;
-    case loco::DataType::FLOAT32:
-      INIT_VALUES(loco::DataType::FLOAT32)
-      break;
-    default:
-      INTERNAL_EXN("create_const_node called with unsupported type");
-      break;
-  }
-  return node;
-}
-
 /**
  *  Simple graph for test
  *
@@ -104,7 +60,7 @@ public:
     _tr_y = g->nodes()->create<luci::CircleTranspose>();
     _tr_y->a(_y);
     std::vector<int32_t> tr_val = {1, 0};
-    _tr_y->perm(create_const_node(g, loco::DataType::S32, {2}, tr_val));
+    _tr_y->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_val));
 
     _fc = g->nodes()->create<luci::CircleFullyConnected>();
     _fc->input(_x);
@@ -114,7 +70,7 @@ public:
     _fc->shape(r_shape);
     auto l = _fc->dim(_fc->rank() - 1).value();
     std::vector<float> bias_val(l, bv);
-    _fc->bias(create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
+    _fc->bias(luci::create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
     _fc->name("fc");
   }