This will revise shape inference for DepthwiseConv2dNative to be done in TFShapeInferenceRule
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
*/
#include "DepthwiseConv2dNativeCanonicalizer.h"
-
-#include "Annotations/PadData.h"
-#include "Annotations/StrideData.h"
+#include "TFShapeInferenceHelper.h"
#include "Dialect/TFDialect.h"
#include "Dialect/TFNodes.h"
set_filter_enc(filter_enc);
set_feature_dec(feature_dec, data_layout);
- // Set DetphwiseConv2D attributes from TFDepthwiseConv2dNative
- auto pad_data = node->annot<moco::tf::PadData>();
- assert(pad_data != nullptr);
+ // Calculate Pad and Stride from inference
+ auto input_shape = moco::tf::node_shape(node->input());
+ auto ker_shape = moco::tf::node_shape(node->filter());
+ auto ker_tensor_shape = ker_shape.as<loco::TensorShape>();
+ auto node_stride = moco::tf::stride_of(node->strides(), node->data_layout());
+ auto node_window = moco::tf::window_of(ker_tensor_shape, "HWCM");
+
+ moco::tf::Padding2DInference infer_padding2d;
- depthwiseconv2d->pad()->top(pad_data->pad()->top());
- depthwiseconv2d->pad()->bottom(pad_data->pad()->bottom());
- depthwiseconv2d->pad()->left(pad_data->pad()->left());
- depthwiseconv2d->pad()->right(pad_data->pad()->right());
+ infer_padding2d.padding(node->padding());
+ infer_padding2d.stride(node_stride);
+ infer_padding2d.window(node_window);
- auto stride_data = node->annot<moco::tf::StrideData>();
- assert(stride_data != nullptr);
+ auto input_feature_shape = moco::tf::as_feature_shape(input_shape, node->data_layout());
+ auto input_plane_shape = moco::tf::make_plane_shape(input_feature_shape);
- depthwiseconv2d->stride()->vertical(stride_data->stride()->vertical());
- depthwiseconv2d->stride()->horizontal(stride_data->stride()->horizontal());
+ *depthwiseconv2d->pad() = infer_padding2d(input_plane_shape);
+ *depthwiseconv2d->stride() = node_stride;
// update graph
auto node_A = node->input();
return loco::NodeShape(output_tensor_shape);
}
+ loco::NodeShape visit(const moco::tf::TFDepthwiseConv2dNative *node) final
+ {
+ auto input_shape = moco::tf::node_shape(node->input()); // NHWC
+ auto ker_shape = moco::tf::node_shape(node->filter());
+ auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWCM
+ auto node_stride = moco::tf::stride_of(node->strides(), node->data_layout());
+ auto node_window = moco::tf::window_of(ker_tensor_shape, "HWCM");
+
+ moco::tf::PlaneInference infer_plane_shape;
+
+ infer_plane_shape.padding(node->padding());
+ infer_plane_shape.stride(node_stride);
+ infer_plane_shape.window(node_window);
+
+ auto input_feature_shape = moco::tf::as_feature_shape(input_shape, node->data_layout());
+ auto input_plane_shape = moco::tf::make_plane_shape(input_feature_shape);
+ // output count is from input count, depth is from kernel 'CM' which is dim(2) * dim(3)
+ auto output_feature_shape = input_feature_shape;
+ output_feature_shape.depth() =
+ loco::Dimension(ker_tensor_shape.dim(2).value() * ker_tensor_shape.dim(3).value());
+
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ moco::tf::update(output_feature_shape).with(output_plane_shape);
+
+ return moco::tf::as_tensor_shape(output_feature_shape, node->data_layout());
+ }
+
public:
loco::NodeShape visit(const moco::tf::TFNode *node) final
{
window.vertical(shape.dim(0).value());
window.horizontal(shape.dim(1).value());
}
+ else if (datalayout == "HWCM")
+ {
+ window.vertical(shape.dim(0).value());
+ window.horizontal(shape.dim(1).value());
+ }
else
{
// TODO add more datalayout supports if needed
return true;
}
-bool fix_shape(moco::tf::TFDepthwiseConv2dNative *node)
-{
- LOGGER(l);
-
- if (shape_inference_done(node))
- return false;
-
- auto ifm = node->input();
- loco::NodeShape ifm_shape;
- if (!node_shape(ifm, ifm_shape))
- {
- // input node shape inference is not ready
- return false;
- }
-
- auto ker = node->filter();
- loco::NodeShape ker_shape;
- if (!node_shape(ker, ker_shape))
- {
- return false;
- }
-
- update_stride_data(node);
-
- auto stride_data = node->annot<StrideData>();
- assert(stride_data != nullptr);
-
- INFO(l) << "FixShape TFDepthwiseConv2dNative strides = " << stride_data->stride()->vertical()
- << ", " << stride_data->stride()->horizontal();
-
- auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>(); // in NHWC
- auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWCM
- assert(ifm_tensor_shape.rank() == 4);
- assert(ker_tensor_shape.rank() == 4);
-
- uint32_t input_height = ifm_tensor_shape.dim(1).value();
- uint32_t input_width = ifm_tensor_shape.dim(2).value();
- uint32_t stride_height = stride_data->stride()->vertical();
- uint32_t stride_width = stride_data->stride()->horizontal();
- uint32_t ker_height = ker_tensor_shape.dim(0).value();
- uint32_t ker_width = ker_tensor_shape.dim(1).value();
- uint32_t dilation_height = 1; // TODO Consider dilation
- uint32_t dilation_width = 1;
- uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
- uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
- uint32_t output_height;
- uint32_t output_width;
-
- auto padding = node->padding();
- assert(padding == "VALID" || padding == "SAME");
-
- if (padding == "VALID")
- {
- output_height = (input_height + stride_height - effective_ker_height) / stride_height;
- output_width = (input_width + stride_width - effective_ker_width) / stride_width;
- }
- else // padding == "SAME"
- {
- output_height = (input_height + stride_height - 1) / stride_height;
- output_width = (input_width + stride_width - 1) / stride_width;
- }
-
- loco::TensorShape ofm_tensor_shape;
- ofm_tensor_shape.rank(4);
- ofm_tensor_shape.dim(0) = ifm_tensor_shape.dim(0);
- ofm_tensor_shape.dim(1) = output_height;
- ofm_tensor_shape.dim(2) = output_width;
- ofm_tensor_shape.dim(3) =
- loco::Dimension(ker_tensor_shape.dim(2).value() * ker_tensor_shape.dim(3).value());
-
- auto shape_data = stdex::make_unique<ShapeInferenceData>();
- shape_data->tensor_shape(ofm_tensor_shape);
- node->annot(std::move(shape_data));
-
- FixPadContext ctx = {input_height, input_width, output_height, output_width,
- stride_height, stride_width, effective_ker_height, effective_ker_width};
-
- calc_annot_paddata(node, ctx);
-
- INFO(l) << "Fix TFDepthwiseConv2dNative shape = ifm" << ifm_tensor_shape << " ker"
- << ker_tensor_shape << " --> ofm" << ofm_tensor_shape;
- INFO(l) << " pad = " << *node->annot<PadData>();
-
- return true;
-}
+bool fix_shape(moco::tf::TFDepthwiseConv2dNative *node) { return false; }
bool fix_shape(moco::tf::TFFusedBatchNorm *node)
{