// https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow
TEST(TFLShapeInferenceRuleTest, avgpool2d_valid)
{
- exo::test::PullPushGraph<locoex::TFLAveragePool2D> test_graph;
- auto pull = test_graph.pull;
+ exo::test::TestGraph graph;
+ auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
+ graph.complete();
+
+ auto pull = graph.pull;
{
pull->shape({1, 4, 3, 1});
}
- auto tfl_node = test_graph.middle_node;
+ // setting TFLAveragePool2D
{
tfl_node->filter()->h(2);
tfl_node->filter()->w(2);
rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
.bind(locoex::TFLDialect::get(), &tfl_rule);
- loco::apply(&rules).to(test_graph.g.get());
+ loco::apply(&rules).to(graph.g.get());
// Verify
{
TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
{
- exo::test::PullPushGraph<locoex::TFLAveragePool2D> test_graph;
- auto pull = test_graph.pull;
+ exo::test::TestGraph graph;
+ auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
+ graph.complete();
+
+ auto pull = graph.pull;
{
pull->shape({1, 4, 3, 1});
}
- auto tfl_node = test_graph.middle_node;
+ // setting TFLAveragePool2D
{
tfl_node->filter()->h(2);
tfl_node->filter()->w(2);
rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
.bind(locoex::TFLDialect::get(), &tfl_rule);
- loco::apply(&rules).to(test_graph.g.get());
+ loco::apply(&rules).to(graph.g.get());
// Verify
{
namespace test
{
-// THIS WILL BE DEPRECATED. USE TestGraph instead.
-// graph to build [Pull - some node of type T - Push]
-template <typename T> struct PullPushGraph
-{
-public:
- std::unique_ptr<loco::Graph> g;
- loco::Pull *pull;
- loco::Push *push;
- T *middle_node;
-
- PullPushGraph()
- {
- // g = Pull - T - Push
- g = loco::make_graph();
-
- pull = g->nodes()->create<loco::Pull>();
-
- middle_node = g->nodes()->create<T>();
- {
- setInput();
- }
-
- push = g->nodes()->create<loco::Push>();
- {
- push->from(middle_node);
- }
-
- auto input = g->inputs()->create();
- {
- input->name("input");
- loco::link(input, pull);
- }
- auto output = g->outputs()->create();
- {
- output->name("output");
- loco::link(output, push);
- }
- }
-
-private:
- void setInput(); // set the input of T
-};
-
-// setInput of TFL nodes
-template <> void PullPushGraph<locoex::TFLAveragePool2D>::setInput() { middle_node->value(pull); }
-
class TestGraph
{
public: