From 9e1f68e3ec9a1d641ed2cda02da9074b0da28ae3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 30 Jul 2019 07:22:42 +0900 Subject: [PATCH] [moco-tf] Introduce as_tensor_shape with FeatureShape (#5973) This will introduce ShapeInferenceData helper to set as TensorShape with FeatureShape and TFDataLayout Signed-off-by: SaeHie Park --- .../moco-tf/src/Annotations/ShapeInferenceData.cpp | 29 ++++++++++++++++++++++ .../moco-tf/src/Annotations/ShapeInferenceData.h | 3 +++ .../src/Annotations/ShapeInferenceData.test.cpp | 24 ++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp b/compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp index e1338d7..14705f2 100644 --- a/compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp +++ b/compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp @@ -132,6 +132,35 @@ void ShapeInferenceData::bias_shape(const loco_tobe::BiasShape &shape) dim(0) = shape; } +void as_tensor_shape(ShapeInferenceData &shapedata, const loco::FeatureShape &feature_shape, + const TFDataLayout &data_layout) +{ + loco::TensorShape tensor_shape; + + tensor_shape.rank(4); + if (data_layout == "NHWC") + { + tensor_shape.dim(0) = feature_shape.count(); + tensor_shape.dim(1) = feature_shape.height(); + tensor_shape.dim(2) = feature_shape.width(); + tensor_shape.dim(3) = feature_shape.depth(); + } + else if (data_layout == "NCHW") + { + tensor_shape.dim(0) = feature_shape.count(); + tensor_shape.dim(1) = feature_shape.depth(); + tensor_shape.dim(2) = feature_shape.height(); + tensor_shape.dim(3) = feature_shape.width(); + } + else + { + // TODO support for other data_layout if needed + throw std::runtime_error("as_tensor_shape: only supports NHWC or NCHW"); + } + + shapedata.tensor_shape(tensor_shape); +} + loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata, const TFDataLayout &data_layout) { diff --git a/compiler/moco-tf/src/Annotations/ShapeInferenceData.h b/compiler/moco-tf/src/Annotations/ShapeInferenceData.h index 164db93..ddacf53 100644 --- a/compiler/moco-tf/src/Annotations/ShapeInferenceData.h +++ b/compiler/moco-tf/src/Annotations/ShapeInferenceData.h @@ -66,6 +66,9 @@ private: loco::Domain _domain{loco::Domain::Tensor}; }; +void as_tensor_shape(ShapeInferenceData &shapedata, const loco::FeatureShape &shape, + const TFDataLayout &data_layout); + loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata, const TFDataLayout &data_layout); diff --git a/compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp b/compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp index 616d0d8..486a935 100644 --- a/compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp +++ b/compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp @@ -98,6 +98,30 @@ TEST(TensorFlowImport, shapeinferencedata_bias) ASSERT_EQ(bias_g, 3); } +TEST(TensorFlowImport, shapeinferencedata_as_tensor_set) +{ + loco::FeatureShape feature_s; + + feature_s.count() = 1; + feature_s.height() = 2; + feature_s.width() = 3; + feature_s.depth() = 4; + + moco::tf::ShapeInferenceData shapedata; + + as_tensor_shape(shapedata, feature_s, "NHWC"); + + loco::TensorShape tensor_g; + + tensor_g = shapedata.tensor_shape(); + + ASSERT_EQ(tensor_g.rank(), 4); + ASSERT_EQ(tensor_g.dim(0), 1); + ASSERT_EQ(tensor_g.dim(1), 2); + ASSERT_EQ(tensor_g.dim(2), 3); + ASSERT_EQ(tensor_g.dim(3), 4); +} + TEST(TensorFlowImport, shapeinferencedata_as_feature) { loco::TensorShape tensor_s; -- 2.7.4