[moco/tf] tensor and filter for ShapeInferenceData (#3837)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 18 Jun 2019 05:11:55 +0000 (14:11 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 18 Jun 2019 05:11:55 +0000 (14:11 +0900)
This will add getter of tensor and setter of filter for ShapeInferenceData

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.cpp
contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.h
contrib/moco/lib/frontend/tf/src/Annotations/ShapeInferenceData.test.cpp

index 25053e2..e788c7f 100644 (file)
@@ -21,6 +21,22 @@ namespace moco
 namespace tf
 {
 
+loco::TensorShape ShapeInferenceData::tensor_shape(void) const
+{
+  loco::TensorShape shape;
+
+  shape.rank(rank());
+  for (uint32_t r = 0; r < rank(); ++r)
+  {
+    if (dim(r).known())
+      shape.dim(r) = loco::make_dimension(dim(r).value());
+    else
+      shape.dim(r).unset();
+  }
+
+  return shape;
+}
+
 loco::FeatureShape ShapeInferenceData::feature_shape(void) const
 {
   loco::FeatureShape shape;
@@ -46,5 +62,14 @@ void ShapeInferenceData::feature_shape(const loco::FeatureShape &shape)
   dim(3) = shape.depth();
 }
 
+void ShapeInferenceData::filter_shape(const loco::FilterShape &shape)
+{
+  rank(4);
+  dim(0) = shape.count();
+  dim(1) = shape.height();
+  dim(2) = shape.width();
+  dim(3) = shape.depth();
+}
+
 } // namespace tf
 } // namespace moco
index 8ea7dad..62abd1a 100644 (file)
@@ -29,7 +29,7 @@ namespace tf
 /**
  * @brief ShapeInferenceData provides shape inference data tracking from the start(input)
  *
- * @note  For Feature, NHWC is used for shape layout
+ * @note  For Feature and Filter, NHWC is used for shape layout
  */
 class ShapeInferenceData : public loco::NodeAnnotation,
                            public loco::NodeMixin<loco::NodeTrait::TensorShape>
@@ -38,9 +38,11 @@ public:
   ~ShapeInferenceData(){};
 
 public:
+  loco::TensorShape tensor_shape(void) const;
   loco::FeatureShape feature_shape(void) const;
 
   void feature_shape(const loco::FeatureShape &shape);
+  void filter_shape(const loco::FilterShape &shape);
 };
 
 } // namespace tf
index fccf29f..7808724 100644 (file)
 
 #include <gtest/gtest.h>
 
+TEST(TensorFlowFrontend, shapeinferencedata_tensor_get)
+{
+  moco::tf::ShapeInferenceData shapedata;
+
+  shapedata.rank(4);
+  shapedata.dim(0) = 1;
+  shapedata.dim(1) = 2;
+  shapedata.dim(2) = 3;
+  shapedata.dim(3) = 4;
+
+  loco::TensorShape tensor = shapedata.tensor_shape();
+
+  ASSERT_EQ(tensor.rank(), 4);
+  ASSERT_EQ(tensor.dim(0), 1);
+  ASSERT_EQ(tensor.dim(1), 2);
+  ASSERT_EQ(tensor.dim(2), 3);
+  ASSERT_EQ(tensor.dim(3), 4);
+}
+
 TEST(TensorFlowFrontend, shapeinferencedata_feature_set)
 {
   loco::FeatureShape feature;
@@ -55,3 +74,23 @@ TEST(TensorFlowFrontend, shapeinferencedata_feature_get)
   ASSERT_EQ(feature.width(), 3);
   ASSERT_EQ(feature.depth(), 4);
 }
+
+TEST(TensorFlowFrontend, shapeinferencedata_filter_set)
+{
+  loco::FilterShape filter;
+
+  filter.count() = 1;
+  filter.height() = 2;
+  filter.width() = 3;
+  filter.depth() = 4;
+
+  moco::tf::ShapeInferenceData shapedata;
+
+  shapedata.filter_shape(filter);
+
+  ASSERT_EQ(shapedata.rank(), 4);
+  ASSERT_EQ(shapedata.dim(0), 1);
+  ASSERT_EQ(shapedata.dim(1), 2);
+  ASSERT_EQ(shapedata.dim(2), 3);
+  ASSERT_EQ(shapedata.dim(3), 4);
+}