From e334f9e270f516d1ea4836dabca6abbaecdeca56 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Vishal=20Keshav/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 18 Oct 2018 17:14:46 +0530 Subject: [PATCH] Added assertion in Reshape (#3050) Assertion verifies if input can be reshaped to output Signed-off-by: Vishal Keshav --- runtimes/pure_arm_compute/src/compilation.cc | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index e66321d..01f66e6 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -2252,22 +2252,23 @@ void Planner::visit(const ::internal::tflite::op::Reshape::Node &node) const ::internal::tflite::operand::Index output_index{node.param().output_index}; const ::internal::tflite::operand::Index input_index{node.param().input_index}; - // NOTE The content of a tensor specified by shape_index should be aligned with - // output tensor shape - // TODO Check consistency of ouput shape - // TODO Re-enable this assert // assert((ifm_shape.C * ifm_shape.H * ifm_shape.W) == out_size); // TODO Should move to the place where the operand is handled, if it is possible. - _builder.addShapeConstr(output_index, - asTensorInfo(asTensorShape(_ctx.at(output_index).shape()), - _ctx.at(output_index).type(), _ctx.at(output_index).scale(), - _ctx.at(output_index).zeroPoint())); - _builder.addShapeConstr(input_index, - asTensorInfo(asTensorShape(_ctx.at(input_index).shape()), - _ctx.at(input_index).type(), _ctx.at(input_index).scale(), - _ctx.at(input_index).zeroPoint())); + + auto input_shape = asTensorShape(_ctx.at(input_index).shape()); + auto output_shape = asTensorShape(_ctx.at(output_index).shape()); + + assert(input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3] == + output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3]); + + _builder.addShapeConstr(output_index, asTensorInfo(output_shape, _ctx.at(output_index).type(), + _ctx.at(output_index).scale(), + _ctx.at(output_index).zeroPoint())); + _builder.addShapeConstr(input_index, asTensorInfo(input_shape, _ctx.at(input_index).type(), + _ctx.at(input_index).scale(), + _ctx.at(input_index).zeroPoint())); struct Param { -- 2.7.4