From 49d39a5b1cb8c020c7c256cbef7bb3be00f75544 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Shubham=20Gupta/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 23 Oct 2018 18:51:42 +0530 Subject: [PATCH] Adding assertions block for SpaceToBatch layer (#3269) This patch will add assertions for SpaceToBatch layer in visit function to handle exceptions Signed-off-by: shubham --- runtimes/pure_arm_compute/src/compilation.cc | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 334038b..b0dcf8f 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -3834,12 +3834,28 @@ void Planner::visit(const ::internal::tflite::op::SpaceToBatchND::Node &node) const ::internal::tflite::operand::Index block_size_index{node.param().block_size_index}; const ::internal::tflite::operand::Index padding_size_index{node.param().padding_size_index}; - // Currently, only 4D NHWC input/output op_context are supported. - // The 4D array need to have exactly 2 spatial dimensions. - // TODO: Support arbitrary dimension in SpaceToBatchND. + { // New block for assertions - assert(_ctx.at(input_index).shape().rank() == 4); - assert(_ctx.at(output_index).shape().rank() == 4); + // Currently, only 4D NHWC input/output op_context are supported. + // The 4D array need to have exactly 2 spatial dimensions. + // TODO: Support arbitrary dimension in SpaceToBatchND. + assert(_ctx.at(input_index).shape().rank() == 4); + assert(_ctx.at(output_index).shape().rank() == 4); + + assert(_ctx.at(padding_size_index).hasData() == true); + assert(_ctx.at(block_size_index).hasData() == true); + + const int32_t *block_size = + reinterpret_cast(_ctx.at(block_size_index).data().base()); + auto block_size_len = _ctx.at(block_size_index).data().size() / sizeof(int32_t); + + const auto &output_shape = _ctx.at(output_index).shape(); + const auto &input_shape = _ctx.at(input_index).shape(); + + assert(block_size_len == 2 && block_size[0] > 0 && block_size[1] > 0); + assert(output_shape.dim(3) == input_shape.dim(3)); + assert(output_shape.dim(0) == input_shape.dim(0) * (block_size[0] * block_size[1])); + } // Set Shape Constraints and TensorInfo _builder.addShapeConstr(output_index, -- 2.7.4