[moco-tf] introduce as_feature_shape (#5796)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 24 Jul 2019 02:15:28 +0000 (11:15 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 24 Jul 2019 02:15:28 +0000 (11:15 +0900)
* [moco-tf] introduce as_feature_shape

This will introduce as_feature_shape for ShapeInferenceData that will return converted feature shape if it's tensor shape with given data layout

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* use TFDataLayout

compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp
compiler/moco-tf/src/Annotations/ShapeInferenceData.h
compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp

index 5115526..e5b49a1 100644 (file)
@@ -132,5 +132,47 @@ void ShapeInferenceData::bias_shape(const loco_tobe::BiasShape &shape)
   dim(0) = shape;
 }
 
+loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
+                                    const TFDataLayout &data_layout)
+{
+  if (shapedata.domain() == loco::Domain::Feature)
+    return shapedata.feature_shape();
+
+  loco::FeatureShape feature_shape;
+
+  // only convert from tensor to feature
+  if (shapedata.domain() != loco::Domain::Tensor)
+  {
+    throw std::runtime_error("as_feature_shape: domain is not tensor");
+  }
+  if (shapedata.rank() != 4)
+  {
+    throw std::runtime_error("as_feature_shape: rank is not 4");
+  }
+
+  // TODO support for other data_layout if needed
+  if (data_layout != "NHWC" && data_layout != "NCHW")
+  {
+    throw std::runtime_error("as_feature_shape: only supports NHWC or NCHW");
+  }
+
+  if (data_layout == "NHWC")
+  {
+    feature_shape.count() = shapedata.dim(0);
+    feature_shape.height() = shapedata.dim(1);
+    feature_shape.width() = shapedata.dim(2);
+    feature_shape.depth() = shapedata.dim(3);
+  }
+  else
+  {
+    feature_shape.count() = shapedata.dim(0);
+    feature_shape.depth() = shapedata.dim(1);
+    feature_shape.height() = shapedata.dim(2);
+    feature_shape.width() = shapedata.dim(3);
+  }
+
+  return feature_shape;
+}
+
 } // namespace tf
 } // namespace moco
index afbe488..0174005 100644 (file)
@@ -34,6 +34,9 @@ namespace moco
 namespace tf
 {
 
+/// @note  Below alias may be introduced as separate class
+using TFDataLayout = std::string;
+
 /**
  * @brief ShapeInferenceData provides shape inference data tracking from the start(input)
  *
@@ -63,6 +66,9 @@ private:
   loco::Domain _domain{loco::Domain::Tensor};
 };
 
+loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
+                                    const TFDataLayout &data_layout);
+
 } // namespace tf
 } // namespace moco
 
index bf10203..5269cd2 100644 (file)
@@ -97,3 +97,25 @@ TEST(TensorFlowImport, shapeinferencedata_bias)
 
   ASSERT_EQ(bias_g, 3);
 }
+
+TEST(TensorFlowImport, shapeinferencedata_as_feature)
+{
+  loco::TensorShape tensor_s;
+
+  tensor_s.rank(4);
+  tensor_s.dim(0) = 1;
+  tensor_s.dim(1) = 2;
+  tensor_s.dim(2) = 3;
+  tensor_s.dim(3) = 4;
+
+  moco::tf::ShapeInferenceData shapedata;
+
+  shapedata.tensor_shape(tensor_s);
+
+  loco::FeatureShape feature_g = as_feature_shape(shapedata, "NHWC");
+
+  ASSERT_EQ(feature_g.count(), 1);
+  ASSERT_EQ(feature_g.height(), 2);
+  ASSERT_EQ(feature_g.width(), 3);
+  ASSERT_EQ(feature_g.depth(), 4);
+}