[moco-tf] introduce node_shape method (#6962)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 28 Aug 2019 01:08:56 +0000 (10:08 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 28 Aug 2019 01:08:56 +0000 (10:08 +0900)
* [moco-tf] introduce node_shape method

This will introduce node_shape() method in FixShapeTransform that returns NodeShape from Node as a common shape information for dialect independent

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* header for shape_known

compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index f560ab5..828f928 100644 (file)
@@ -28,6 +28,7 @@
 
 #include <loco.h>
 #include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
 #include <moco/Log.h>
 #include <stdex/Memory.h>
 #include <plier/tf/Convert.h>
@@ -126,6 +127,52 @@ bool copy_shapedata(const loco::Node *src, loco::Node *dst)
   return true;
 }
 
+/**
+ * @note  While in shape inference, Node maybe Canonical, TF dialect or other dialects
+ *        This will provide common loco::NodeShape as shape information
+ */
+bool node_shape(const loco::Node *node, loco::NodeShape &nodeshape)
+{
+  if (loco::shape_known(node))
+  {
+    nodeshape = loco::shape_get(node);
+    return true;
+  }
+
+  auto shapedata = node->annot<ShapeInferenceData>();
+  if (shapedata == nullptr)
+  {
+    return false;
+  }
+
+  switch (shapedata->domain())
+  {
+    case loco::Domain::Tensor:
+      nodeshape.set(shapedata->tensor_shape());
+      break;
+
+    case loco::Domain::Feature:
+      nodeshape.set(shapedata->feature_shape());
+      break;
+
+    case loco::Domain::Filter:
+      nodeshape.set(shapedata->filter_shape());
+      break;
+
+    case loco::Domain::DepthwiseFilter:
+      nodeshape.set(shapedata->depthwisefilter_shape());
+      break;
+
+    case loco::Domain::Bias:
+      nodeshape.set(shapedata->bias_shape());
+      break;
+
+    default:
+      throw std::runtime_error("Unsupported Domain in node_shape()");
+  }
+  return true;
+}
+
 struct FixPadContext
 {
   uint32_t input_height;