--- /dev/null
+#include <cmath>
+
+#include "nnc/core/IR/model/actions/ShapeInference.h"
+
+#include "nnc/core/IR/model/operations/fully_connected_op.h"
+#include "nnc/core/IR/model/operations/softmax_op.h"
+#include "nnc/core/IR/model/operations/capped_relu_op.h"
+#include "nnc/core/IR/model/operations/depthwise_conv2d_op.h"
+#include "nnc/core/IR/model/operations/conv_2d_op.h"
+#include "nnc/core/IR/model/operations/pool_op.h"
+#include "nnc/core/IR/model/operations/variable_op.h"
+#include "nnc/core/IR/model/operations/relu_op.h"
+#include "nnc/core/IR/model/operations/concat_op.h"
+#include "nnc/core/IR/model/operations/bias_add_op.h"
+#include "nnc/core/IR/model/operations/reshape_op.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace IR
+{
+namespace model
+{
+
+using nncc::core::ADT::tensor::Shape;
+
+std::vector<int> calculate2DPaddings(ops::PaddingType paddingType, const Shape& inShape,
+ const Shape& windowShape, const Shape& strides, Shape& outShape)
+{
+ auto inRank = inShape.rank();
+ // Assuming input tensor is 3-dimensional. Will support more general cases as needed.
+ assert(inRank == 3);
+ std::vector<int> paddings(3);
+
+ if (paddingType == ops::PaddingType::Same)
+ {
+ for (uint32_t d = 0; d < inRank - 1; ++d)
+ {
+ outShape.dim(d) = (inShape.dim(d) - 1) / strides.dim(d) + 1;
+ int pad_along_axis;
+ if (inShape.dim(d) % strides.dim(d) == 0)
+ {
+ pad_along_axis = std::max((int)windowShape.dim(d) - (int)strides.dim(d), 0);
+ }
+ else
+ {
+ pad_along_axis = std::max((int)(outShape.dim(d) - 1) * (int)strides.dim(d) +
+ (int)windowShape.dim(d) - (int)inShape.dim(d),
+ 0);
+ }
+ paddings[d] = pad_along_axis / 2;
+ }
+ }
+ else
+ {
+ for (uint32_t d = 0; d < inRank - 1; ++d)
+ {
+ outShape.dim(d) = (inShape.dim(d) - windowShape.dim(d)) / strides.dim(d) + 1;
+ paddings[d] = 0;
+ }
+ }
+
+ return paddings;
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::ConcatOp &op)
+{
+ fillInputShapes(node, op);
+
+ uint32_t axis = op.getAxis();
+ Shape outShape;
+ outShape.resize(op.getInputShape(0).rank());
+
+ for (uint32_t d = 0; d < outShape.rank(); ++d)
+ {
+ outShape.dim(d) = op.getInputShape(0).dim(d);
+ }
+ outShape.dim(axis) = 0;
+
+ for (uint32_t i = 0; i < op.getNumInputs(); ++i)
+ {
+ outShape.dim(axis) += op.getInputShape(i).dim(axis);
+ }
+
+ op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::Conv2DOp &op)
+{
+ fillInputShapes(node, op);
+
+ Shape outShape;
+ outShape.resize(3);
+ auto &strides = op.getStrides();
+ auto &kernel = op.getKernel();
+ auto &inShape = op.getInputShape(0);
+ auto &kernelShape = kernel.getShape();
+ uint32_t inRank = inShape.rank();
+
+ auto pads = calculate2DPaddings(op.getPaddingType(), inShape, kernelShape, strides, outShape);
+ for (size_t i = 0; i < pads.size(); ++i)
+ {
+ op.setPadding(i, pads[i]);
+ }
+
+ outShape.dim(inRank - 1) = kernelShape.dim(kernelShape.rank() - 1);
+ op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::VariableOp &op)
+{
+ (void)op;
+ (void)node;
+ // No need to do anything for inputs. These should be set by user
+}
+
+void ShapeInference::fillInputShapes(ADT::INode::Ref node, OpDescription &op)
+{
+ uint32_t i = 0;
+ for (auto &in : node->getPrevNodes())
+ {
+ const Shape &inShape = in.node->getOperation()->getOutputShape(in.index);
+ op.setInputShape(i++, inShape);
+ }
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::ReluOp &op)
+{
+ fillInputShapes(node, op);
+ op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::SoftmaxOp &op)
+{
+ fillInputShapes(node, op);
+ op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::PoolOp &op)
+{
+ fillInputShapes(node, op);
+
+ Shape outShape;
+ outShape.resize(3);
+ auto &strides = op.getStrides();
+ auto &windowShape = op.getWindowShape();
+ auto &inShape = op.getInputShape(0);
+ const uint32_t inRank = inShape.rank();
+
+ // Assuming input tensor is 3-dimensional. Will support more general cases when needed.
+ assert(inRank == 3);
+
+ auto pads = calculate2DPaddings(op.getPaddingType(), inShape, windowShape, strides, outShape);
+ for (uint32_t d = 0; d < inShape.rank(); ++d)
+ {
+ op.setPadding(d, pads[d]);
+ }
+ outShape.dim(inRank - 1) = inShape.dim(inRank - 1);
+ op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::FullyConnectedOp &op)
+{
+ fillInputShapes(node, op);
+ const Shape &inShape = op.getInputShape(0);
+ const Shape &wShape = op.getWeights().getShape();
+ const uint32_t weightsRank = wShape.rank();
+ const uint32_t inRank = inShape.rank();
+
+ assert(weightsRank >= 2);
+ assert(inRank == weightsRank);
+ assert(inShape.dim(inRank - 1) == wShape.dim(weightsRank - 2));
+ for (uint32_t i = 0; i < weightsRank - 2; ++i)
+ {
+ assert(wShape.dim(i) == inShape.dim(i));
+ }
+
+ Shape outShape = wShape;
+ outShape.dim(weightsRank - 1) = inShape.dim(weightsRank - 1);
+ outShape.dim(weightsRank - 2) = wShape.dim(weightsRank - 2);
+ op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::CappedReluOp &op)
+{
+ fillInputShapes(node, op);
+ op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::DepthwiseConv2DOp &op)
+{
+ fillInputShapes(node, op);
+
+ Shape outShape;
+ outShape.resize(3);
+ auto &strides = op.getStrides();
+ auto &kernelShape = op.getKernel().getShape();
+ auto &inShape = op.getInputShape(0);
+ int inRank = inShape.rank();
+
+ // Assuming input tensor is 3-dimensional. Will support more general cases when needed.
+ assert(inRank == 3);
+ assert(inShape.dim(2) == kernelShape.dim(2));
+
+ auto pads = calculate2DPaddings(op.getPaddingType(), inShape, kernelShape, strides, outShape);
+ for (uint32_t d = 0; d < inShape.rank(); ++d)
+ {
+ op.setPadding(d, pads[d]);
+ }
+
+ outShape.dim(inRank - 1) = inShape.dim(inRank - 1);
+ op.setOutputShape(0, outShape);
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::BiasAddOp &op)
+{
+ fillInputShapes(node, op);
+ op.setOutputShape(0, op.getInputShape(0));
+}
+
+void ShapeInference::visit(ADT::INode::Ref node, ops::ReshapeOp &op)
+{
+ // Reshape should have it's output shape filled by importer/user
+ fillInputShapes(node, op);
+}
+
+} // namespace model
+} // namespace IR
+} // namespace core
+} // namespace contrib
+} // namespace nncc