Support automatic dimension size in Reshape operation (#790)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Wed, 25 Jul 2018 12:07:34 +0000 (15:07 +0300)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Wed, 25 Jul 2018 12:07:34 +0000 (21:07 +0900)
Update ShapeInference to support auto dimension( -1 ) in reshape output shape

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp
contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.test.cpp [new file with mode: 0644]

index d3751f5..2b97da4 100644 (file)
@@ -1,4 +1,5 @@
 #include <cmath>
+#include <limits>
 
 #include "nnc/core/IR/model/actions/ShapeInference.h"
 
@@ -225,6 +226,32 @@ void ShapeInference::visit(ADT::INode::Ref node, ops::ReshapeOp &op)
 {
   // Reshape should have it's output shape filled by importer/user
   fillInputShapes(node, op);
+  auto& inShape = op.getInputShape(0);
+  auto outShape = op.getOutputShape(0);
+
+  //-1 in shape underflows to UINT32_MAX
+  const auto autoDimension  = std::numeric_limits<uint32_t>::max();
+
+  assert(inShape.rank() == outShape.rank());
+
+  auto inElementsNum = num_elements(inShape);
+  uint32_t outElementsNum = 1;
+  //can't use num_elements due to -1 in input shape and Shape using unsigned ints for dimensions
+  for( uint32_t d = 0; d < outShape.rank(); ++d ) {
+    auto dim = outShape.dim(d);
+    if( dim != autoDimension ) {
+      outElementsNum *= dim;
+    }
+  }
+
+  for( uint32_t d = 0; d < outShape.rank(); ++d ) {
+    auto& dim = outShape.dim(d);
+    if( dim ==  autoDimension ) {
+      dim = static_cast<uint32_t>(inElementsNum / outElementsNum);
+    }
+  }
+
+  op.setOutputShape(0, outShape);
 }
 
 } // namespace model
diff --git a/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.test.cpp b/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.test.cpp
new file mode 100644 (file)
index 0000000..1ba2251
--- /dev/null
@@ -0,0 +1,28 @@
+#include "nnc/core/IR/model/graph/graph.h"
+#include "nnc/core/IR/model/actions/ShapeInference.h"
+#include "nnc/core/IR/model/operations/reshape_op.h"
+#include "nncc/core/ADT/tensor/Shape.h"
+
+#include "gtest/gtest.h"
+
+using namespace nncc::contrib::core::IR::model;
+using nncc::core::ADT::tensor::Shape;
+
+TEST(ShapeInferenceTest, ReshapeAutoDimension) {
+  Graph g;
+  ShapeInference si;
+
+  Shape resultShape{ 10, 1, 10 };
+
+  auto input = g.create<ops::VariableOp>("input");
+  input->getOperation()->setOutputShape(0, Shape{ 10, 2, 5} );
+
+  auto n = g.create<ops::ReshapeOp>("reshape");
+  n->getOperation()->setInputShape( 0, Shape{10, 2, 5} );
+  n->getOperation()->setOutputShape(0, Shape{10, 1, static_cast<uint32_t >(-1)} );
+  n->connectInputTo(0, input->getOutput(0));
+
+  si.visit(n, *static_cast<ops::ReshapeOp*>(n->getOperation()));
+
+  ASSERT_EQ(resultShape, n->getOperation()->getOutputShape(0));
+}