[moco-tf] Introduce as_node_shape (#7476)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 16 Sep 2019 22:00:23 +0000 (07:00 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 16 Sep 2019 22:00:23 +0000 (07:00 +0900)
This will introduce as_node_shape() method in FixShapeTransform that returns loco::NodeShape from ShapeInferenceData

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 5be31ac..d7ab271 100644 (file)
@@ -132,6 +132,29 @@ std::unique_ptr<ShapeInferenceData> make_shape_inference_data(const loco::NodeSh
   return std::move(shape_data);
 }
 
+loco::NodeShape as_node_shape(const ShapeInferenceData *shapedata)
+{
+  switch (shapedata->domain())
+  {
+    case loco::Domain::Tensor:
+      return loco::NodeShape({shapedata->tensor_shape()});
+
+    case loco::Domain::Feature:
+      return loco::NodeShape({shapedata->feature_shape()});
+
+    case loco::Domain::Filter:
+      return loco::NodeShape({shapedata->filter_shape()});
+
+    case loco::Domain::DepthwiseFilter:
+      return loco::NodeShape({shapedata->depthwisefilter_shape()});
+
+    case loco::Domain::Bias:
+      return loco::NodeShape({shapedata->bias_shape()});
+  }
+
+  throw std::runtime_error("Unsupported Domain in as_node_shape");
+}
+
 /**
  * @brief  Copy ShapeInferenceData from loco::Node pointer src to dst
  */