#include "luci/Pass/FuseAddWithFullyConnectedPass.h"
+#include "helpers/CreateCircleConst.h"
+
#include <luci/IR/CircleNodes.h>
#include <luci/test/TestIOGraph.h>
using namespace luci::test;
-// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
- const std::vector<uint32_t> &shape,
- const std::vector<T> &values)
-{
- auto node = g->nodes()->create<luci::CircleConst>();
- node->dtype(dtype);
- node->rank(shape.size());
-
- uint32_t size = 1;
- for (uint32_t i = 0; i < shape.size(); ++i)
- {
- node->dim(i) = shape.at(i);
- size *= shape.at(i);
- }
- node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT) \
- { \
- node->size<DT>(size); \
- for (uint32_t i = 0; i < values.size(); ++i) \
- node->at<DT>(i) = values[i]; \
- }
-
- switch (dtype)
- {
- case loco::DataType::U8:
- INIT_VALUES(loco::DataType::U8);
- break;
- case loco::DataType::S16:
- INIT_VALUES(loco::DataType::S16);
- break;
- case loco::DataType::S32:
- INIT_VALUES(loco::DataType::S32);
- break;
- case loco::DataType::FLOAT32:
- INIT_VALUES(loco::DataType::FLOAT32)
- break;
- default:
- INTERNAL_EXN("create_const_node called with unsupported type");
- break;
- }
- return node;
-}
-
/**
* Simple graph for test
*
void init(loco::Graph *g)
{
std::vector<float> weights_val(16 * 4);
- _fc_f = create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
+ _fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
std::vector<float> bias_val(16);
- _fc_b = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
+ _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
_fc = g->nodes()->create<luci::CircleFullyConnected>();
_fc->weights(_fc_f);
std::vector<float> addition_val;
for (uint32_t i = 0; i < 16; i++)
addition_val.push_back(static_cast<float>(i));
- _add_c = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
+ _add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
_add = g->nodes()->create<luci::CircleAdd>();
_add->x(_fc);