#include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
+#include "helpers/CreateCircleConst.h"
+
#include <luci/test/TestIOGraph.h>
#include <luci/IR/CircleNodes.h>
using namespace luci::test;
-// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
- const std::vector<uint32_t> &shape,
- const std::vector<T> &values)
-{
- auto node = g->nodes()->create<luci::CircleConst>();
- node->dtype(dtype);
- node->rank(shape.size());
-
- uint32_t size = 1;
- for (uint32_t i = 0; i < shape.size(); ++i)
- {
- node->dim(i) = shape.at(i);
- size *= shape.at(i);
- }
- node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT) \
- { \
- node->size<DT>(size); \
- for (uint32_t i = 0; i < values.size(); ++i) \
- node->at<DT>(i) = values[i]; \
- }
-
- switch (dtype)
- {
- case loco::DataType::U8:
- INIT_VALUES(loco::DataType::U8);
- break;
- case loco::DataType::S16:
- INIT_VALUES(loco::DataType::S16);
- break;
- case loco::DataType::S32:
- INIT_VALUES(loco::DataType::S32);
- break;
- case loco::DataType::FLOAT32:
- INIT_VALUES(loco::DataType::FLOAT32)
- break;
- default:
- INTERNAL_EXN("create_const_node called with unsupported type");
- break;
- }
- return node;
-}
-
/**
* Simple graph for test
*
_tr_y = g->nodes()->create<luci::CircleTranspose>();
_tr_y->a(_y);
std::vector<int32_t> tr_val = {1, 0};
- _tr_y->perm(create_const_node(g, loco::DataType::S32, {2}, tr_val));
+ _tr_y->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_val));
_fc = g->nodes()->create<luci::CircleFullyConnected>();
_fc->input(_x);
_fc->shape(r_shape);
auto l = _fc->dim(_fc->rank() - 1).value();
std::vector<float> bias_val(l, bv);
- _fc->bias(create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
+ _fc->bias(luci::create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
_fc->name("fc");
}