[nnc core] Add ShapeInference class (#452)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Tue, 3 Jul 2018 12:42:24 +0000 (16:42 +0400)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Tue, 3 Jul 2018 12:42:24 +0000 (21:42 +0900)
Add ShapeInference class

Used to infer shapes of intermediate nodes based only on input shapes

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/libs/core/include/nnc/core/IR/model/actions/ShapeInference.h [new file with mode: 0644]
contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp [new file with mode: 0644]

diff --git a/contrib/nnc/libs/core/include/nnc/core/IR/model/actions/ShapeInference.h b/contrib/nnc/libs/core/include/nnc/core/IR/model/actions/ShapeInference.h
new file mode 100644 (file)
index 0000000..3ae1474
--- /dev/null
@@ -0,0 +1,44 @@
+#ifndef _NNC_CORE_IR_MODEL_SHAPE_INFERENCE_
+#define _NNC_CORE_IR_MODEL_SHAPE_INFERENCE_
+
+#include "nnc/core/IR/model/visitor/visitor.h"
+#include "nnc/core/IR/model/graph/ir_node.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace IR
+{
+namespace model
+{
+
+using namespace nncc::contrib::core::IR::model;
+
+class ShapeInference : public Visitor {
+ public:
+  void visit(ADT::INode::Ref node, ops::ConcatOp &op) override;
+  void visit(ADT::INode::Ref node, ops::Conv2DOp &op) override;
+  void visit(ADT::INode::Ref node, ops::DepthwiseConv2DOp &op) override;
+  void visit(ADT::INode::Ref node, ops::ReluOp &op) override;
+  void visit(ADT::INode::Ref node, ops::SoftmaxOp &op) override;
+  void visit(ADT::INode::Ref node, ops::PoolOp &op) override;
+  void visit(ADT::INode::Ref node, ops::FullyConnectedOp &op) override;
+  void visit(ADT::INode::Ref node, ops::CappedReluOp &op) override;
+  void visit(ADT::INode::Ref node, ops::BiasAddOp &op) override;
+  void visit(ADT::INode::Ref node, ops::ReshapeOp &op) override;
+  void visit(ADT::INode::Ref node, ops::VariableOp &op) override;
+
+ protected:
+  void fillInputShapes(ADT::INode::Ref node, OpDescription &op);
+};
+
+} // namespace model
+} // namespace IR
+} // namespace core
+} // namespace contrib
+} // namespace nncc
+
+#endif //_NNC_CORE_IR_MODEL_SHAPE_INFERENCE_
diff --git a/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp b/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp
new file mode 100644 (file)
index 0000000..d3751f5
--- /dev/null
@@ -0,0 +1,234 @@
+#include <cmath>
+
+#include "nnc/core/IR/model/actions/ShapeInference.h"
+
+#include "nnc/core/IR/model/operations/fully_connected_op.h"
+#include "nnc/core/IR/model/operations/softmax_op.h"
+#include "nnc/core/IR/model/operations/capped_relu_op.h"
+#include "nnc/core/IR/model/operations/depthwise_conv2d_op.h"
+#include "nnc/core/IR/model/operations/conv_2d_op.h"
+#include "nnc/core/IR/model/operations/pool_op.h"
+#include "nnc/core/IR/model/operations/variable_op.h"
+#include "nnc/core/IR/model/operations/relu_op.h"
+#include "nnc/core/IR/model/operations/concat_op.h"
+#include "nnc/core/IR/model/operations/bias_add_op.h"
+#include "nnc/core/IR/model/operations/reshape_op.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace IR
+{
+namespace model
+{
+
+using nncc::core::ADT::tensor::Shape;
+
+std::vector<int> calculate2DPaddings(ops::PaddingType paddingType, const Shape& inShape,
+    const Shape& windowShape, const Shape& strides, Shape& outShape)
+{
+  auto inRank = inShape.rank();
+  // Assuming input tensor is 3-dimensional. Will support more general cases as needed.
+  assert(inRank == 3);
+  std::vector<int> paddings(3);
+
+  if (paddingType == ops::PaddingType::Same)
+  {
+    for (uint32_t d = 0; d < inRank - 1; ++d)
+    {
+      outShape.dim(d) = (inShape.dim(d) - 1) / strides.dim(d) + 1;
+      int pad_along_axis;
+      if (inShape.dim(d) % strides.dim(d) == 0)
+      {
+        pad_along_axis = std::max((int)windowShape.dim(d) - (int)strides.dim(d), 0);
+      }
+      else
+      {
+        pad_along_axis = std::max((int)(outShape.dim(d) - 1) * (int)strides.dim(d) +
+                                  (int)windowShape.dim(d) - (int)inShape.dim(d),
+                                  0);
+      }
+      paddings[d] = pad_along_axis / 2;
+    }
+  }
+  else
+  {
+    for (uint32_t d = 0; d < inRank - 1; ++d)
+    {
+      outShape.dim(d) = (inShape.dim(d) - windowShape.dim(d)) / strides.dim(d) + 1;
+      paddings[d] = 0;
+    }
+  }
+
+  return paddings;
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::ConcatOp &op)
+{
+  fillInputShapes(node, op);
+
+  uint32_t axis = op.getAxis();
+  Shape outShape;
+  outShape.resize(op.getInputShape(0).rank());
+
+  for (uint32_t d = 0; d < outShape.rank(); ++d)
+  {
+    outShape.dim(d) = op.getInputShape(0).dim(d);
+  }
+  outShape.dim(axis) = 0;
+
+  for (uint32_t i = 0; i < op.getNumInputs(); ++i)
+  {
+    outShape.dim(axis) += op.getInputShape(i).dim(axis);
+  }
+
+  op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::Conv2DOp &op)
+{
+  fillInputShapes(node, op);
+
+  Shape outShape;
+  outShape.resize(3);
+  auto &strides = op.getStrides();
+  auto &kernel = op.getKernel();
+  auto &inShape = op.getInputShape(0);
+  auto &kernelShape = kernel.getShape();
+  uint32_t inRank = inShape.rank();
+
+  auto pads = calculate2DPaddings(op.getPaddingType(), inShape, kernelShape, strides, outShape);
+  for (size_t i = 0; i < pads.size(); ++i)
+  {
+    op.setPadding(i, pads[i]);
+  }
+
+  outShape.dim(inRank - 1) = kernelShape.dim(kernelShape.rank() - 1);
+  op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::VariableOp &op)
+{
+  (void)op;
+  (void)node;
+  // No need to do anything for inputs. These should be set by user
+}
+
+void ShapeInference::fillInputShapes(ADT::INode::Ref node, OpDescription &op)
+{
+  uint32_t i = 0;
+  for (auto &in : node->getPrevNodes())
+  {
+    const Shape &inShape = in.node->getOperation()->getOutputShape(in.index);
+    op.setInputShape(i++, inShape);
+  }
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::ReluOp &op)
+{
+  fillInputShapes(node, op);
+  op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::SoftmaxOp &op)
+{
+  fillInputShapes(node, op);
+  op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::PoolOp &op)
+{
+  fillInputShapes(node, op);
+
+  Shape outShape;
+  outShape.resize(3);
+  auto &strides = op.getStrides();
+  auto &windowShape = op.getWindowShape();
+  auto &inShape = op.getInputShape(0);
+  const uint32_t inRank = inShape.rank();
+
+  // Assuming input tensor is 3-dimensional. Will support more general cases when needed.
+  assert(inRank == 3);
+
+  auto pads = calculate2DPaddings(op.getPaddingType(), inShape, windowShape, strides, outShape);
+  for (uint32_t d = 0; d < inShape.rank(); ++d)
+  {
+    op.setPadding(d, pads[d]);
+  }
+  outShape.dim(inRank - 1) = inShape.dim(inRank - 1);
+  op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::FullyConnectedOp &op)
+{
+  fillInputShapes(node, op);
+  const Shape &inShape = op.getInputShape(0);
+  const Shape &wShape = op.getWeights().getShape();
+  const uint32_t weightsRank = wShape.rank();
+  const uint32_t inRank = inShape.rank();
+
+  assert(weightsRank >= 2);
+  assert(inRank == weightsRank);
+  assert(inShape.dim(inRank - 1) == wShape.dim(weightsRank - 2));
+  for (uint32_t i = 0; i < weightsRank - 2; ++i)
+  {
+    assert(wShape.dim(i) == inShape.dim(i));
+  }
+
+  Shape outShape = wShape;
+  outShape.dim(weightsRank - 1) = inShape.dim(weightsRank - 1);
+  outShape.dim(weightsRank - 2) = wShape.dim(weightsRank - 2);
+  op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::CappedReluOp &op)
+{
+  fillInputShapes(node, op);
+  op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::DepthwiseConv2DOp &op)
+{
+  fillInputShapes(node, op);
+
+  Shape outShape;
+  outShape.resize(3);
+  auto &strides = op.getStrides();
+  auto &kernelShape = op.getKernel().getShape();
+  auto &inShape = op.getInputShape(0);
+  int inRank = inShape.rank();
+
+  // Assuming input tensor is 3-dimensional. Will support more general cases when needed.
+  assert(inRank == 3);
+  assert(inShape.dim(2) == kernelShape.dim(2));
+
+  auto pads = calculate2DPaddings(op.getPaddingType(), inShape, kernelShape, strides, outShape);
+  for (uint32_t d = 0; d < inShape.rank(); ++d)
+  {
+    op.setPadding(d, pads[d]);
+  }
+
+  outShape.dim(inRank - 1) = inShape.dim(inRank - 1);
+  op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::BiasAddOp &op)
+{
+  fillInputShapes(node, op);
+  op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::ReshapeOp &op)
+{
+  // Reshape should have it's output shape filled by importer/user
+  fillInputShapes(node, op);
+}
+
+} // namespace model
+} // namespace IR
+} // namespace core
+} // namespace contrib
+} // namespace nncc