moco::tf::FixShapeTransform fstransform;
loco::Graph graph;
- auto conv2d_node = graph.nodes()->create<loco::Conv2D>();
+ auto conv2d_node = graph.nodes()->create<moco::tf::TFConv2D>();
+ conv2d_node->data_layout("NHWC");
+ conv2d_node->strides({1, stride_h_w[0], stride_h_w[1], 1});
+ conv2d_node->padding(padding);
- auto stride = conv2d_node->stride();
- stride->vertical(stride_h_w[0]);
- stride->horizontal(stride_h_w[1]);
-
- auto ifm_node = graph.nodes()->create<loco::ConstGen>();
+ auto ifm_node = graph.nodes()->create<moco::tf::TFConst>();
{
auto shapedata = stdex::make_unique<moco::tf::ShapeInferenceData>();
- loco::FeatureShape cshape;
- cshape.count() = ifm_shape[0];
- cshape.height() = ifm_shape[1];
- cshape.width() = ifm_shape[2];
- cshape.depth() = ifm_shape[3];
- shapedata->feature_shape(cshape);
+ loco::TensorShape tshape;
+ tshape.rank(4);
+ tshape.dim(0).set(ifm_shape[0]);
+ tshape.dim(1).set(ifm_shape[1]);
+ tshape.dim(2).set(ifm_shape[2]);
+ tshape.dim(3).set(ifm_shape[3]);
+ shapedata->tensor_shape(tshape);
ifm_node->annot(std::move(shapedata));
}
auto ker_node = graph.nodes()->create<loco::ConstGen>();
{
auto shapedata = stdex::make_unique<moco::tf::ShapeInferenceData>();
- loco::FilterShape cshape;
- cshape.height() = ker_shape[0];
- cshape.width() = ker_shape[1];
- cshape.depth() = ker_shape[2];
- cshape.count() = ker_shape[3];
- shapedata->filter_shape(cshape);
+ loco::TensorShape tshape;
+ tshape.rank(4);
+ tshape.dim(0).set(ker_shape[0]);
+ tshape.dim(1).set(ker_shape[1]);
+ tshape.dim(2).set(ker_shape[2]);
+ tshape.dim(3).set(ker_shape[3]);
+ shapedata->tensor_shape(tshape);
ker_node->annot(std::move(shapedata));
}
setup_output_node(&graph, conv2d_node);
- auto padding_data = stdex::make_unique<moco::tf::PaddingData>(padding);
- conv2d_node->annot(std::move(padding_data));
-
moco::tf::FixShapeTransform transform;
transform.run(&graph);
auto shapedata = conv2d_node->annot<moco::tf::ShapeInferenceData>();
ASSERT_NE(shapedata, nullptr);
- auto fshape = shapedata->feature_shape();
- ASSERT_EQ(fshape.count(), expected_shape[0]);
- ASSERT_EQ(fshape.height(), expected_shape[1]);
- ASSERT_EQ(fshape.width(), expected_shape[2]);
- ASSERT_EQ(fshape.depth(), expected_shape[3]);
+ auto tshape = shapedata->tensor_shape();
+ ASSERT_EQ(tshape.rank(), 4);
+ ASSERT_EQ(tshape.dim(0).value(), expected_shape[0]);
+ ASSERT_EQ(tshape.dim(1).value(), expected_shape[1]);
+ ASSERT_EQ(tshape.dim(2).value(), expected_shape[2]);
+ ASSERT_EQ(tshape.dim(3).value(), expected_shape[3]);
}
} // namespace