From e69e5cef4661593905cf152a787f53d5edb95a47 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 4 Sep 2019 10:51:07 +0900 Subject: [PATCH] [moco-tf] Revise FixShape test with Conv2D (#7116) This will revise FixShapeTransform test with Conv2D to use moco::tf::TFConv2D from loco::Conv2D IR Signed-off-by: SaeHie Park --- .../src/Transforms/FixShapeTransform.test.cpp | 51 +++++++++++----------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp index 8b77a3b..1cb0a4d 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp @@ -125,33 +125,34 @@ void conv2d_test(const std::array ifm_shape, const std::arraycreate(); + auto conv2d_node = graph.nodes()->create(); + conv2d_node->data_layout("NHWC"); + conv2d_node->strides({1, stride_h_w[0], stride_h_w[1], 1}); + conv2d_node->padding(padding); - auto stride = conv2d_node->stride(); - stride->vertical(stride_h_w[0]); - stride->horizontal(stride_h_w[1]); - - auto ifm_node = graph.nodes()->create(); + auto ifm_node = graph.nodes()->create(); { auto shapedata = stdex::make_unique(); - loco::FeatureShape cshape; - cshape.count() = ifm_shape[0]; - cshape.height() = ifm_shape[1]; - cshape.width() = ifm_shape[2]; - cshape.depth() = ifm_shape[3]; - shapedata->feature_shape(cshape); + loco::TensorShape tshape; + tshape.rank(4); + tshape.dim(0).set(ifm_shape[0]); + tshape.dim(1).set(ifm_shape[1]); + tshape.dim(2).set(ifm_shape[2]); + tshape.dim(3).set(ifm_shape[3]); + shapedata->tensor_shape(tshape); ifm_node->annot(std::move(shapedata)); } auto ker_node = graph.nodes()->create(); { auto shapedata = stdex::make_unique(); - loco::FilterShape cshape; - cshape.height() = ker_shape[0]; - cshape.width() = ker_shape[1]; - cshape.depth() = ker_shape[2]; - cshape.count() = ker_shape[3]; - shapedata->filter_shape(cshape); + loco::TensorShape tshape; + tshape.rank(4); + tshape.dim(0).set(ker_shape[0]); + tshape.dim(1).set(ker_shape[1]); + tshape.dim(2).set(ker_shape[2]); + tshape.dim(3).set(ker_shape[3]); + shapedata->tensor_shape(tshape); ker_node->annot(std::move(shapedata)); } @@ -160,19 +161,17 @@ void conv2d_test(const std::array ifm_shape, const std::array(padding); - conv2d_node->annot(std::move(padding_data)); - moco::tf::FixShapeTransform transform; transform.run(&graph); auto shapedata = conv2d_node->annot(); ASSERT_NE(shapedata, nullptr); - auto fshape = shapedata->feature_shape(); - ASSERT_EQ(fshape.count(), expected_shape[0]); - ASSERT_EQ(fshape.height(), expected_shape[1]); - ASSERT_EQ(fshape.width(), expected_shape[2]); - ASSERT_EQ(fshape.depth(), expected_shape[3]); + auto tshape = shapedata->tensor_shape(); + ASSERT_EQ(tshape.rank(), 4); + ASSERT_EQ(tshape.dim(0).value(), expected_shape[0]); + ASSERT_EQ(tshape.dim(1).value(), expected_shape[1]); + ASSERT_EQ(tshape.dim(2).value(), expected_shape[2]); + ASSERT_EQ(tshape.dim(3).value(), expected_shape[3]); } } // namespace -- 2.7.4