// TODO Support FilterEncode
// TODO Support FixedReshape
- // TODO Support MaxPool2D
+
+ // CASE: MaxPool2D
+ loco::NodeShape visit(const loco::MaxPool2D *node) final
+ {
+ PlaneInference infer_plane_shape;
+
+ infer_plane_shape.pad(node->pad());
+ infer_plane_shape.window(node->window());
+ infer_plane_shape.stride(node->stride());
+
+ auto input_feature_shape = loco::shape_get(node->ifm()).as<loco::FeatureShape>();
+ auto input_plane_shape = make_plane_shape(input_feature_shape);
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+ auto output_feature_shape = input_feature_shape; // MaxPool2D does not change count/depth
+
+ // Update the height/width of output_feature_shape with that of output_plane_shape
+ update(output_feature_shape).with(output_plane_shape);
+
+ return loco::NodeShape{output_feature_shape};
+ }
// CASE: Push
loco::NodeShape visit(const loco::Push *node) final
ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().height(), 4);
ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().width(), 2);
}
+
+TEST(CanonicalShapeInferenceRuleTest, maxpool2d)
+{
+ using namespace loco;
+
+ // Create a sample network
+ GraphTestcase<GraphCode::MaxPool2D> testcase;
+
+ auto perm = make_NHWC_perm<Domain::Feature>();
+
+ testcase.pull_node->shape({1, 8, 4, 3});
+
+ testcase.encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm));
+
+ testcase.maxpool2d_node->window()->vertical(2);
+ testcase.maxpool2d_node->window()->horizontal(2);
+
+ testcase.maxpool2d_node->stride()->vertical(2);
+ testcase.maxpool2d_node->stride()->horizontal(2);
+
+ testcase.decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm));
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ //
+ // NOTE MaxPool2D testcase assumes NHWC layout
+ ASSERT_TRUE(loco::shape_known(testcase.maxpool2d_node));
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).domain(), loco::Domain::Feature);
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().count(), 1);
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().depth(), 3);
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().height(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().width(), 2);
+}
Relu,
FeatureCodec,
AvgPool2D,
+ MaxPool2D,
};
template <GraphCode Code> class GraphTestcase;
std::unique_ptr<loco::Graph> _graph;
};
+template <> class GraphTestcase<GraphCode::MaxPool2D> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ // Create Graph Input/Output
+ auto graph_input = _graph->inputs()->create();
+ auto graph_output = _graph->outputs()->create();
+
+ graph_input->name("input");
+ graph_output->name("output");
+
+ // Create and connect nodes
+ pull_node = _graph->nodes()->create<Pull>();
+ pull_node->index(0);
+
+ encode_node = _graph->nodes()->create<FeatureEncode>();
+ encode_node->input(pull_node);
+
+ maxpool2d_node = _graph->nodes()->create<MaxPool2D>();
+ maxpool2d_node->ifm(encode_node);
+
+ decode_node = _graph->nodes()->create<FeatureDecode>();
+ decode_node->input(maxpool2d_node);
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->index(0);
+ push_node->from(decode_node);
+
+ // Create a link between input/output and corresponding nodes
+ loco::link(graph_input, pull_node);
+ loco::link(graph_output, push_node);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FeatureEncode *encode_node = nullptr;
+ loco::MaxPool2D *maxpool2d_node = nullptr;
+ loco::FeatureDecode *decode_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
namespace
{