[moco-tf] Introduce as_tensor_shape with FeatureShape (#5973)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 29 Jul 2019 22:22:42 +0000 (07:22 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 29 Jul 2019 22:22:42 +0000 (07:22 +0900)
This will introduce ShapeInferenceData helper to set as TensorShape with FeatureShape and TFDataLayout

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp
compiler/moco-tf/src/Annotations/ShapeInferenceData.h
compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp

index e1338d7..14705f2 100644 (file)
@@ -132,6 +132,35 @@ void ShapeInferenceData::bias_shape(const loco_tobe::BiasShape &shape)
   dim(0) = shape;
 }
 
+void as_tensor_shape(ShapeInferenceData &shapedata, const loco::FeatureShape &feature_shape,
+                     const TFDataLayout &data_layout)
+{
+  loco::TensorShape tensor_shape;
+
+  tensor_shape.rank(4);
+  if (data_layout == "NHWC")
+  {
+    tensor_shape.dim(0) = feature_shape.count();
+    tensor_shape.dim(1) = feature_shape.height();
+    tensor_shape.dim(2) = feature_shape.width();
+    tensor_shape.dim(3) = feature_shape.depth();
+  }
+  else if (data_layout == "NCHW")
+  {
+    tensor_shape.dim(0) = feature_shape.count();
+    tensor_shape.dim(1) = feature_shape.depth();
+    tensor_shape.dim(2) = feature_shape.height();
+    tensor_shape.dim(3) = feature_shape.width();
+  }
+  else
+  {
+    // TODO support for other data_layout if needed
+    throw std::runtime_error("as_tensor_shape: only supports NHWC or NCHW");
+  }
+
+  shapedata.tensor_shape(tensor_shape);
+}
+
 loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
                                     const TFDataLayout &data_layout)
 {
index 164db93..ddacf53 100644 (file)
@@ -66,6 +66,9 @@ private:
   loco::Domain _domain{loco::Domain::Tensor};
 };
 
+void as_tensor_shape(ShapeInferenceData &shapedata, const loco::FeatureShape &shape,
+                     const TFDataLayout &data_layout);
+
 loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
                                     const TFDataLayout &data_layout);
 
index 616d0d8..486a935 100644 (file)
@@ -98,6 +98,30 @@ TEST(TensorFlowImport, shapeinferencedata_bias)
   ASSERT_EQ(bias_g, 3);
 }
 
+TEST(TensorFlowImport, shapeinferencedata_as_tensor_set)
+{
+  loco::FeatureShape feature_s;
+
+  feature_s.count() = 1;
+  feature_s.height() = 2;
+  feature_s.width() = 3;
+  feature_s.depth() = 4;
+
+  moco::tf::ShapeInferenceData shapedata;
+
+  as_tensor_shape(shapedata, feature_s, "NHWC");
+
+  loco::TensorShape tensor_g;
+
+  tensor_g = shapedata.tensor_shape();
+
+  ASSERT_EQ(tensor_g.rank(), 4);
+  ASSERT_EQ(tensor_g.dim(0), 1);
+  ASSERT_EQ(tensor_g.dim(1), 2);
+  ASSERT_EQ(tensor_g.dim(2), 3);
+  ASSERT_EQ(tensor_g.dim(3), 4);
+}
+
 TEST(TensorFlowImport, shapeinferencedata_as_feature)
 {
   loco::TensorShape tensor_s;