#include <arm_compute/runtime/CL/functions/CLCopy.h>
#include <arm_compute/runtime/CL/functions/CLNormalizationLayerEx.h>
#include <arm_compute/runtime/CL/functions/CLExp.h>
-#include <arm_compute/runtime/CL/functions/CLBatchToSpaceND.h>
+#include <arm_compute/runtime/CL/functions/CLBatchToSpaceLayer.h>
#include <arm_compute/runtime/CL/functions/CLEmbeddingLookup.h>
#include <arm_compute/runtime/CL/functions/CLSquaredDifference.h>
#include <arm_compute/runtime/CL/functions/CLNeg.h>
assert(_ctx.at(input_index).shape().rank() == 4);
assert(_ctx.at(output_index).shape().rank() == 4);
+ assert(_ctx.at(block_size_index).shape().rank() == 1);
assert(_ctx.at(block_size_index).hasData() == true);
const int32_t *block_size =
const auto &output_shape = _ctx.at(output_index).shape();
const auto &input_shape = _ctx.at(input_index).shape();
-
- assert((_ctx.at(block_size_index).data().size() / sizeof(int32_t)) == 2 && block_size[0] > 0 &&
- block_size[1] > 0);
- {
- assert(output_shape.dim(3) == input_shape.dim(3));
- assert(output_shape.dim(1) == input_shape.dim(1) * block_size[0]);
- assert(output_shape.dim(2) == input_shape.dim(2) * block_size[1]);
- assert(output_shape.dim(0) == input_shape.dim(0) / (block_size[0] * block_size[1]));
- }
+ const auto &block_size_shape = _ctx.at(block_size_index).shape();
// Set Shape Constraints and TensorInfo
_builder.addShapeConstr(
input_index, asTensorInfo(asTensorShape(input_shape, false), _ctx.at(input_index).type(),
_ctx.at(input_index).scale(), _ctx.at(input_index).zeroPoint()));
+ _builder.addShapeConstr(block_size_index, asTensorInfo(asTensorShape(block_size_shape),
+ _ctx.at(block_size_index).type(),
+ _ctx.at(block_size_index).scale(),
+ _ctx.at(block_size_index).zeroPoint()));
+
+ // initializer for block_size
+ {
+ const auto block_size_base =
+ reinterpret_cast<const int32_t *>(_ctx.at(block_size_index).data().base());
+
+ assert(output_shape.dim(3) == input_shape.dim(3));
+ assert(output_shape.dim(1) == input_shape.dim(1) * block_size_base[0]);
+ assert(output_shape.dim(2) == input_shape.dim(2) * block_size_base[1]);
+ assert(output_shape.dim(0) == input_shape.dim(0) / (block_size_base[0] * block_size_base[1]));
+ assert(_ctx.at(block_size_index).type() == ANEURALNETWORKS_TENSOR_INT32);
+
+ assert((_ctx.at(block_size_index).data().size() / sizeof(int32_t)) == 2 &&
+ block_size_base[0] > 0 && block_size_base[1] > 0);
+
+ auto initializer = [block_size_base](::arm_compute::ITensor &tensor) {
+ const int32_t *from = reinterpret_cast<const int32_t *>(block_size_base);
+ int32_t *into = reinterpret_cast<int32_t *>(tensor.ptr_to_element({0}));
+ into[0] = from[1];
+ into[1] = from[0];
+ };
+ _builder.addInitializer(block_size_index, initializer);
+ }
+
// Construct operation parameters
struct Param
{
int output_index;
int input_index;
+ int block_size_index;
const int32_t *block_size;
int32_t rank;
};
param.output_index = output_index.asInt();
param.input_index = input_index.asInt();
+ param.block_size_index = block_size_index.asInt();
param.block_size = block_size;
param.rank = _ctx.at(input_index).shape().rank();
auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index});
auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index});
+ auto block_size_alloc = ctx.at(::internal::tflite::operand::Index{param.block_size_index});
// NOTE SimpleBatchToSpaceND is quite slow, but may be useful for debugging
if (from_env<bool>(std::getenv("USE_SIMPLE_BATCH_TO_SPACE_ND")))
}
else if (::internal::arm_compute::isGpuMode())
{
- auto fn = nnfw::cpp14::make_unique<::arm_compute::CLBatchToSpaceND>();
+ auto fn = nnfw::cpp14::make_unique<::arm_compute::CLBatchToSpaceLayer>();
- fn->configure(CAST_CL(input_alloc), CAST_CL(output_alloc), param.block_size);
+ fn->configure(CAST_CL(input_alloc), CAST_CL(block_size_alloc), CAST_CL(output_alloc));
builder.append("BatchToSpaceND", std::move(fn));
}
else
// TODO Enable NEON Support
throw std::runtime_error("Not supported, yet");
}
-
};
_builder.addStage(stage);