*/
#include "Conv2DCanonicalizer.h"
-
-#include "Annotations/PadData.h"
-#include "Annotations/StrideData.h"
+#include "TFShapeInferenceHelper.h"
#include "Dialect/TFDialect.h"
#include "Dialect/TFNodes.h"
set_filter_enc(filter_enc);
set_feature_dec(feature_dec, data_layout);
- // Set Conv2D attributes from TFConv2D
- auto pad_data = node->annot<moco::tf::PadData>();
- assert(pad_data != nullptr);
+ auto input_shape = moco::tf::node_shape(node->input());
+ assert(input_shape.domain() != loco::Domain::Unknown);
+
+ auto ker_shape = moco::tf::node_shape(node->filter());
+ auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWIO
+
+ auto node_stride = moco::tf::stride_of(node->strides(), node->data_layout());
+ auto node_window = moco::tf::window_of(ker_tensor_shape, "HWIO");
+
+ moco::tf::Padding2DInference infer_padding2d;
- conv2d->pad()->top(pad_data->pad()->top());
- conv2d->pad()->bottom(pad_data->pad()->bottom());
- conv2d->pad()->left(pad_data->pad()->left());
- conv2d->pad()->right(pad_data->pad()->right());
+ infer_padding2d.padding(node->padding());
+ infer_padding2d.stride(node_stride);
+ infer_padding2d.window(node_window);
- auto stride_data = node->annot<moco::tf::StrideData>();
- assert(stride_data != nullptr);
+ auto input_feature_shape = moco::tf::as_feature_shape(input_shape, node->data_layout());
+ auto input_plane_shape = moco::tf::make_plane_shape(input_feature_shape);
- conv2d->stride()->vertical(stride_data->stride()->vertical());
- conv2d->stride()->horizontal(stride_data->stride()->horizontal());
+ *conv2d->pad() = infer_padding2d(input_plane_shape);
+ *conv2d->stride() = node_stride;
// update graph
auto node_A = node->input();