#include <arm_compute/runtime/CL/CLScheduler.h>
#include "util/feature/nchw/View.h"
+#include "backend/acl_cl/Swizzle.h"
#include "util/logging.h"
namespace
assert(axis < 4);
+ // TODO Handle when axis is negative
+ assert(axis >= 0);
+
// This map converts NHWC to NCHW(reversed)
// NHWC -> WHCN
- static const uint32_t axis_map[] = {3, 1, 0, 2};
- _axis = axis_map[axis];
+ _axis = ToARMComputeAxis(output_alloc->num_dimensions(), axis).value();
// TODO Support Quant8
_input_type = OperandType::TENSOR_FLOAT32;