return window;
}
+loco::Window<2> window_of(const loco::DepthwiseFilterShape &depthwise_filter_shape)
+{
+ loco::Window<2> window;
+
+ window.vertical(depthwise_filter_shape.height().value());
+ window.horizontal(depthwise_filter_shape.width().value());
+
+ return window;
+}
+
class PlaneInference final
{
public:
return loco::NodeShape{output_feature_shape};
}
- // TODO Support DepthwiseConv2D
+ // CASE: DepthwiseConv2D
+ loco::NodeShape visit(const loco::DepthwiseConv2D *node) final
+ {
+ auto depthwise_filter_shape = loco::shape_get(node->ker()).as<loco::DepthwiseFilterShape>();
+ auto dpethwise_filter_window = window_of(depthwise_filter_shape);
+
+ PlaneInference infer_plane_shape;
+
+ infer_plane_shape.pad(node->pad());
+ infer_plane_shape.window(&dpethwise_filter_window);
+ infer_plane_shape.stride(node->stride());
+
+ auto input_feature_shape = loco::shape_get(node->ifm()).as<loco::FeatureShape>();
+ auto input_plane_shape = make_plane_shape(input_feature_shape);
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ loco::FeatureShape output_feature_shape;
+
+ // "COUNT" does not change
+ output_feature_shape.count() = input_feature_shape.count();
+ // "DEPTH" depends on [in_channels * channel_multiplier] of filters
+ output_feature_shape.depth() = loco::Dimension(depthwise_filter_shape.depth().value() *
+ depthwise_filter_shape.multiplier().value());
+ // Update the height/width of output_feature_shape with that of output_plane_shape
+ update(output_feature_shape).with(output_plane_shape);
+
+ return loco::NodeShape{output_feature_shape};
+ }
+
// CASE: DepthwiseFilterEncode
loco::NodeShape visit(const loco::DepthwiseFilterEncode *node) final
{
ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().width(), 2);
}
+TEST(CanonicalShapeInferenceRuleTest, depthwiseconv2d)
+{
+ using namespace loco;
+
+ // Create a sample network
+ GraphTestcase<GraphCode::DepthwiseConv2D> testcase;
+
+ testcase.pull_node->shape({1, 4, 4, 3});
+
+ testcase.const_node->dtype(loco::DataType::FLOAT32);
+ testcase.const_node->shape({2, 2, 3, 2});
+
+ testcase.depthwiseconv2d_node->stride()->vertical(1);
+ testcase.depthwiseconv2d_node->stride()->horizontal(1);
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ //
+ // NOTE DepthwiseConv2D testcase assumes NHWC layout
+ ASSERT_TRUE(loco::shape_known(testcase.depthwiseconv2d_node));
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).domain(), loco::Domain::Feature);
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().count(), 1);
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().depth(), 6);
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().height(), 3);
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().width(), 3);
+}
+
TEST(CanonicalShapeInferenceRuleTest, maxpool2d)
{
using namespace loco;