[moco-tf] Introduce as_feature_shape in FixShapeTransform (#6989)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 28 Aug 2019 06:53:03 +0000 (15:53 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 28 Aug 2019 06:53:03 +0000 (15:53 +0900)
This will introduce as_feature_shape() that returns FeatureShape from loco::NodeShape with TFDataLayout

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 701145b..8a74e92 100644 (file)
@@ -166,6 +166,50 @@ bool node_shape(const loco::Node *node, loco::NodeShape &nodeshape)
   return true;
 }
 
+loco::FeatureShape as_feature_shape(const loco::NodeShape &nodeshape,
+                                    const TFDataLayout &data_layout)
+{
+  if (nodeshape.domain() == loco::Domain::Feature)
+    return nodeshape.as<loco::FeatureShape>();
+
+  loco::FeatureShape feature_shape;
+
+  // only convert from tensor to feature
+  if (nodeshape.domain() != loco::Domain::Tensor)
+  {
+    throw std::runtime_error("as_feature_shape: domain is not tensor");
+  }
+
+  loco::TensorShape tensor_shape = nodeshape.as<loco::TensorShape>();
+
+  if (tensor_shape.rank() != 4)
+  {
+    throw std::runtime_error("as_feature_shape: rank is not 4");
+  }
+
+  if (data_layout == "NHWC")
+  {
+    feature_shape.count() = tensor_shape.dim(0);
+    feature_shape.height() = tensor_shape.dim(1);
+    feature_shape.width() = tensor_shape.dim(2);
+    feature_shape.depth() = tensor_shape.dim(3);
+  }
+  else if (data_layout == "NCHW")
+  {
+    feature_shape.count() = tensor_shape.dim(0);
+    feature_shape.depth() = tensor_shape.dim(1);
+    feature_shape.height() = tensor_shape.dim(2);
+    feature_shape.width() = tensor_shape.dim(3);
+  }
+  else
+  {
+    // TODO support for other data_layout if needed
+    throw std::runtime_error("as_feature_shape: only supports NHWC or NCHW");
+  }
+
+  return feature_shape;
+}
+
 struct FixPadContext
 {
   uint32_t input_height;