[moco-tf] introduce make_shape_inference_data (#6960)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 27 Aug 2019 08:13:57 +0000 (17:13 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 27 Aug 2019 08:13:57 +0000 (17:13 +0900)
This will introduce make_shape_inference_data() that makes ShapeInferenceData from NodeShape shape information

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 428114b..f560ab5 100644 (file)
@@ -27,6 +27,7 @@
 #include "Dialect/TFNodes.h"
 
 #include <loco.h>
+#include <loco/IR/NodeShape.h>
 #include <moco/Log.h>
 #include <stdex/Memory.h>
 #include <plier/tf/Convert.h>
@@ -70,6 +71,39 @@ template <class T> void copy_shape_values(const T *src, ShapeInferenceData *dst)
   }
 }
 
+std::unique_ptr<ShapeInferenceData> make_shape_inference_data(const loco::NodeShape &src)
+{
+  auto shape_data = stdex::make_unique<ShapeInferenceData>();
+
+  switch (src.domain())
+  {
+    case loco::Domain::Tensor:
+      shape_data->tensor_shape(src.as<loco::TensorShape>());
+      break;
+
+    case loco::Domain::Feature:
+      shape_data->feature_shape(src.as<loco::FeatureShape>());
+      break;
+
+    case loco::Domain::Filter:
+      shape_data->filter_shape(src.as<loco::FilterShape>());
+      break;
+
+    case loco::Domain::DepthwiseFilter:
+      shape_data->depthwisefilter_shape(src.as<loco::DepthwiseFilterShape>());
+      break;
+
+    case loco::Domain::Bias:
+      shape_data->bias_shape(src.as<loco::BiasShape>());
+      break;
+
+    default:
+      throw std::runtime_error("Unsupported Domain in make_shape_inference_data");
+  }
+
+  return std::move(shape_data);
+}
+
 /**
  * @brief  Copy ShapeInferenceData from loco::Node pointer src to dst
  */