[moco-tf] make_shape_inference_data from T (#6980)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 28 Aug 2019 07:17:06 +0000 (16:17 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 28 Aug 2019 07:17:06 +0000 (16:17 +0900)
This will introduce template make_shape_inference_data() that creates a copy of ShapeInferenceData from src

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 33f3832..2d6355b 100644 (file)
@@ -65,6 +65,31 @@ template <class T> void copy_shape_values(const T *src, ShapeInferenceData *dst)
   }
 }
 
+/**
+ * @brief  Make copy of ShapeInferenceData from src
+ *
+ * @note   T can be ShapeInferenceData or loco::Node based class having shape
+ *         attributes like TFConst, COpCall and so on
+ */
+template <class T> std::unique_ptr<ShapeInferenceData> make_shape_inference_data(const T *src)
+{
+  assert(src != nullptr);
+
+  auto shape_data = stdex::make_unique<ShapeInferenceData>();
+
+  uint32_t rank = src->rank();
+  shape_data->rank(rank);
+  for (uint32_t index = 0; index < rank; ++index)
+  {
+    if (src->dim(index).known())
+      shape_data->dim(index) = src->dim(index).value();
+    else
+      shape_data->dim(index).unset();
+  }
+
+  return std::move(shape_data);
+}
+
 std::unique_ptr<ShapeInferenceData> make_shape_inference_data(const loco::NodeShape &src)
 {
   auto shape_data = stdex::make_unique<ShapeInferenceData>();