TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
{
// Create a simple network
- auto g = loco::make_graph();
-
- auto pull_node = g->nodes()->create<loco::Pull>();
- {
- pull_node->rank(2);
- pull_node->dim(0) = 3;
- pull_node->dim(1) = 4;
- }
-
- auto tfl_node = g->nodes()->create<locoex::TFLRelu>();
- tfl_node->features(pull_node);
-
- auto push_node = g->nodes()->create<loco::Push>();
- push_node->from(tfl_node);
+ exo::test::TestGraph graph;
+ auto tfl_node = graph.append<locoex::TFLRelu>(graph.pull);
+ graph.complete(tfl_node);
- auto input = g->inputs()->create();
+ // set shape
{
- input->name("input");
- loco::link(input, pull_node);
- }
- auto output = g->outputs()->create();
- {
- output->name("output");
- loco::link(output, push_node);
+ graph.pull->rank(2);
+ graph.pull->dim(0) = 3;
+ graph.pull->dim(1) = 4;
}
// pre-check
rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
.bind(locoex::TFLDialect::get(), &tfl_rule);
- loco::apply(&rules).to(g.get());
+ loco::apply(&rules).to(graph.g.get());
// Verify
{