return loco::NodeShape{tensor_shape};
}
- // TODO Support ReLU
+ // CASE: ReLU
+ loco::NodeShape visit(const loco::ReLU *node) final { return loco::shape_get(node->input()); }
+
// TODO Support ReLU6
// TODO Support TensorBiasAdd
// TODO SUpport TensorConcat
ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
}
+
+TEST(CanonicalShapeInferenceRuleTest, relu)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::Relu> testcase;
+
+ testcase.pull_node->shape({1, 2, 3, 4});
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.push_node));
+ ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
+}
enum class GraphCode
{
Identity,
+ Relu,
};
template <GraphCode Code> class GraphTestcase;
std::unique_ptr<loco::Graph> _graph;
};
+template <> class GraphTestcase<GraphCode::Relu> final
+{
+public:
+ GraphTestcase()
+ {
+ // Create a sample network
+ _graph = loco::make_graph();
+
+ // Create Nodes
+ pull_node = _graph->nodes()->create<loco::Pull>();
+
+ relu_node = _graph->nodes()->create<loco::ReLU>();
+ relu_node->input(pull_node);
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->from(relu_node);
+
+ // Create Graph Input
+ auto graph_input = _graph->inputs()->create();
+
+ graph_input->name("input");
+ graph_input->node(pull_node);
+ pull_node->index(0);
+
+ // Create Graph Output
+ auto graph_output = _graph->outputs()->create();
+
+ graph_output->name("output");
+ graph_output->node(push_node);
+ push_node->index(0);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::ReLU *relu_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
#endif // __GRAPH_TESTCASE_H__