assert(node->template annot<PadData>() != nullptr);
}
+template <class T> void update_stride_data(T *node)
+{
+ auto stride_data = stdex::make_unique<StrideData>();
+ auto strides = node->strides();
+ auto data_layout = plier::tf::as_data_layout(node->data_layout());
+ if (data_layout == plier::tf::DataLayout::NHWC)
+ {
+ stride_data->stride()->vertical(strides[1]);
+ stride_data->stride()->horizontal(strides[2]);
+ }
+ else if (data_layout == plier::tf::DataLayout::NCHW)
+ {
+ stride_data->stride()->vertical(strides[2]);
+ stride_data->stride()->horizontal(strides[3]);
+ }
+ node->annot(std::move(stride_data));
+}
+
+template <class T> void update_window_data(T *node)
+{
+ auto window_data = stdex::make_unique<WindowData>();
+ auto ksize = node->ksize();
+ auto data_layout = plier::tf::as_data_layout(node->data_layout());
+ if (data_layout == plier::tf::DataLayout::NHWC)
+ {
+ window_data->window()->vertical(ksize[1]);
+ window_data->window()->horizontal(ksize[2]);
+ }
+ else if (data_layout == plier::tf::DataLayout::NCHW)
+ {
+ window_data->window()->vertical(ksize[2]);
+ window_data->window()->horizontal(ksize[3]);
+ }
+ node->annot(std::move(window_data));
+}
+
bool fix_shape(loco::AvgPool2D *node)
{
LOGGER(l);