Introduce 'broadcast' method (#1610)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 8 Jun 2018 05:43:31 +0000 (14:43 +0900)
committer서상민/동작제어Lab(SR)/Staff Engineer/삼성전자 <sangmin7.seo@samsung.com>
Fri, 8 Jun 2018 05:43:31 +0000 (14:43 +0900)
This commit introduces 'broadcast' method which computes the broadcasted
shape from two tensor shapes.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
include/support/tflite/TensorShapeUtils.h
libs/support/tflite/src/TensorShapeUtils.cpp [new file with mode: 0644]

index 959e232..0c45add 100644 (file)
@@ -41,6 +41,9 @@ std::vector<int32_t> as_dims(const nnfw::util::tensor::Shape &shape)
   return dims;
 }
 
+nnfw::util::tensor::Shape broadcast(const nnfw::util::tensor::Shape &lhs_shape,
+                                    const nnfw::util::tensor::Shape &rhs_shape);
+
 } // namespace tflite
 } // namespace support
 } // namespace nnfw
diff --git a/libs/support/tflite/src/TensorShapeUtils.cpp b/libs/support/tflite/src/TensorShapeUtils.cpp
new file mode 100644 (file)
index 0000000..611ba92
--- /dev/null
@@ -0,0 +1,51 @@
+#include "support/tflite/TensorShapeUtils.h"
+
+namespace nnfw
+{
+namespace support
+{
+namespace tflite
+{
+
+nnfw::util::tensor::Shape broadcast(const nnfw::util::tensor::Shape &lhs_shape,
+                                    const nnfw::util::tensor::Shape &rhs_shape)
+{
+  const uint32_t lhs_rank = lhs_shape.rank();
+  const uint32_t rhs_rank = rhs_shape.rank();
+  const uint32_t out_rank = std::max(lhs_rank, rhs_rank);
+
+  // TODO Simplify implementation
+  std::vector<int32_t> lhs_normalized_dims;
+  std::vector<int32_t> rhs_normalized_dims;
+
+  for (uint32_t n = 0; n < out_rank - lhs_rank; ++n)
+  {
+    lhs_normalized_dims.emplace_back(1);
+  }
+  for (uint32_t axis = 0; axis < lhs_rank; ++axis)
+  {
+    lhs_normalized_dims.emplace_back(lhs_shape.dim(axis));
+  }
+
+  for (uint32_t n = 0; n < out_rank - rhs_rank; ++n)
+  {
+    rhs_normalized_dims.emplace_back(1);
+  }
+  for (uint32_t axis = 0; axis < rhs_rank; ++axis)
+  {
+    rhs_normalized_dims.emplace_back(rhs_shape.dim(axis));
+  }
+
+  nnfw::util::tensor::Shape out_shape(out_rank);
+
+  for (uint32_t axis = 0; axis < out_rank; ++axis)
+  {
+    out_shape.dim(axis) = std::max(lhs_normalized_dims.at(axis), rhs_normalized_dims.at(axis));
+  }
+
+  return out_shape;
+}
+
+} // namespace tflite
+} // namespace support
+} // namespace nnfw