auto padding = node->padding();
assert(padding == "VALID" || padding == "SAME");
- // TODO move this to some new Transformation...
- {
- {
- auto stride_data = node->annot<StrideData>();
- assert(stride_data == nullptr);
- }
- auto stride_data = stdex::make_unique<StrideData>();
- auto strides = node->strides();
- auto data_layout = plier::tf::as_data_layout(node->data_layout());
- if (data_layout == plier::tf::DataLayout::NHWC)
- {
- stride_data->stride()->vertical(strides[1]);
- stride_data->stride()->horizontal(strides[2]);
- }
- else if (data_layout == plier::tf::DataLayout::NCHW)
- {
- stride_data->stride()->vertical(strides[2]);
- stride_data->stride()->horizontal(strides[3]);
- }
- node->annot(std::move(stride_data));
-
- {
- auto window_data = node->annot<WindowData>();
- assert(window_data == nullptr);
- }
- auto window_data = stdex::make_unique<WindowData>();
- auto ksize = node->ksize();
- if (data_layout == plier::tf::DataLayout::NHWC)
- {
- window_data->window()->vertical(ksize[1]);
- window_data->window()->horizontal(ksize[2]);
- }
- else if (data_layout == plier::tf::DataLayout::NCHW)
- {
- window_data->window()->vertical(ksize[2]);
- window_data->window()->horizontal(ksize[3]);
- }
- node->annot(std::move(window_data));
- }
+ update_stride_data(node);
+ update_window_data(node);
auto value_feature_shape = as_feature_shape(*value_shapedata, node->data_layout());
auto padding = node->padding();
assert(padding == "VALID" || padding == "SAME");
- // TODO move this to some new Transformation...
- {
- {
- auto stride_data = node->annot<StrideData>();
- assert(stride_data == nullptr);
- }
- auto stride_data = stdex::make_unique<StrideData>();
- auto strides = node->strides();
- auto data_layout = plier::tf::as_data_layout(node->data_layout());
- if (data_layout == plier::tf::DataLayout::NHWC)
- {
- stride_data->stride()->vertical(strides[1]);
- stride_data->stride()->horizontal(strides[2]);
- }
- else if (data_layout == plier::tf::DataLayout::NCHW)
- {
- stride_data->stride()->vertical(strides[2]);
- stride_data->stride()->horizontal(strides[3]);
- }
- node->annot(std::move(stride_data));
- }
+ update_stride_data(node);
auto stride_data = node->annot<StrideData>();
assert(stride_data != nullptr);
return false;
}
- {
- auto stride_data = node->annot<StrideData>();
- assert(stride_data == nullptr);
- }
-
- auto stride_copy = stdex::make_unique<StrideData>();
- auto strides = node->strides();
- auto data_layout = plier::tf::as_data_layout(node->data_layout());
- if (data_layout == plier::tf::DataLayout::NHWC)
- {
- stride_copy->stride()->vertical(strides[1]);
- stride_copy->stride()->horizontal(strides[2]);
- }
- else if (data_layout == plier::tf::DataLayout::NCHW)
- {
- stride_copy->stride()->vertical(strides[2]);
- stride_copy->stride()->horizontal(strides[3]);
- }
- else
- {
- throw std::runtime_error{"Not supported for other data layout"};
- }
- node->annot(std::move(stride_copy));
+ update_stride_data(node);
auto stride_data = node->annot<StrideData>();
assert(stride_data != nullptr);
auto padding = node->padding();
assert(padding == "VALID" || padding == "SAME");
- // TODO move this to some new Transformation...
- {
- {
- auto stride_data = node->annot<StrideData>();
- assert(stride_data == nullptr);
- }
- auto stride_data = stdex::make_unique<StrideData>();
- auto strides = node->strides();
- auto data_layout = plier::tf::as_data_layout(node->data_layout());
- if (data_layout == plier::tf::DataLayout::NHWC)
- {
- stride_data->stride()->vertical(strides[1]);
- stride_data->stride()->horizontal(strides[2]);
- }
- else if (data_layout == plier::tf::DataLayout::NCHW)
- {
- stride_data->stride()->vertical(strides[2]);
- stride_data->stride()->horizontal(strides[3]);
- }
- node->annot(std::move(stride_data));
-
- {
- auto window_data = node->annot<WindowData>();
- assert(window_data == nullptr);
- }
- auto window_data = stdex::make_unique<WindowData>();
- auto ksize = node->ksize();
- if (data_layout == plier::tf::DataLayout::NHWC)
- {
- window_data->window()->vertical(ksize[1]);
- window_data->window()->horizontal(ksize[2]);
- }
- else if (data_layout == plier::tf::DataLayout::NCHW)
- {
- window_data->window()->vertical(ksize[2]);
- window_data->window()->horizontal(ksize[3]);
- }
- node->annot(std::move(window_data));
- }
+ update_stride_data(node);
+ update_window_data(node);
auto value_feature_shape = as_feature_shape(*value_shapedata, node->data_layout());