return FeatureShapeUpdater{&feature_shape};
}
+loco::Window<2> window_of(const loco::FilterShape &filter_shape)
+{
+ loco::Window<2> window;
+
+ window.vertical(filter_shape.height().value());
+ window.horizontal(filter_shape.height().value());
+
+ return window;
+}
+
class PlaneInference final
{
public:
return loco::NodeShape{tensor_shape};
}
- // TODO Support Conv2D
+ // CASE: Conv2D
+ loco::NodeShape visit(const loco::Conv2D *node) final
+ {
+ auto filter_shape = loco::shape_get(node->ker()).as<loco::FilterShape>();
+ auto filter_window = window_of(filter_shape);
+
+ PlaneInference infer_plane_shape;
+
+ infer_plane_shape.pad(node->pad());
+ infer_plane_shape.window(&filter_window);
+ infer_plane_shape.stride(node->stride());
+
+ auto input_feature_shape = loco::shape_get(node->ifm()).as<loco::FeatureShape>();
+ auto input_plane_shape = make_plane_shape(input_feature_shape);
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ loco::FeatureShape output_feature_shape;
+
+ // "COUNT" does not change
+ output_feature_shape.count() = input_feature_shape.count();
+ // "DEPTH" depends on # of filters
+ output_feature_shape.depth() = filter_shape.count();
+ // Update the height/width of output_feature_shape with that of output_plane_shape
+ update(output_feature_shape).with(output_plane_shape);
+
+ return loco::NodeShape{output_feature_shape};
+ }
+
// TODO Support DepthwiseConv2D
// TODO Support DepthwiseFilterEncode
return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())};
}
- // TODO Support FilterEncode
+ // CASE: FilterEncode
+ loco::NodeShape visit(const loco::FilterEncode *node) final
+ {
+ auto input_tensor_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ return loco::NodeShape{node->encoder()->shape(input_tensor_shape)};
+ }
+
// TODO Support FixedReshape
// CASE: MaxPool2D