// Create a sample network
GraphTestcase<GraphCode::ConstGen> testcase;
+ testcase.const_node->dtype(loco::DataType::FLOAT32);
+ testcase.const_node->shape({1, 2});
+
// Run Inference
loco::CanonicalShapeInferenceRule rule;
}
};
+struct ConstGenLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::ConstGen *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ loco::ConstGen *node(void) { return _node; }
+
+ private:
+ loco::ConstGen *_node = nullptr;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto const_node = ctx->graph()->nodes()->create<loco::ConstGen>();
+
+ ctx->stack()->push(const_node);
+
+ return stdex::make_unique<Return>(const_node);
+ }
+};
+
#include "loco/IR/PermutingCodec.h"
struct FeatureEncodeLayer final
public:
GraphTestcase()
{
- // Create a sample network
_graph = loco::make_graph();
- // Create Nodes
- const_node = _graph->nodes()->create<loco::ConstGen>();
- const_node->dtype(loco::DataType::FLOAT32);
- const_node->shape({1, 2});
-
- push_node = _graph->nodes()->create<loco::Push>();
- push_node->from(const_node);
-
- // Create Graph Output
- auto graph_output = _graph->outputs()->create();
+ auto graph_builder = make_graph_builder(_graph.get());
- graph_output->name("output");
- graph_output->dtype(loco::DataType::FLOAT32);
- graph_output->shape({1, 2});
+ const_node = graph_builder->push<ConstGenLayer>()->node();
- loco::link(graph_output, push_node);
- push_node->index(0);
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
}
public: