Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / include / util / Utils.h
index 505f5a9..6b6bc24 100644 (file)
 
 #define UNUSED_RELEASE(a) (void)(a)
 
-template <size_t from, size_t to, typename Enable = void> struct ForEachDimension
+template <size_t rest> struct ForEachDimension
 {
   template <typename L>
   static void unroll(const onert::ir::Shape &shape, onert::ir::Coordinates &coords,
                      L lambda_function)
   {
-    static_assert(from < to, "from must not be less than to");
-    assert(static_cast<int>(to) <= shape.rank());
-    const auto &d = shape.dim(from);
+    if (static_cast<int>(rest) > shape.rank())
+    {
+      ForEachDimension<rest - 1>::unroll(shape, coords, lambda_function);
+      return;
+    }
+
+    const auto axis = shape.rank() - rest;
+    const auto &d = shape.dim(axis);
 
     for (auto v = 0; v < d; v++)
     {
-      coords.set(from, v);
-      ForEachDimension<from + 1, to>::unroll(shape, coords, lambda_function);
+      coords.set(axis, v);
+      ForEachDimension<rest - 1>::unroll(shape, coords, lambda_function);
     }
   }
 };
 
-template <size_t from, size_t to>
-struct ForEachDimension<from, to, typename std::enable_if<from == to>::type>
+template <> struct ForEachDimension<0>
 {
   template <typename L>
   static void unroll(const onert::ir::Shape &shape, onert::ir::Coordinates &coords,
                      L lambda_function)
   {
     UNUSED_RELEASE(shape);
-    assert(static_cast<int>(to) <= shape.rank());
     lambda_function(coords);
   }
 };
 
 template <typename L> inline void ShapeLoop(const onert::ir::Shape &shape, L lambda_function)
 {
-  assert(shape.rank() > 0);
-  for (auto i = 0; i < shape.rank(); ++i)
+  int32_t rank = shape.rank();
+  assert(rank > 0);
+  for (int32_t i = 0; i < rank; ++i)
   {
     assert(shape.dim(i) > 0);
   }
 
   onert::ir::Coordinates coords;
-  switch (shape.rank())
+  if (rank == 0)
   {
-    case 0:
-      coords.set(0, 0);
-      ForEachDimension<0, 0>::unroll(shape, coords, lambda_function);
-      break;
-    case 1:
-      ForEachDimension<0, 1>::unroll(shape, coords, lambda_function);
-      break;
-    case 2:
-      ForEachDimension<0, 2>::unroll(shape, coords, lambda_function);
-      break;
-    case 3:
-      ForEachDimension<0, 3>::unroll(shape, coords, lambda_function);
-      break;
-    case 4:
-      ForEachDimension<0, 4>::unroll(shape, coords, lambda_function);
-      break;
-    case 5:
-      ForEachDimension<0, 5>::unroll(shape, coords, lambda_function);
-      break;
-    case 6:
-      ForEachDimension<0, 6>::unroll(shape, coords, lambda_function);
-      break;
-    default:
-      assert(false && "ShapeLoop, 1 <= Shape'rank <= 6");
-      break;
+    coords.set(0, 0);
   }
+  // TODO Change 6 to onert::ir::Shape::kMaxRank if onert::ir::Shape::kMaxRank is modified as a
+  // constant expression
+  ForEachDimension<6>::unroll(shape, coords, lambda_function);
 }
 #endif // __ONERT_UTIL_UTILS_H__