Adding assertions block for SpaceToBatch layer (#3269)
authorShubham Gupta/System SW /SRI-Bangalore/Engineer/삼성전자 <shub98.gupta@samsung.com>
Tue, 23 Oct 2018 13:21:42 +0000 (18:51 +0530)
committer샤란/System SW /SRI-Bangalore/Staff Engineer/삼성전자 <sharan.allur@samsung.com>
Tue, 23 Oct 2018 13:21:42 +0000 (18:51 +0530)
This patch will add assertions for SpaceToBatch layer in visit function to handle exceptions

Signed-off-by: shubham <shub98.gupta@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc

index 334038b..b0dcf8f 100644 (file)
@@ -3834,12 +3834,28 @@ void Planner::visit(const ::internal::tflite::op::SpaceToBatchND::Node &node)
   const ::internal::tflite::operand::Index block_size_index{node.param().block_size_index};
   const ::internal::tflite::operand::Index padding_size_index{node.param().padding_size_index};
 
-  // Currently, only 4D NHWC input/output op_context are supported.
-  // The 4D array need to have exactly 2 spatial dimensions.
-  // TODO: Support arbitrary dimension in SpaceToBatchND.
+  { // New block for assertions
 
-  assert(_ctx.at(input_index).shape().rank() == 4);
-  assert(_ctx.at(output_index).shape().rank() == 4);
+    // Currently, only 4D NHWC input/output op_context are supported.
+    // The 4D array need to have exactly 2 spatial dimensions.
+    // TODO: Support arbitrary dimension in SpaceToBatchND.
+    assert(_ctx.at(input_index).shape().rank() == 4);
+    assert(_ctx.at(output_index).shape().rank() == 4);
+
+    assert(_ctx.at(padding_size_index).hasData() == true);
+    assert(_ctx.at(block_size_index).hasData() == true);
+
+    const int32_t *block_size =
+        reinterpret_cast<const int32_t *>(_ctx.at(block_size_index).data().base());
+    auto block_size_len = _ctx.at(block_size_index).data().size() / sizeof(int32_t);
+
+    const auto &output_shape = _ctx.at(output_index).shape();
+    const auto &input_shape = _ctx.at(input_index).shape();
+
+    assert(block_size_len == 2 && block_size[0] > 0 && block_size[1] > 0);
+    assert(output_shape.dim(3) == input_shape.dim(3));
+    assert(output_shape.dim(0) == input_shape.dim(0) * (block_size[0] * block_size[1]));
+  }
 
   // Set Shape Constraints and TensorInfo
   _builder.addShapeConstr(output_index,