namespace
{
+struct PlaneShape
+{
+ loco::Dimension height;
+ loco::Dimension width;
+};
+
+PlaneShape make_plane_shape(const loco::FeatureShape &feature_shape)
+{
+ PlaneShape plane_shape;
+
+ plane_shape.height = feature_shape.height();
+ plane_shape.width = feature_shape.width();
+
+ return plane_shape;
+}
+
+class FeatureShapeUpdater final
+{
+public:
+ FeatureShapeUpdater(loco::FeatureShape *ptr) : _feature_shape_ptr{ptr}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void with(const PlaneShape &plane_shape) const
+ {
+ _feature_shape_ptr->height() = plane_shape.height;
+ _feature_shape_ptr->width() = plane_shape.width;
+ }
+
+private:
+ loco::FeatureShape *_feature_shape_ptr;
+};
+
+/**
+ * HOW TO USE
+ *
+ * loco::FeatureShape feature_shape = ...;
+ *
+ * update(feature_shape).with(...)
+ */
+FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
+{
+ return FeatureShapeUpdater{&feature_shape};
+}
+
+class PlaneInference final
+{
+public:
+ PlaneShape operator()(const PlaneShape &in) const
+ {
+ assert(_pad != nullptr);
+ assert(_window != nullptr);
+ assert(_stride != nullptr);
+
+ uint32_t const raw_input_height = in.height.value();
+ uint32_t const raw_input_width = in.width.value();
+
+ uint32_t const raw_window_height = _window->vertical();
+ uint32_t const raw_window_width = _window->horizontal();
+
+ uint32_t const vertical_padding = _pad->top() + _pad->bottom();
+ uint32_t const horizontal_padding = _pad->left() + _pad->right();
+
+ uint32_t const effective_input_height = raw_input_height + vertical_padding;
+ uint32_t const effective_input_width = raw_input_width + horizontal_padding;
+
+ // NOTE To support "dilation" later
+ uint32_t const effective_window_height = raw_window_height;
+ uint32_t const effective_window_width = raw_window_width;
+
+ uint32_t const vertical_stride = _stride->vertical();
+ uint32_t const horizontal_stride = _stride->horizontal();
+
+ assert((effective_input_height - effective_window_height) % vertical_stride == 0);
+ assert((effective_input_width - effective_window_width) % horizontal_stride == 0);
+
+ PlaneShape res;
+
+ res.height = (effective_input_height - effective_window_height) / vertical_stride + 1;
+ res.width = (effective_input_width - effective_window_width) / horizontal_stride + 1;
+
+ return res;
+ }
+
+public:
+ void pad(const loco::Pad<2> *value) { _pad = value; }
+ void window(const loco::Window<2> *value) { _window = value; }
+ void stride(const loco::Stride<2> *value) { _stride = value; }
+
+public:
+ const loco::Pad<2> *_pad = nullptr;
+ const loco::Window<2> *_window = nullptr;
+ const loco::Stride<2> *_stride = nullptr;
+};
+
loco::NodeShape eltwise_binary_node_shape(const loco::Node *node)
{
// This helper works only for binary node.
class ForwardShapeInferenceAlgorithm final : public loco::CanonicalNodeVisitor<loco::NodeShape>
{
public:
- // TODO Support AvgPool2D
+ // CASE: AvgPool2D
+ loco::NodeShape visit(const loco::AvgPool2D *node) final
+ {
+ PlaneInference infer_plane_shape;
+
+ infer_plane_shape.pad(node->pad());
+ infer_plane_shape.window(node->window());
+ infer_plane_shape.stride(node->stride());
+
+ auto input_feature_shape = loco::shape_get(node->ifm()).as<loco::FeatureShape>();
+ auto input_plane_shape = make_plane_shape(input_feature_shape);
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+ auto output_feature_shape = input_feature_shape; // AvgPool2D does not change count/depth
+
+ // Update the height/width of output_feature_shape with that of output_plane_shape
+ update(output_feature_shape).with(output_plane_shape);
+
+ return loco::NodeShape{output_feature_shape};
+ }
+
// TODO Support BiasEncode
// TODO Support ConstGen
// TODO Support Conv2D
ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(2), 3);
ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(3), 4);
}
+
+TEST(CanonicalShapeInferenceRuleTest, avgpool2d)
+{
+ using namespace loco;
+
+ // Create a sample network
+ GraphTestcase<GraphCode::AvgPool2D> testcase;
+
+ auto perm = make_NHWC_perm<Domain::Feature>();
+
+ testcase.pull_node->shape({1, 8, 4, 3});
+
+ testcase.encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm));
+
+ testcase.avgpool2d_node->window()->vertical(2);
+ testcase.avgpool2d_node->window()->horizontal(2);
+
+ testcase.avgpool2d_node->stride()->vertical(2);
+ testcase.avgpool2d_node->stride()->horizontal(2);
+
+ testcase.decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm));
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ //
+ // NOTE AvgPool2D testcase assumes NHWC layout
+ ASSERT_TRUE(loco::shape_known(testcase.avgpool2d_node));
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).domain(), loco::Domain::Feature);
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().count(), 1);
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().depth(), 3);
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().height(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().width(), 2);
+}
Identity,
Relu,
FeatureCodec,
+ AvgPool2D,
};
template <GraphCode Code> class GraphTestcase;
std::unique_ptr<loco::Graph> _graph;
};
+template <> class GraphTestcase<GraphCode::AvgPool2D> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ // Create Graph Input/Output
+ auto graph_input = _graph->inputs()->create();
+ auto graph_output = _graph->outputs()->create();
+
+ graph_input->name("input");
+ graph_output->name("output");
+
+ // Create and connect nodes
+ pull_node = _graph->nodes()->create<Pull>();
+ pull_node->index(0);
+
+ encode_node = _graph->nodes()->create<FeatureEncode>();
+ encode_node->input(pull_node);
+
+ avgpool2d_node = _graph->nodes()->create<AvgPool2D>();
+ avgpool2d_node->ifm(encode_node);
+
+ decode_node = _graph->nodes()->create<FeatureDecode>();
+ decode_node->input(avgpool2d_node);
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->index(0);
+ push_node->from(decode_node);
+
+ // Create a link between input/output and corresponding nodes
+ loco::link(graph_input, pull_node);
+ loco::link(graph_output, push_node);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FeatureEncode *encode_node = nullptr;
+ loco::AvgPool2D *avgpool2d_node = nullptr;
+ loco::FeatureDecode *decode_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+namespace
+{
+
+template <loco::Domain D> loco::Permutation<D> make_NHWC_perm(void);
+
+template <> loco::Permutation<loco::Domain::Feature> make_NHWC_perm(void)
+{
+ loco::Permutation<loco::Domain::Feature> perm;
+
+ perm[loco::FeatureAxis::Count] = 0;
+ perm[loco::FeatureAxis::Height] = 1;
+ perm[loco::FeatureAxis::Width] = 2;
+ perm[loco::FeatureAxis::Depth] = 3;
+
+ return perm;
+}
+
+} // namespace
#endif // __GRAPH_TESTCASE_H__