From eed146c624974564c19eda08421137c3aa11fc8c Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 18 Jun 2019 14:11:55 +0900 Subject: [PATCH] [moco/tf] tensor and filter for ShapeInferenceData (#3837) This will add getter of tensor and setter of filter for ShapeInferenceData Signed-off-by: SaeHie Park --- .../tf/src/Annotations/ShapeInferenceData.cpp | 25 ++++++++++++++ .../tf/src/Annotations/ShapeInferenceData.h | 4 ++- .../tf/src/Annotations/ShapeInferenceData.test.cpp | 39 ++++++++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.cpp b/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.cpp index 25053e2..e788c7f 100644 --- a/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.cpp +++ b/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.cpp @@ -21,6 +21,22 @@ namespace moco namespace tf { +loco::TensorShape ShapeInferenceData::tensor_shape(void) const +{ + loco::TensorShape shape; + + shape.rank(rank()); + for (uint32_t r = 0; r < rank(); ++r) + { + if (dim(r).known()) + shape.dim(r) = loco::make_dimension(dim(r).value()); + else + shape.dim(r).unset(); + } + + return shape; +} + loco::FeatureShape ShapeInferenceData::feature_shape(void) const { loco::FeatureShape shape; @@ -46,5 +62,14 @@ void ShapeInferenceData::feature_shape(const loco::FeatureShape &shape) dim(3) = shape.depth(); } +void ShapeInferenceData::filter_shape(const loco::FilterShape &shape) +{ + rank(4); + dim(0) = shape.count(); + dim(1) = shape.height(); + dim(2) = shape.width(); + dim(3) = shape.depth(); +} + } // namespace tf } // namespace moco diff --git a/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.h b/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.h index 8ea7dad..62abd1a 100644 --- a/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.h +++ b/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.h @@ -29,7 +29,7 @@ namespace tf /** * @brief ShapeInferenceData provides shape inference data tracking from the start(input) * - * @note For Feature, NHWC is used for shape layout + * @note For Feature and Filter, NHWC is used for shape layout */ class ShapeInferenceData : public loco::NodeAnnotation, public loco::NodeMixin @@ -38,9 +38,11 @@ public: ~ShapeInferenceData(){}; public: + loco::TensorShape tensor_shape(void) const; loco::FeatureShape feature_shape(void) const; void feature_shape(const loco::FeatureShape &shape); + void filter_shape(const loco::FilterShape &shape); }; } // namespace tf diff --git a/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.test.cpp b/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.test.cpp index fccf29f..7808724 100644 --- a/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.test.cpp +++ b/contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.test.cpp @@ -18,6 +18,25 @@ #include +TEST(TensorFlowFrontend, shapeinferencedata_tensor_get) +{ + moco::tf::ShapeInferenceData shapedata; + + shapedata.rank(4); + shapedata.dim(0) = 1; + shapedata.dim(1) = 2; + shapedata.dim(2) = 3; + shapedata.dim(3) = 4; + + loco::TensorShape tensor = shapedata.tensor_shape(); + + ASSERT_EQ(tensor.rank(), 4); + ASSERT_EQ(tensor.dim(0), 1); + ASSERT_EQ(tensor.dim(1), 2); + ASSERT_EQ(tensor.dim(2), 3); + ASSERT_EQ(tensor.dim(3), 4); +} + TEST(TensorFlowFrontend, shapeinferencedata_feature_set) { loco::FeatureShape feature; @@ -55,3 +74,23 @@ TEST(TensorFlowFrontend, shapeinferencedata_feature_get) ASSERT_EQ(feature.width(), 3); ASSERT_EQ(feature.depth(), 4); } + +TEST(TensorFlowFrontend, shapeinferencedata_filter_set) +{ + loco::FilterShape filter; + + filter.count() = 1; + filter.height() = 2; + filter.width() = 3; + filter.depth() = 4; + + moco::tf::ShapeInferenceData shapedata; + + shapedata.filter_shape(filter); + + ASSERT_EQ(shapedata.rank(), 4); + ASSERT_EQ(shapedata.dim(0), 1); + ASSERT_EQ(shapedata.dim(1), 2); + ASSERT_EQ(shapedata.dim(2), 3); + ASSERT_EQ(shapedata.dim(3), 4); +} -- 2.7.4