From 7f34d63f40c895f096102e320e3935a30570722a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Vladimir=20Plazun/AI=20Tools=20Lab=20/SRR/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 3 Jul 2018 16:42:24 +0400 Subject: [PATCH] [nnc core] Add ShapeInference class (#452) Add ShapeInference class Used to infer shapes of intermediate nodes based only on input shapes Signed-off-by: Vladimir Plazun --- .../nnc/core/IR/model/actions/ShapeInference.h | 44 ++++ .../src/core/IR/model/actions/ShapeInference.cpp | 234 +++++++++++++++++++++ 2 files changed, 278 insertions(+) create mode 100644 contrib/nnc/libs/core/include/nnc/core/IR/model/actions/ShapeInference.h create mode 100644 contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp diff --git a/contrib/nnc/libs/core/include/nnc/core/IR/model/actions/ShapeInference.h b/contrib/nnc/libs/core/include/nnc/core/IR/model/actions/ShapeInference.h new file mode 100644 index 0000000..3ae1474 --- /dev/null +++ b/contrib/nnc/libs/core/include/nnc/core/IR/model/actions/ShapeInference.h @@ -0,0 +1,44 @@ +#ifndef _NNC_CORE_IR_MODEL_SHAPE_INFERENCE_ +#define _NNC_CORE_IR_MODEL_SHAPE_INFERENCE_ + +#include "nnc/core/IR/model/visitor/visitor.h" +#include "nnc/core/IR/model/graph/ir_node.h" + +namespace nncc +{ +namespace contrib +{ +namespace core +{ +namespace IR +{ +namespace model +{ + +using namespace nncc::contrib::core::IR::model; + +class ShapeInference : public Visitor { + public: + void visit(ADT::INode::Ref node, ops::ConcatOp &op) override; + void visit(ADT::INode::Ref node, ops::Conv2DOp &op) override; + void visit(ADT::INode::Ref node, ops::DepthwiseConv2DOp &op) override; + void visit(ADT::INode::Ref node, ops::ReluOp &op) override; + void visit(ADT::INode::Ref node, ops::SoftmaxOp &op) override; + void visit(ADT::INode::Ref node, ops::PoolOp &op) override; + void visit(ADT::INode::Ref node, ops::FullyConnectedOp &op) override; + void visit(ADT::INode::Ref node, ops::CappedReluOp &op) override; + void visit(ADT::INode::Ref node, ops::BiasAddOp &op) override; + void visit(ADT::INode::Ref node, ops::ReshapeOp &op) override; + void visit(ADT::INode::Ref node, ops::VariableOp &op) override; + + protected: + void fillInputShapes(ADT::INode::Ref node, OpDescription &op); +}; + +} // namespace model +} // namespace IR +} // namespace core +} // namespace contrib +} // namespace nncc + +#endif //_NNC_CORE_IR_MODEL_SHAPE_INFERENCE_ diff --git a/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp b/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp new file mode 100644 index 0000000..d3751f5 --- /dev/null +++ b/contrib/nnc/libs/core/src/core/IR/model/actions/ShapeInference.cpp @@ -0,0 +1,234 @@ +#include + +#include "nnc/core/IR/model/actions/ShapeInference.h" + +#include "nnc/core/IR/model/operations/fully_connected_op.h" +#include "nnc/core/IR/model/operations/softmax_op.h" +#include "nnc/core/IR/model/operations/capped_relu_op.h" +#include "nnc/core/IR/model/operations/depthwise_conv2d_op.h" +#include "nnc/core/IR/model/operations/conv_2d_op.h" +#include "nnc/core/IR/model/operations/pool_op.h" +#include "nnc/core/IR/model/operations/variable_op.h" +#include "nnc/core/IR/model/operations/relu_op.h" +#include "nnc/core/IR/model/operations/concat_op.h" +#include "nnc/core/IR/model/operations/bias_add_op.h" +#include "nnc/core/IR/model/operations/reshape_op.h" + +namespace nncc +{ +namespace contrib +{ +namespace core +{ +namespace IR +{ +namespace model +{ + +using nncc::core::ADT::tensor::Shape; + +std::vector calculate2DPaddings(ops::PaddingType paddingType, const Shape& inShape, + const Shape& windowShape, const Shape& strides, Shape& outShape) +{ + auto inRank = inShape.rank(); + // Assuming input tensor is 3-dimensional. Will support more general cases as needed. + assert(inRank == 3); + std::vector paddings(3); + + if (paddingType == ops::PaddingType::Same) + { + for (uint32_t d = 0; d < inRank - 1; ++d) + { + outShape.dim(d) = (inShape.dim(d) - 1) / strides.dim(d) + 1; + int pad_along_axis; + if (inShape.dim(d) % strides.dim(d) == 0) + { + pad_along_axis = std::max((int)windowShape.dim(d) - (int)strides.dim(d), 0); + } + else + { + pad_along_axis = std::max((int)(outShape.dim(d) - 1) * (int)strides.dim(d) + + (int)windowShape.dim(d) - (int)inShape.dim(d), + 0); + } + paddings[d] = pad_along_axis / 2; + } + } + else + { + for (uint32_t d = 0; d < inRank - 1; ++d) + { + outShape.dim(d) = (inShape.dim(d) - windowShape.dim(d)) / strides.dim(d) + 1; + paddings[d] = 0; + } + } + + return paddings; +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::ConcatOp &op) +{ + fillInputShapes(node, op); + + uint32_t axis = op.getAxis(); + Shape outShape; + outShape.resize(op.getInputShape(0).rank()); + + for (uint32_t d = 0; d < outShape.rank(); ++d) + { + outShape.dim(d) = op.getInputShape(0).dim(d); + } + outShape.dim(axis) = 0; + + for (uint32_t i = 0; i < op.getNumInputs(); ++i) + { + outShape.dim(axis) += op.getInputShape(i).dim(axis); + } + + op.setOutputShape(0, outShape); +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::Conv2DOp &op) +{ + fillInputShapes(node, op); + + Shape outShape; + outShape.resize(3); + auto &strides = op.getStrides(); + auto &kernel = op.getKernel(); + auto &inShape = op.getInputShape(0); + auto &kernelShape = kernel.getShape(); + uint32_t inRank = inShape.rank(); + + auto pads = calculate2DPaddings(op.getPaddingType(), inShape, kernelShape, strides, outShape); + for (size_t i = 0; i < pads.size(); ++i) + { + op.setPadding(i, pads[i]); + } + + outShape.dim(inRank - 1) = kernelShape.dim(kernelShape.rank() - 1); + op.setOutputShape(0, outShape); +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::VariableOp &op) +{ + (void)op; + (void)node; + // No need to do anything for inputs. These should be set by user +} + +void ShapeInference::fillInputShapes(ADT::INode::Ref node, OpDescription &op) +{ + uint32_t i = 0; + for (auto &in : node->getPrevNodes()) + { + const Shape &inShape = in.node->getOperation()->getOutputShape(in.index); + op.setInputShape(i++, inShape); + } +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::ReluOp &op) +{ + fillInputShapes(node, op); + op.setOutputShape(0, op.getInputShape(0)); +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::SoftmaxOp &op) +{ + fillInputShapes(node, op); + op.setOutputShape(0, op.getInputShape(0)); +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::PoolOp &op) +{ + fillInputShapes(node, op); + + Shape outShape; + outShape.resize(3); + auto &strides = op.getStrides(); + auto &windowShape = op.getWindowShape(); + auto &inShape = op.getInputShape(0); + const uint32_t inRank = inShape.rank(); + + // Assuming input tensor is 3-dimensional. Will support more general cases when needed. + assert(inRank == 3); + + auto pads = calculate2DPaddings(op.getPaddingType(), inShape, windowShape, strides, outShape); + for (uint32_t d = 0; d < inShape.rank(); ++d) + { + op.setPadding(d, pads[d]); + } + outShape.dim(inRank - 1) = inShape.dim(inRank - 1); + op.setOutputShape(0, outShape); +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::FullyConnectedOp &op) +{ + fillInputShapes(node, op); + const Shape &inShape = op.getInputShape(0); + const Shape &wShape = op.getWeights().getShape(); + const uint32_t weightsRank = wShape.rank(); + const uint32_t inRank = inShape.rank(); + + assert(weightsRank >= 2); + assert(inRank == weightsRank); + assert(inShape.dim(inRank - 1) == wShape.dim(weightsRank - 2)); + for (uint32_t i = 0; i < weightsRank - 2; ++i) + { + assert(wShape.dim(i) == inShape.dim(i)); + } + + Shape outShape = wShape; + outShape.dim(weightsRank - 1) = inShape.dim(weightsRank - 1); + outShape.dim(weightsRank - 2) = wShape.dim(weightsRank - 2); + op.setOutputShape(0, outShape); +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::CappedReluOp &op) +{ + fillInputShapes(node, op); + op.setOutputShape(0, op.getInputShape(0)); +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::DepthwiseConv2DOp &op) +{ + fillInputShapes(node, op); + + Shape outShape; + outShape.resize(3); + auto &strides = op.getStrides(); + auto &kernelShape = op.getKernel().getShape(); + auto &inShape = op.getInputShape(0); + int inRank = inShape.rank(); + + // Assuming input tensor is 3-dimensional. Will support more general cases when needed. + assert(inRank == 3); + assert(inShape.dim(2) == kernelShape.dim(2)); + + auto pads = calculate2DPaddings(op.getPaddingType(), inShape, kernelShape, strides, outShape); + for (uint32_t d = 0; d < inShape.rank(); ++d) + { + op.setPadding(d, pads[d]); + } + + outShape.dim(inRank - 1) = inShape.dim(inRank - 1); + op.setOutputShape(0, outShape); +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::BiasAddOp &op) +{ + fillInputShapes(node, op); + op.setOutputShape(0, op.getInputShape(0)); +} + +void ShapeInference::visit(ADT::INode::Ref node, ops::ReshapeOp &op) +{ + // Reshape should have it's output shape filled by importer/user + fillInputShapes(node, op); +} + +} // namespace model +} // namespace IR +} // namespace core +} // namespace contrib +} // namespace nncc -- 2.7.4