#define UNUSED_RELEASE(a) (void)(a)
-template <size_t from, size_t to, typename Enable = void> struct ForEachDimension
+template <size_t rest> struct ForEachDimension
{
template <typename L>
static void unroll(const onert::ir::Shape &shape, onert::ir::Coordinates &coords,
L lambda_function)
{
- static_assert(from < to, "from must not be less than to");
- assert(static_cast<int>(to) <= shape.rank());
- const auto &d = shape.dim(from);
+ if (static_cast<int>(rest) > shape.rank())
+ {
+ ForEachDimension<rest - 1>::unroll(shape, coords, lambda_function);
+ return;
+ }
+
+ const auto axis = shape.rank() - rest;
+ const auto &d = shape.dim(axis);
for (auto v = 0; v < d; v++)
{
- coords.set(from, v);
- ForEachDimension<from + 1, to>::unroll(shape, coords, lambda_function);
+ coords.set(axis, v);
+ ForEachDimension<rest - 1>::unroll(shape, coords, lambda_function);
}
}
};
-template <size_t from, size_t to>
-struct ForEachDimension<from, to, typename std::enable_if<from == to>::type>
+template <> struct ForEachDimension<0>
{
template <typename L>
static void unroll(const onert::ir::Shape &shape, onert::ir::Coordinates &coords,
L lambda_function)
{
UNUSED_RELEASE(shape);
- assert(static_cast<int>(to) <= shape.rank());
lambda_function(coords);
}
};
template <typename L> inline void ShapeLoop(const onert::ir::Shape &shape, L lambda_function)
{
- assert(shape.rank() > 0);
- for (auto i = 0; i < shape.rank(); ++i)
+ int32_t rank = shape.rank();
+ assert(rank > 0);
+ for (int32_t i = 0; i < rank; ++i)
{
assert(shape.dim(i) > 0);
}
onert::ir::Coordinates coords;
- switch (shape.rank())
+ if (rank == 0)
{
- case 0:
- coords.set(0, 0);
- ForEachDimension<0, 0>::unroll(shape, coords, lambda_function);
- break;
- case 1:
- ForEachDimension<0, 1>::unroll(shape, coords, lambda_function);
- break;
- case 2:
- ForEachDimension<0, 2>::unroll(shape, coords, lambda_function);
- break;
- case 3:
- ForEachDimension<0, 3>::unroll(shape, coords, lambda_function);
- break;
- case 4:
- ForEachDimension<0, 4>::unroll(shape, coords, lambda_function);
- break;
- case 5:
- ForEachDimension<0, 5>::unroll(shape, coords, lambda_function);
- break;
- case 6:
- ForEachDimension<0, 6>::unroll(shape, coords, lambda_function);
- break;
- default:
- assert(false && "ShapeLoop, 1 <= Shape'rank <= 6");
- break;
+ coords.set(0, 0);
}
+ // TODO Change 6 to onert::ir::Shape::kMaxRank if onert::ir::Shape::kMaxRank is modified as a
+ // constant expression
+ ForEachDimension<6>::unroll(shape, coords, lambda_function);
}
#endif // __ONERT_UTIL_UTILS_H__