From 55ebc861bafde178d5df087995cb74163e071d81 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=82=A8=EA=B6=81=EC=84=9D/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 3 Sep 2019 07:20:14 +0900 Subject: [PATCH] [moco-tf] Bugfix of FixShapeTransform of Squeeze (#7031) If `Squeeze` operation have `squeeze_dims` as `[-1,-2]` and input rank is 4, it should be converted to `[2, 3]` According to current implementation, the code erase -1 first and then put 3 in the `squeeze_dims` However, it causes some bug. If -1 is erased, iterator of `squeeze_dims` points -2. And then if 3 is inserted, the iterator points 3 because it is `Set`! Now, `squeeze_dims` is `[-2, 3]` and iterator try to iterate next element but there is no more element because 3 is the last element. Therefore this commit will fix this bug by introducing new `Set` Signed-off-by: Seok NamKoong --- compiler/moco-tf/src/Transforms/FixShapeTransform.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index 817d9c8..bbf8c9a 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -1634,13 +1634,13 @@ bool fix_shape(moco::tf::TFSqueeze *node) } // Resolve negative squeeze dimension + std::set resolved_squeeze_dims; for (auto squeeze_dim : squeeze_dims) { if (squeeze_dim < 0) - { - squeeze_dims.erase(squeeze_dim); - squeeze_dims.insert(squeeze_dim + (int64_t)input_rank); - } + resolved_squeeze_dims.insert(squeeze_dim + (int64_t)input_rank); + else + resolved_squeeze_dims.insert(squeeze_dim); } // Remove squeeze dimensions only @@ -1648,7 +1648,7 @@ bool fix_shape(moco::tf::TFSqueeze *node) { assert(input_tensor_shape.dim(axis).known()); auto dim = input_tensor_shape.dim(axis).value(); - if (squeeze_dims.find((int64_t)axis) == squeeze_dims.cend()) + if (resolved_squeeze_dims.find((int64_t)axis) == resolved_squeeze_dims.cend()) { // Not squeeze dim node_shape.rank(++node_rank); -- 2.7.4