}
};
+struct FixedReshapeLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::FixedReshape *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ Return *shape(std::initializer_list<uint32_t> dims)
+ {
+ _node->shape(dims);
+ return this;
+ }
+
+ public:
+ loco::FixedReshape *node(void) { return _node; }
+
+ private:
+ loco::FixedReshape *_node = nullptr;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto reshape_node = ctx->graph()->nodes()->create<loco::FixedReshape>();
+
+ reshape_node->input(ctx->stack()->pop());
+
+ ctx->stack()->push(reshape_node);
+
+ return stdex::make_unique<Return>(reshape_node);
+ }
+};
+
#endif // __GRAPH_BUILDER_H__
AvgPool2D,
MaxPool2D,
TensorConcat,
+ FixedReshape,
};
template <GraphCode Code> class GraphTestcase;
std::unique_ptr<loco::Graph> _graph;
};
+template <> class GraphTestcase<GraphCode::FixedReshape> final
+{
+public:
+ GraphTestcase()
+ {
+ _graph = loco::make_graph();
+
+ auto graph_builder = make_graph_builder(_graph.get());
+
+ pull_node = graph_builder->push<InputLayer>()->name("input")->node();
+ reshape_node = graph_builder->push<FixedReshapeLayer>()->node();
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FixedReshape *reshape_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
namespace
{