[nnc] Perform shape inference at construction time (#2399)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Thu, 29 Nov 2018 13:42:35 +0000 (16:42 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 29 Nov 2018 13:42:35 +0000 (16:42 +0300)
* Move shape inference functionality from separate class to constructors of individual operations;
* First dimension of input is no longer removed by the importers;
* Adjust ModelIR, soft backend and interpreter to correctly work with non-stripped first dimension;
* Minor coding style fixes and comments.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
62 files changed:
contrib/nnc/core/CMakeLists.txt
contrib/nnc/core/modelIR/Operation.cpp
contrib/nnc/core/modelIR/ShapeInference.cpp [deleted file]
contrib/nnc/core/modelIR/operations/ConcatOp.cpp [new file with mode: 0644]
contrib/nnc/core/modelIR/operations/Conv2DOp.cpp [new file with mode: 0644]
contrib/nnc/core/modelIR/operations/DeConv2DOp.cpp [new file with mode: 0644]
contrib/nnc/core/modelIR/operations/DepthwiseConv2DOp.cpp [new file with mode: 0644]
contrib/nnc/core/modelIR/operations/FullyConnectedOp.cpp [new file with mode: 0644]
contrib/nnc/core/modelIR/operations/PadOp.cpp [moved from contrib/nnc/passes/common_frontend/shape_helper.cpp with 50% similarity]
contrib/nnc/core/modelIR/operations/PoolOp.cpp [new file with mode: 0644]
contrib/nnc/core/modelIR/operations/SqueezeOp.cpp [new file with mode: 0644]
contrib/nnc/include/core/modelIR/Operation.h
contrib/nnc/include/core/modelIR/ShapeInference.h [deleted file]
contrib/nnc/include/core/modelIR/operations/BatchNormOp.h
contrib/nnc/include/core/modelIR/operations/BiasAddOp.h
contrib/nnc/include/core/modelIR/operations/CappedReluOp.h
contrib/nnc/include/core/modelIR/operations/ConcatOp.h
contrib/nnc/include/core/modelIR/operations/ConstantOp.h
contrib/nnc/include/core/modelIR/operations/Conv2DOp.h
contrib/nnc/include/core/modelIR/operations/Deconv2DOp.h
contrib/nnc/include/core/modelIR/operations/DepthwiseConv2DOp.h
contrib/nnc/include/core/modelIR/operations/DropoutOp.h
contrib/nnc/include/core/modelIR/operations/ElementwiseOp.h
contrib/nnc/include/core/modelIR/operations/EluOp.h
contrib/nnc/include/core/modelIR/operations/FullyConnectedOp.h
contrib/nnc/include/core/modelIR/operations/PadOp.h
contrib/nnc/include/core/modelIR/operations/PoolOp.h
contrib/nnc/include/core/modelIR/operations/ReduceFOp.h
contrib/nnc/include/core/modelIR/operations/ReluOp.h
contrib/nnc/include/core/modelIR/operations/ReshapeOp.h
contrib/nnc/include/core/modelIR/operations/ResizeOp.h
contrib/nnc/include/core/modelIR/operations/ScaleOp.h
contrib/nnc/include/core/modelIR/operations/SoftmaxOp.h
contrib/nnc/include/core/modelIR/operations/SqueezeOp.h
contrib/nnc/include/core/modelIR/operations/TanhOp.h
contrib/nnc/include/passes/common_frontend/shape_helper.h
contrib/nnc/passes/acl_soft_backend/AclCppGenerator.cpp
contrib/nnc/passes/caffe_frontend/caffe_importer.cpp
contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp
contrib/nnc/passes/caffe_frontend/caffe_op_creator.h
contrib/nnc/passes/common_frontend/CMakeLists.txt
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/interpreter_pass.cpp
contrib/nnc/passes/interpreter/ops/Depthwise_conv_2D.cpp
contrib/nnc/passes/interpreter/ops/Pool.cpp
contrib/nnc/passes/interpreter/ops/conv_2D.cpp
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp
contrib/nnc/passes/soft_backend/BaseGenerator.cpp
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/tests/interpreter/graph_creator.cpp
contrib/nnc/tests/interpreter/op_info_util.h
contrib/nnc/tests/soft_backend/CompileCPP.cpp
contrib/nnc/unittests/core/Graph.cpp
contrib/nnc/unittests/core/ShapeInference.cpp
contrib/nnc/unittests/core/operation.cpp
contrib/nnc/unittests/soft_backend/CPPOperations.cpp
contrib/nnc/unittests/soft_backend/ModelAnalyzer.cpp
contrib/nnc/utils/caffe2_dot_dumper/model_dump.cpp
contrib/nnc/utils/caffe_dot_dumper/model_dump.cpp
contrib/nnc/utils/tflite_dot_dumper/sanity_check.cpp

index 3b69451..598092d 100644 (file)
@@ -1,11 +1,18 @@
-set(SOURCES "modelIR/Graph.cpp"
+set(SOURCES "modelIR/operations/ConcatOp.cpp"
+            "modelIR/operations/Conv2DOp.cpp"
+            "modelIR/operations/DeConv2DOp.cpp"
+            "modelIR/operations/DepthwiseConv2DOp.cpp"
+            "modelIR/operations/FullyConnectedOp.cpp"
+            "modelIR/operations/PadOp.cpp"
+            "modelIR/operations/PoolOp.cpp"
+            "modelIR/operations/SqueezeOp.cpp"
+            "modelIR/Graph.cpp"
             "modelIR/Index.cpp"
             "modelIR/ir_dot_builder.cpp"
             "modelIR/IrDotDumper.cpp"
             "modelIR/ir_dot_node_info.cpp"
             "modelIR/Operation.cpp"
             "modelIR/Shape.cpp"
-            "modelIR/ShapeInference.cpp"
             "modelIR/Tensor.cpp"
             "modelIR/TensorVariant.cpp"
             "modelIR/Visitor.cpp")
index dc501ed..90598fb 100644 (file)
@@ -56,13 +56,9 @@ const IODescriptor Operation::getOutput(std::size_t index) {
 }
 
 const Shape& Operation::getInputShape(std::size_t index) const {
-  assert(index < getNumInputs());
-  return _inputShapes.at(index);
-}
-
-void Operation::setInputShape(std::size_t index, const Shape& shape) {
-  assert(index < getNumInputs());
-  _inputShapes[index] = shape;
+  // Shape of the input is the shape of the connected output.
+  IODescriptor descriptor = _inputs.at(index);
+  return descriptor.op->getOutputShape(descriptor.index);
 }
 
 const Shape& Operation::getOutputShape(std::size_t index) const {
diff --git a/contrib/nnc/core/modelIR/ShapeInference.cpp b/contrib/nnc/core/modelIR/ShapeInference.cpp
deleted file mode 100644 (file)
index 34f48ca..0000000
+++ /dev/null
@@ -1,448 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cmath>
-#include <algorithm>
-
-#include "core/modelIR/ShapeInference.h"
-
-#include "core/modelIR/operations/FullyConnectedOp.h"
-#include "core/modelIR/operations/SoftmaxOp.h"
-#include "core/modelIR/operations/CappedReluOp.h"
-#include "core/modelIR/operations/DepthwiseConv2DOp.h"
-#include "core/modelIR/operations/ConstantOp.h"
-#include "core/modelIR/operations/Conv2DOp.h"
-#include "core/modelIR/operations/Deconv2DOp.h"
-#include "core/modelIR/operations/PoolOp.h"
-#include "core/modelIR/operations/VariableOp.h"
-#include "core/modelIR/operations/ReluOp.h"
-#include "core/modelIR/operations/EluOp.h"
-#include "core/modelIR/operations/ConcatOp.h"
-#include "core/modelIR/operations/BiasAddOp.h"
-#include "core/modelIR/operations/ReshapeOp.h"
-#include "core/modelIR/operations/ResizeOp.h"
-#include "core/modelIR/operations/BatchNormOp.h"
-#include "core/modelIR/operations/ScaleOp.h"
-#include "core/modelIR/operations/DropoutOp.h"
-#include "core/modelIR/operations/TanhOp.h"
-#include "core/modelIR/operations/ElementwiseOp.h"
-#include "core/modelIR/operations/SqueezeOp.h"
-#include "core/modelIR/operations/PadOp.h"
-#include "core/modelIR/operations/ReduceFOp.h"
-
-namespace nnc {
-namespace mir {
-
-using nnc::mir::Shape;
-
-template <class Op>
-void fillHWShapesForPaddedOperations(Op& op, const Shape& windowShape, Shape& outShape) {
-  auto& strides = op.getStrides();
-  auto& inShape = op.getInputShape(0);
-  auto inRank = inShape.rank();
-  outShape.resize(inRank);
-
-  ops::PaddingType pType = op.getPaddingType();
-  switch (pType) {
-    case ops::PaddingType::Same:
-      for (int32_t d = 0; d < inRank - 1; ++d) {
-        outShape.dim(d) = (inShape.dim(d) - 1) / strides.dim(d) + 1;
-        int pad_along_axis;
-        if (inShape.dim(d) % strides.dim(d) == 0) {
-          pad_along_axis = std::max(( int ) windowShape.dim(d) - ( int ) strides.dim(d), 0);
-        } else {
-          pad_along_axis = std::max(( int ) (outShape.dim(d) - 1) * ( int ) strides.dim(d) +
-                                    ( int ) windowShape.dim(d) - ( int ) inShape.dim(d),
-                                    0);
-        }
-        op.setPadding(d, pad_along_axis / 2);
-      }
-      break;
-    case ops::PaddingType::Valid:
-      for (int32_t d = 0; d < inRank - 1; ++d) {
-        op.setPadding(d, 0);
-      }
-      // FALLTHROUGH
-    case ops::PaddingType::Custom:
-      for (int32_t d = 0; d < inRank - 1; ++d) {
-        outShape.dim(d) =
-          (inShape.dim(d) + 2 * op.getPadding(d) - windowShape.dim(d)) / strides.dim(d) + 1;
-      }
-      break;
-    default:
-      assert(false && "invalid padding type");
-      break;
-  }
-  // For now padding for channels is not supported, initialize it with zero
-  op.setPadding(inRank - 1, 0);
-}
-
-void ShapeInference::visit(ops::ConcatOp& op) {
-  fillInputShapes(op);
-
-  int32_t axis = op.getAxis();
-  Shape outShape;
-  outShape.resize(op.getInputShape(0).rank());
-
-  for (int32_t d = 0; d < outShape.rank(); ++d) {
-    outShape.dim(d) = op.getInputShape(0).dim(d);
-  }
-  outShape.dim(axis) = 0;
-
-  for (size_t i = 0; i < op.getNumInputs(); ++i) {
-    outShape.dim(axis) += op.getInputShape(i).dim(axis);
-  }
-
-  op.setOutputShape(0, outShape);
-}
-
-void ShapeInference::visit(ops::Conv2DOp& op) {
-  fillInputShapes(op);
-
-  Shape outShape;
-  auto& kernel = op.getKernel();
-  auto& kernelShape = kernel.getShape();
-
-  fillHWShapesForPaddedOperations(op, kernelShape, outShape);
-
-  outShape.dim(outShape.rank() - 1) = kernelShape.dim(kernelShape.rank() - 1);
-  op.setOutputShape(0, outShape);
-}
-
-void ShapeInference::visit(ops::VariableOp&) {
-  // No need to do anything for inputs. These should be set by user
-}
-
-void ShapeInference::visit(ops::ConstantOp&) {
-}
-
-void ShapeInference::fillInputShapes(Operation& op) {
-  size_t i = 0;
-  for (auto& in : op.getPrevNodes()) {
-    const Shape& inShape = in.op->getOutputShape(in.index);
-    op.setInputShape(i++, inShape);
-  }
-}
-
-void ShapeInference::visit(ops::ReluOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::ResizeOp& op) {
-  fillInputShapes(op);
-  const auto& in_s = op.getInputShape(0);
-  Shape out_s = in_s;
-  auto res_s = op.getResultShape();
-  const std::vector<float>& scales = op.getScales();
-
-  if (scales.size() > 0) {
-    assert(
-      in_s.rank() == static_cast<int32_t>(scales.size())
-      && "Scaling parameter incompatible with input shape");
-    for (int32_t i = 0; i < in_s.rank(); i++) {
-      out_s.dim(i) = (int32_t)lroundf(scales[i] * in_s.dim(i));
-    }
-  } else {
-    // Assume batch is cut off
-    assert(in_s.rank() == 3);
-    out_s.dim(0) = res_s.dim(0);
-    out_s.dim(1) = res_s.dim(1);
-    out_s.dim(2) = in_s.dim(2);
-    op.setScales({static_cast<float> (out_s.dim(0)) / in_s.dim(0),
-                  static_cast<float> (out_s.dim(1)) / in_s.dim(1), 1.0f});
-  }
-  op.setOutputShape(0, out_s);
-}
-
-void ShapeInference::visit(ops::SoftmaxOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::PoolOp& op) {
-  fillInputShapes(op);
-
-  Shape outShape;
-  auto& windowShape = op.getWindowShape();
-  auto& inShape = op.getInputShape(0);
-  const int32_t inRank = inShape.rank();
-  // Assuming input tensor is 3-dimensional. Will support more general cases when needed.
-  assert(inRank == 3);
-
-  fillHWShapesForPaddedOperations(op, windowShape, outShape);
-
-  outShape.dim(inRank - 1) = inShape.dim(inRank - 1);
-  op.setOutputShape(0, outShape);
-}
-
-void ShapeInference::visit(ops::FullyConnectedOp& op) {
-  fillInputShapes(op);
-  const Shape& inShape = op.getInputShape(0);
-  const Shape& wShape = op.getWeights().getShape();
-  const int32_t weightsRank = wShape.rank();
-  const int32_t inRank = inShape.rank();
-
-  assert(weightsRank >= 2);
-  assert(inRank == weightsRank);
-  assert(inShape.dim(inRank - 1) == wShape.dim(weightsRank - 2));
-  ( void ) inRank;
-  for (int32_t i = 0; i < weightsRank - 2; ++i) {
-    assert(wShape.dim(i) == inShape.dim(i));
-  }
-
-  Shape outShape = wShape;
-  outShape.dim(weightsRank - 1) = wShape.dim(weightsRank - 1);
-  outShape.dim(weightsRank - 2) = inShape.dim(weightsRank - 2);
-  op.setOutputShape(0, outShape);
-}
-
-void ShapeInference::visit(ops::CappedReluOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::DepthwiseConv2DOp& op) {
-  fillInputShapes(op);
-
-  Shape outShape;
-  auto& kernelShape = op.getKernel().getShape();
-  auto& inShape = op.getInputShape(0);
-  int inRank = inShape.rank();
-  int kernelRank = kernelShape.rank();
-
-  // Assuming input tensor is 3-dimensional. Will support more general cases when needed.
-  assert(inRank == 3);
-  assert(inShape.dim(2) == kernelShape.dim(2));
-
-  fillHWShapesForPaddedOperations(op, kernelShape, outShape);
-
-  outShape.dim(inRank - 1) = inShape.dim(inRank - 1) * kernelShape.dim(kernelRank - 1);
-  op.setOutputShape(0, outShape);
-}
-
-void ShapeInference::visit(ops::BiasAddOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::ReshapeOp& op) {
-  // Reshape should have it's output shape filled by importer/user
-  fillInputShapes(op);
-  auto& inShape = op.getInputShape(0);
-  auto outShape = op.getOutputShape(0);
-
-  auto inElementsNum = inShape.numElements();
-  int32_t outElementsNum = 1;
-  //can't use num_elements due to -1 in input shape and Shape using unsigned ints for dimensions
-  for (int32_t d = 0; d < outShape.rank(); ++d) {
-    auto dim = outShape.dim(d);
-    if( dim != Shape::autoDim) {
-      outElementsNum *= dim;
-    }
-  }
-
-  for (int32_t d = 0; d < outShape.rank(); ++d) {
-    auto& dim = outShape.dim(d);
-    if( dim == Shape::autoDim ) {
-      dim = static_cast<int32_t>(inElementsNum / outElementsNum);
-    }
-  }
-
-  op.setOutputShape(0, outShape);
-}
-
-void ShapeInference::visit(ops::ScaleOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::DropoutOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::BatchNormOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::DeConv2DOp& op) {
-  /**
-  see https://github.com/tensorflow/tensorflow/issues/2118
-   for reason why the output shape is what it is.
-
-   output = input * stride + filter - stride  # VALID
-   output = input * stride - stride + 1  # SAME
- */
-  fillInputShapes(op);
-
-  Shape out_shape;
-  Shape in_shape = op.getInputShape(0);
-  auto& kernel = op.getKernel();
-  auto& kernel_shape = kernel.getShape();
-
-  assert(kernel_shape.rank() == 4);
-  assert(in_shape.rank() == 3);
-  assert(kernel_shape.dim(3) == in_shape.dim(2));
-
-  auto pad_type = op.getPaddingType();
-  auto in_rank = in_shape.rank();
-  auto strides = op.getStrides();
-  out_shape.resize(in_rank);
-
-  switch (pad_type) {
-    case ops::PaddingType::Same:
-      for (int32_t d = 0; d < in_rank; ++d) {
-        out_shape.dim(d) = in_shape.dim(d) * strides.dim(d) + 1 - strides.dim(d);
-      }
-      break;
-    case ops::PaddingType::Valid:
-      for (int32_t d = 0; d < in_rank; ++d) {
-        out_shape.dim(d) = in_shape.dim(d) * strides.dim(d) + kernel_shape.dim(d) - strides.dim(d);
-      }
-      break;
-    case ops::PaddingType::Custom:
-      for (int32_t d = 0; d < in_rank - 1; ++d) {
-        out_shape.dim(d) =
-          (in_shape.dim(d) - 1) * strides.dim(d) - 2 * op.getPadding(d) + kernel_shape.dim(d);
-      }
-      break;
-    default: {
-      assert(false && "invalid padding type");
-    }
-  }
-  out_shape.dim(-1) = kernel_shape.dim(-2);
-  op.setOutputShape(0, out_shape);
-}
-
-void ShapeInference::visit(ops::EluOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::TanhOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::ElementwiseOp& op) {
-  fillInputShapes(op);
-  op.setOutputShape(0, op.getInputShape(0));
-}
-
-void ShapeInference::visit(ops::SqueezeOp& op) {
-  fillInputShapes(op);
-  assert(op.getNumInputs() == 1);
-
-  const auto& input_shape = op.getInputShape(0);
-  int32_t input_rank = input_shape.rank();
-  Shape output_shape;
-  int32_t output_rank = 0;
-
-  std::vector<int32_t> dims_to_squeeze;
-
-  if (op.getNumSqueezeDims() == 0) {
-    for (int32_t i = 0; i < input_rank; ++i) {
-      if (input_shape.dim(i) == 1) {
-        dims_to_squeeze.push_back(i);
-      }
-    }
-  } else {
-    dims_to_squeeze = op.getDimsToSqueeze();
-    std::sort(dims_to_squeeze.begin(), dims_to_squeeze.end());
-    dims_to_squeeze.erase(
-      std::unique(dims_to_squeeze.begin(), dims_to_squeeze.end()),
-      dims_to_squeeze.end()
-    );
-  }
-
-  if (dims_to_squeeze.size() == static_cast<size_t>(input_rank)) {
-    //Input shape have 1s in all dimensions, output shape is (1,)
-    op.setOutputShape(0, Shape{1});
-    return;
-  }
-
-  size_t squeezing_idx = 0;
-  output_shape.resize(input_rank - dims_to_squeeze.size());
-  for (int32_t i = 0; i < input_rank; ++i) {
-    if (squeezing_idx < dims_to_squeeze.size() && i == dims_to_squeeze[squeezing_idx]) {
-      if (input_shape.dim(i) != 1)
-        throw std::invalid_argument("All squeezed dimensions should have size 1");
-
-      squeezing_idx++;
-    } else {
-      output_shape.dim(output_rank++) = input_shape.dim(i);
-    }
-  }
-
-  op.setOutputShape(0, output_shape);
-}
-
-void ShapeInference::visit(ops::PadOp& op) {
-  /**
-  padded size of each dimension D of the output is:
-  paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]
- */
-  fillInputShapes(op);
-
-  const Shape& in_shape = op.getInputShape(0);
-  Shape out_shape;
-
-  assert(in_shape.rank() == op.getNumDim());
-
-  int32_t num_dim = in_shape.rank();
-  out_shape.resize(num_dim);
-
-  for (int32_t dim = 0; dim < num_dim; dim++) {
-    std::pair<int32_t, int32_t> padding = op.getPaddingForDim(dim);
-    out_shape.dim(dim) = padding.first + in_shape.dim(dim) + padding.second;
-  }
-
-  op.setOutputShape(0, out_shape);
-}
-
-void ShapeInference::visit(ops::ReduceFOp& op) {
-  fillInputShapes(op);
-  assert(op.getNumInputs() == 1);
-
-  const auto& input_shape = op.getInputShape(0);
-  const auto& red_dims = op.getReductionDims();
-  Shape output_shape;
-  if (op.getKeepDims()) {
-    output_shape = input_shape;
-    for (auto red_axis: red_dims) {
-      output_shape.dim(red_axis) = 1;
-    }
-  } else {
-    std::vector<int32_t> out_dims;
-    out_dims.reserve(input_shape.rank() - op.getReductionDims().size());
-    auto red_axis = red_dims.begin();
-    for (int32_t axis_id = 0; axis_id < input_shape.rank(); axis_id++) {
-      if (axis_id == (*red_axis)) {
-        red_axis++;
-      } else {
-        out_dims.emplace_back(input_shape.dim(axis_id));
-      }
-    }
-    output_shape = Shape(out_dims);
-  }
-
-  op.setOutputShape(0, output_shape);
-}
-
-} // namespace mir
-} // namespace nnc
diff --git a/contrib/nnc/core/modelIR/operations/ConcatOp.cpp b/contrib/nnc/core/modelIR/operations/ConcatOp.cpp
new file mode 100644 (file)
index 0000000..cc617fa
--- /dev/null
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/ConcatOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+void ConcatOp::inferOutputShapes() {
+  Shape output_shape(getInputShape(0));
+  output_shape.dim(_axis) = 0;
+  for (std::size_t i = 0; i < getNumInputs(); ++i)
+    output_shape.dim(_axis) += getInputShape(i).dim(_axis);
+  setOutputShape(0, output_shape);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
diff --git a/contrib/nnc/core/modelIR/operations/Conv2DOp.cpp b/contrib/nnc/core/modelIR/operations/Conv2DOp.cpp
new file mode 100644 (file)
index 0000000..aea018d
--- /dev/null
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/Conv2DOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+void Conv2DOp::inferOutputShapes() {
+  auto& input_shape = getInputShape(0);
+  auto& kernel_shape = getKernel().getShape();
+  auto& strides = getStrides();
+  auto input_rank = input_shape.rank();
+
+  assert(input_rank == 4);
+
+  Shape output_shape;
+  output_shape.resize(input_rank);
+
+  switch (getPaddingType()) {
+    case ops::PaddingType::Same: {
+      output_shape.dim(1) = (input_shape.dim(1) - 1) / strides.dim(0) + 1;
+      output_shape.dim(2) = (input_shape.dim(2) - 1) / strides.dim(1) + 1;
+      int32_t pad_along_height = (input_shape.dim(1) % strides.dim(0) == 0)
+          ? std::max(kernel_shape.dim(0) - strides.dim(0), 0)
+          : std::max((output_shape.dim(1) - 1) * strides.dim(0) + kernel_shape.dim(0) - input_shape.dim(1), 0);
+      int32_t pad_along_width = (input_shape.dim(2) % strides.dim(0) == 0)
+          ? std::max(kernel_shape.dim(1) - strides.dim(1), 0)
+          : std::max((output_shape.dim(2) - 1) * strides.dim(1) + kernel_shape.dim(1) - input_shape.dim(2), 0);
+      _paddings.at(0) = pad_along_height / 2;
+      _paddings.at(1) = pad_along_width / 2;
+      break;
+    }
+    case ops::PaddingType::Valid:
+      _paddings.at(0) = 0;
+      _paddings.at(1) = 0;
+      // FALLTHROUGH
+    case ops::PaddingType::Custom:
+      output_shape.dim(1) = (input_shape.dim(1) + 2 * getPadding(0) - kernel_shape.dim(0)) / strides.dim(0) + 1;
+      output_shape.dim(2) = (input_shape.dim(2) + 2 * getPadding(1) - kernel_shape.dim(1)) / strides.dim(1) + 1;
+      break;
+    default:
+      assert(false && "invalid padding type");
+      break;
+  }
+  // For now padding for channels is not supported, initialize it with zero
+  _paddings.at(2) = 0;
+
+  output_shape.dim(0) = input_shape.dim(0);
+  output_shape.dim(3) = kernel_shape.dim(3);
+  setOutputShape(0, output_shape);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
diff --git a/contrib/nnc/core/modelIR/operations/DeConv2DOp.cpp b/contrib/nnc/core/modelIR/operations/DeConv2DOp.cpp
new file mode 100644 (file)
index 0000000..2ce2739
--- /dev/null
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/Deconv2DOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+// See https://github.com/tensorflow/tensorflow/issues/2118
+// VALID: output = input * stride + filter - stride
+// SAME: output = input * stride - stride + 1
+void DeConv2DOp::inferOutputShapes() {
+  auto& input_shape = getInputShape(0);
+  auto& kernel_shape = getKernel().getShape();
+  auto& strides = getStrides();
+  auto input_rank = input_shape.rank();
+
+  assert(input_rank == 3);
+  assert(kernel_shape.rank() == 4);
+  assert(kernel_shape.dim(3) == input_shape.dim(2));
+
+  Shape output_shape;
+  output_shape.resize(input_rank);
+
+  switch (_paddingType) {
+    case ops::PaddingType::Same:
+      for (int32_t d = 0; d < input_rank; ++d)
+        output_shape.dim(d) = input_shape.dim(d) * strides.dim(d) - strides.dim(d) + 1;
+      break;
+    case ops::PaddingType::Valid:
+      for (int32_t d = 0; d < input_rank; ++d)
+        output_shape.dim(d) =
+            input_shape.dim(d) * strides.dim(d) + kernel_shape.dim(d) - strides.dim(d);
+      break;
+    case ops::PaddingType::Custom:
+      for (int32_t d = 0; d < input_rank - 1; ++d)
+        output_shape.dim(d) =
+            input_shape.dim(d) * strides.dim(d) + kernel_shape.dim(d) - strides.dim(d) -
+            2 * getPadding(d);
+      break;
+    default: {
+      assert(false && "invalid padding type");
+    }
+  }
+
+  output_shape.dim(-1) = kernel_shape.dim(-2);
+  setOutputShape(0, output_shape);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
diff --git a/contrib/nnc/core/modelIR/operations/DepthwiseConv2DOp.cpp b/contrib/nnc/core/modelIR/operations/DepthwiseConv2DOp.cpp
new file mode 100644 (file)
index 0000000..3129713
--- /dev/null
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/DepthwiseConv2DOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+void DepthwiseConv2DOp::inferOutputShapes() {
+  auto& input_shape = getInputShape(0);
+  auto& kernel_shape = getKernel().getShape();
+  auto& strides = getStrides();
+  auto input_rank = input_shape.rank();
+
+  // Assuming input tensor is 3-dimensional. Will support more general cases when needed.
+  assert(input_rank == 4);
+  assert(input_shape.dim(3) == kernel_shape.dim(2));
+
+  Shape output_shape;
+  output_shape.resize(input_rank);
+
+  switch (getPaddingType()) {
+    case ops::PaddingType::Same: {
+      output_shape.dim(1) = (input_shape.dim(1) - 1) / strides.dim(0) + 1;
+      output_shape.dim(2) = (input_shape.dim(2) - 1) / strides.dim(1) + 1;
+      int32_t pad_along_height = (input_shape.dim(1) % strides.dim(0) == 0)
+          ? std::max(kernel_shape.dim(0) - strides.dim(0), 0)
+          : std::max((output_shape.dim(1) - 1) * strides.dim(0) + kernel_shape.dim(0) - input_shape.dim(1), 0);
+      int32_t pad_along_width = (input_shape.dim(2) % strides.dim(1) == 0)
+          ? std::max(kernel_shape.dim(1) - strides.dim(1), 0)
+          : std::max((output_shape.dim(2) - 1) * strides.dim(1) + kernel_shape.dim(1) - input_shape.dim(2), 0);
+      _paddings.at(0) = pad_along_height / 2;
+      _paddings.at(1) = pad_along_width / 2;
+      break;
+    }
+    case ops::PaddingType::Valid:
+      _paddings.at(0) = 0;
+      _paddings.at(1) = 0;
+      // FALLTHROUGH
+    case ops::PaddingType::Custom:
+      output_shape.dim(1) = (input_shape.dim(1) + 2 * getPadding(0) - kernel_shape.dim(0)) / strides.dim(0) + 1;
+      output_shape.dim(2) = (input_shape.dim(2) + 2 * getPadding(1) - kernel_shape.dim(1)) / strides.dim(1) + 1;
+      break;
+    default:
+      assert(false && "invalid padding type");
+      break;
+  }
+  // For now padding for channels is not supported, initialize it with zero
+  _paddings.at(2) = 0;
+
+  output_shape.dim(0) = input_shape.dim(0);
+  output_shape.dim(3) = input_shape.dim(3) * kernel_shape.dim(3);
+  setOutputShape(0, output_shape);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
diff --git a/contrib/nnc/core/modelIR/operations/FullyConnectedOp.cpp b/contrib/nnc/core/modelIR/operations/FullyConnectedOp.cpp
new file mode 100644 (file)
index 0000000..c4ffa6b
--- /dev/null
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/FullyConnectedOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+void FullyConnectedOp::inferOutputShapes() {
+  auto& input_shape = getInputShape(0);
+  auto& weights_shape = getWeights().getShape();
+  auto input_rank = input_shape.rank();
+  auto weights_rank = weights_shape.rank();
+
+  assert(weights_rank >= 2);
+  assert(input_rank == weights_rank);
+  assert(input_shape.dim(input_rank - 1) == weights_shape.dim(weights_rank - 2));
+  (void)input_rank;
+  for (int32_t i = 0; i < weights_rank - 2; ++i)
+    assert(weights_shape.dim(i) == input_shape.dim(i));
+
+  Shape output_shape = weights_shape;
+  output_shape.dim(weights_rank - 1) = weights_shape.dim(weights_rank - 1);
+  output_shape.dim(weights_rank - 2) = input_shape.dim(weights_rank - 2);
+  setOutputShape(0, output_shape);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
  * limitations under the License.
  */
 
-#include <vector>
+#include "core/modelIR/operations/PadOp.h"
 
-#include "passes/common_frontend/shape_helper.h"
-#include "pass/PassException.h"
+namespace nnc {
+namespace mir {
+namespace ops {
 
+void PadOp::inferOutputShapes() {
+  const Shape& input_shape = getInputShape(0);
+  int32_t num_dims = input_shape.rank();
 
-namespace nnc
-{
+  assert(num_dims == getNumDim());
 
-void ShapeHelper::cutOffBatchDim(mir::Shape& shape)
-{
-  if (shape.dim(0) != 1)
-  {
-    throw PassException{"While attempting to cut off tensor batch dimension (first one),"
-                        "found that it is not 1. Check the model being imported, if the first"
-                        "dimension of the input is not 1, then it might be not batch, and the"
-                        "code needs some restructuring"};
+  Shape out_shape;
+  out_shape.resize(num_dims);
+  for (int32_t dim = 0; dim < num_dims; ++dim) {
+    std::pair<int32_t, int32_t> padding = getPaddingForDim(dim);
+    out_shape.dim(dim) = padding.first + input_shape.dim(dim) + padding.second;
   }
 
-  for (int32_t i = 0; i < shape.rank() - 1; ++i)
-  {
-    shape.dim(i) = shape.dim(i + 1);
-  }
-  shape.resize(shape.rank() - 1);
+  setOutputShape(0, out_shape);
 }
 
+} // namespace ops
+} // namespace mir
 } // namespace nnc
diff --git a/contrib/nnc/core/modelIR/operations/PoolOp.cpp b/contrib/nnc/core/modelIR/operations/PoolOp.cpp
new file mode 100644 (file)
index 0000000..457ee01
--- /dev/null
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/PoolOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+void PoolOp::inferOutputShapes() {
+  auto& input_shape = getInputShape(0);
+  auto& window_shape = getWindowShape();
+  auto& strides = getStrides();
+  auto input_rank = input_shape.rank();
+
+  assert(input_rank == 4);
+
+  Shape output_shape;
+  output_shape.resize(input_rank);
+
+  switch (getPaddingType()) {
+    case ops::PaddingType::Same: {
+      output_shape.dim(1) = (input_shape.dim(1) - 1) / strides.dim(0) + 1;
+      output_shape.dim(2) = (input_shape.dim(2) - 1) / strides.dim(1) + 1;
+      int32_t pad_along_height = (input_shape.dim(1) % strides.dim(0) == 0)
+          ? std::max(window_shape.dim(0) - strides.dim(0), 0)
+          : std::max((output_shape.dim(1) - 1) * strides.dim(0) + window_shape.dim(0) - input_shape.dim(1), 0);
+      int32_t pad_along_width = (input_shape.dim(1) % strides.dim(0) == 0)
+          ? std::max(window_shape.dim(1) - strides.dim(1), 0)
+          : std::max((output_shape.dim(2) - 1) * strides.dim(1) + window_shape.dim(1) - input_shape.dim(2), 0);
+      _paddings.at(0) = pad_along_height / 2;
+      _paddings.at(1) = pad_along_width / 2;
+      break;
+    }
+    case ops::PaddingType::Valid:
+      _paddings.at(0) = 0;
+      _paddings.at(1) = 0;
+      // FALLTHROUGH
+    case ops::PaddingType::Custom:
+      output_shape.dim(1) = (input_shape.dim(1) + 2 * getPadding(0) - window_shape.dim(0)) / strides.dim(0) + 1;
+      output_shape.dim(2) = (input_shape.dim(2) + 2 * getPadding(1) - window_shape.dim(1)) / strides.dim(0) + 1;
+      break;
+    default:
+      assert(false && "invalid padding type");
+      break;
+  }
+  // For now padding for channels is not supported, initialize it with zero
+  _paddings.at(2) = 0;
+
+  output_shape.dim(0) = input_shape.dim(0);
+  output_shape.dim(3) = input_shape.dim(3);
+  setOutputShape(0, output_shape);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
diff --git a/contrib/nnc/core/modelIR/operations/SqueezeOp.cpp b/contrib/nnc/core/modelIR/operations/SqueezeOp.cpp
new file mode 100644 (file)
index 0000000..7e7e70e
--- /dev/null
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/SqueezeOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+void SqueezeOp::inferOutputShapes() {
+  assert(getNumInputs() == 1);
+
+  const auto& input_shape = getInputShape(0);
+  int32_t input_rank = input_shape.rank();
+  Shape output_shape;
+  int32_t output_rank = 0;
+
+  std::vector<int32_t> dims_to_squeeze;
+
+  if (getNumSqueezeDims() == 0) {
+    for (int32_t i = 0; i < input_rank; ++i) {
+      if (input_shape.dim(i) == 1) {
+        dims_to_squeeze.push_back(i);
+      }
+    }
+  } else {
+    dims_to_squeeze = getDimsToSqueeze();
+    sort(dims_to_squeeze.begin(), dims_to_squeeze.end());
+    dims_to_squeeze.erase(
+        unique(dims_to_squeeze.begin(), dims_to_squeeze.end()),
+        dims_to_squeeze.end()
+    );
+  }
+
+  if (dims_to_squeeze.size() == static_cast<size_t>(input_rank)) {
+    //Input shape have 1s in all dimensions, output shape is (1,)
+    setOutputShape(0, Shape{1});
+    return;
+  }
+
+  size_t squeezing_idx = 0;
+  output_shape.resize(input_rank - dims_to_squeeze.size());
+  for (int32_t i = 0; i < input_rank; ++i) {
+    if (squeezing_idx < dims_to_squeeze.size() && i == dims_to_squeeze[squeezing_idx]) {
+      if (input_shape.dim(i) != 1)
+        throw std::invalid_argument("All squeezed dimensions should have size 1");
+
+      squeezing_idx++;
+    } else {
+      output_shape.dim(output_rank++) = input_shape.dim(i);
+    }
+  }
+
+  setOutputShape(0, output_shape);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
index ac8213e..f09f458 100644 (file)
@@ -65,13 +65,12 @@ public:
 
   const nnc::mir::Shape& getInputShape(std::size_t index) const;
   const nnc::mir::Shape& getOutputShape(std::size_t index) const;
-  void setInputShape(std::size_t index, const nnc::mir::Shape& shape);
-  void setOutputShape(std::size_t index, const nnc::mir::Shape& shape);
 
   void accept(IVisitor* v);
 
 protected:
   Operation(Type type, const std::vector<IODescriptor>& args);
+  void setOutputShape(std::size_t index, const nnc::mir::Shape& shape);
 
 private:
   Type _type;
@@ -81,7 +80,6 @@ private:
   std::size_t _num_outputs;
   std::vector<IODescriptor> _inputs;
   std::vector<Operation*> _outputs;
-  std::map<size_t, nnc::mir::Shape> _inputShapes;
   std::map<size_t, nnc::mir::Shape> _outputShapes;
 };
 
diff --git a/contrib/nnc/include/core/modelIR/ShapeInference.h b/contrib/nnc/include/core/modelIR/ShapeInference.h
deleted file mode 100644 (file)
index 5138455..0000000
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef _NNC_CORE_IR_MODEL_SHAPE_INFERENCE_
-#define _NNC_CORE_IR_MODEL_SHAPE_INFERENCE_
-#include <limits>
-
-#include "core/modelIR/Visitor.h"
-#include "core/modelIR/Operation.h"
-
-namespace nnc
-{
-namespace mir
-{
-
-class ShapeInference : public IVisitor {
-public:
-  void visit(ops::ConcatOp& op) override;
-  void visit(ops::ConstantOp& op) override;
-  void visit(ops::Conv2DOp& op) override;
-  void visit(ops::DepthwiseConv2DOp& op) override;
-  void visit(ops::ReluOp& op) override;
-  void visit(ops::SoftmaxOp& op) override;
-  void visit(ops::PoolOp& op) override;
-  void visit(ops::FullyConnectedOp& op) override;
-  void visit(ops::CappedReluOp& op) override;
-  void visit(ops::BiasAddOp& op) override;
-  void visit(ops::ReshapeOp& op) override;
-  void visit(ops::ResizeOp& op) override;
-  void visit(ops::VariableOp& op) override;
-  void visit(ops::ScaleOp& op) override;
-  void visit(ops::BatchNormOp& op) override;
-  void visit(ops::DropoutOp& op) override;
-  void visit(ops::TanhOp& op) override;
-  void visit(ops::ElementwiseOp& op) override;
-  void visit(ops::DeConv2DOp& op) override;
-  void visit(ops::EluOp& op) override;
-  void visit(ops::SqueezeOp& op) override;
-  void visit(ops::PadOp& op) override;
-  void visit(ops::ReduceFOp& op) override;
-
-protected:
-  void fillInputShapes(Operation& op);
-};
-
-} // namespace mir
-} // namespace nnc
-
-#endif //_NNC_CORE_IR_MODEL_SHAPE_INFERENCE_
index 66572dc..c15a9a2 100644 (file)
@@ -26,8 +26,11 @@ namespace ops {
 class BatchNormOp : public Operation {
 public:
   BatchNormOp(const IODescriptor& arg, float movingAvgFraction, float eps, bool spatial)
-    : Operation(Type::batchNorm, {arg}), _movingAvgFraction(movingAvgFraction), _eps(eps),
-      _spatial(spatial) {}
+      : Operation(Type::batchNorm, {arg}), _movingAvgFraction(movingAvgFraction), _eps(eps),
+        _spatial(spatial) {
+    // Infer output shape.
+    setOutputShape(0, getInputShape(0));
+  }
 
   /**
    * @return The epsilon value to use to avoid division by zero.
index dc42c9e..7c16d34 100644 (file)
@@ -27,7 +27,10 @@ namespace ops {
 class BiasAddOp : public Operation {
 public:
   BiasAddOp(const IODescriptor& arg, const TensorVariant& weights)
-    : Operation(Type::biasAdd, {arg}), _weights(weights) {}
+    : Operation(Type::biasAdd, {arg}), _weights(weights) {
+    // Infer output shape.
+    setOutputShape(0, getInputShape(0));
+  }
 
   const TensorVariant& getWeights() const { return _weights; }
 
index bbb5e34..df05ac0 100644 (file)
@@ -26,7 +26,10 @@ namespace ops {
 class CappedReluOp : public Operation {
 public:
   CappedReluOp(const IODescriptor& arg, float cap)
-    : Operation(Type::cappedReLU, {arg}), _cap(cap) {}
+    : Operation(Type::cappedReLU, {arg}), _cap(cap) {
+    // Infer output shape.
+    setOutputShape(0, getInputShape(0));
+  }
 
   float getCap() const { return _cap; }
 
index 19f2eda..90c7798 100644 (file)
@@ -29,7 +29,9 @@ namespace ops {
 class ConcatOp : public Operation {
 public:
   ConcatOp(const std::vector<IODescriptor>& args, int32_t axis)
-    : Operation(Type::concat, args), _axis(axis) {}
+      : Operation(Type::concat, args), _axis(axis) {
+    inferOutputShapes();
+  }
 
   int32_t getAxis() const {
     if (_axis < 0) {
@@ -43,6 +45,8 @@ public:
   }
 
 private:
+  void inferOutputShapes();
+
   /// @brief The axis along which to concatenate, may be negative to index from the end
   int32_t _axis;
 };
index 1ce3458..075ccad 100644 (file)
@@ -25,13 +25,14 @@ namespace ops {
 
 class ConstantOp : public Operation {
 public:
-  explicit ConstantOp(TensorVariant* value) : Operation(Type::constant, {}),
-                                                     _value(*value) {}
+  ConstantOp(const TensorVariant& value) : Operation(Type::constant, {}), _value(value) {
+    setOutputShape(0, _value.getShape());
+  }
 
-  const TensorVariant &getValue() const {return _value;}
+  const TensorVariant& getValue() const { return _value; }
 
 private:
-    TensorVariant _value;
+  TensorVariant _value;
 };
 
 } // namespace ops
index 5d10d31..536d7a7 100644 (file)
@@ -29,27 +29,46 @@ namespace ops {
 
 class Conv2DOp : public Operation {
 public:
-  Conv2DOp(const IODescriptor& arg, const TensorVariant& kernel, const Shape& strides,
-           PaddingType padding)
-    : Operation(Type::conv2D, {arg}), _kernel(kernel), _strides(strides), _padding(padding) {
-    _pads.resize(3);
+  Conv2DOp(const IODescriptor& arg,
+           const TensorVariant& kernel,
+           const Shape& strides,
+           const std::vector<int32_t>& paddings)
+      : Operation(Type::conv2D, {arg}),
+        _kernel(kernel),
+        _strides(strides),
+        _paddingType(PaddingType::Custom),
+        _paddings(paddings) {
+    inferOutputShapes();
+  }
+
+  Conv2DOp(const IODescriptor& arg,
+           const TensorVariant& kernel,
+           const Shape& strides,
+           PaddingType padding_type)
+      : Operation(Type::conv2D, {arg}),
+        _kernel(kernel),
+        _strides(strides),
+        _paddingType(padding_type),
+        _paddings(3) {
+    assert(_paddingType != PaddingType::Custom);
+    inferOutputShapes();
   }
 
   const TensorVariant& getKernel() const { return _kernel; }
 
   const Shape& getStrides() const { return _strides; }
 
-  PaddingType getPaddingType() const { return _padding; }
+  PaddingType getPaddingType() const { return _paddingType; }
 
-  int32_t getPadding(int32_t dim) const { return _pads[dim]; }
-
-  void setPadding(int32_t dim, int32_t pad) { _pads[dim] = pad; }
+  int32_t getPadding(int32_t dim) const { return _paddings[dim]; }
 
 private:
+  void inferOutputShapes();
+
   const TensorVariant _kernel;
   Shape _strides;
-  PaddingType _padding;
-  std::vector<int32_t> _pads;
+  PaddingType _paddingType;
+  std::vector<int32_t> _paddings;
 };
 
 } // namespace ops
index 0ce5e27..3816256 100644 (file)
@@ -27,27 +27,46 @@ namespace ops {
 
 class DeConv2DOp : public Operation {
 public:
-  DeConv2DOp(const IODescriptor& arg, const TensorVariant& kernel, const Shape& strides,
-             PaddingType padding)
-    : Operation(Type::deConv2D, {arg}), _kernel(kernel), _strides(strides), _padding(padding) {
-    _pads.resize(3);
+  DeConv2DOp(const IODescriptor& arg,
+             const TensorVariant& kernel,
+             const Shape& strides,
+             const std::vector<int32_t>& paddings)
+      : Operation(Type::deConv2D, {arg}),
+        _kernel(kernel),
+        _strides(strides),
+        _paddingType(PaddingType::Custom),
+        _paddings(paddings) {
+    inferOutputShapes();
+  }
+
+  DeConv2DOp(const IODescriptor& arg,
+             const TensorVariant& kernel,
+             const Shape& strides,
+             PaddingType padding_type)
+      : Operation(Type::deConv2D, {arg}),
+        _kernel(kernel),
+        _strides(strides),
+        _paddingType(padding_type),
+        _paddings(3) {
+    assert(_paddingType != PaddingType::Custom);
+    inferOutputShapes();
   }
 
   const TensorVariant& getKernel() const { return _kernel; }
 
   const Shape& getStrides() const { return _strides; }
 
-  PaddingType getPaddingType() const { return _padding; }
+  PaddingType getPaddingType() const { return _paddingType; }
 
-  int getPadding(int32_t dim) const { return _pads[dim]; }
-
-  void setPadding(int32_t dim, int pad) { _pads[dim] = pad; }
+  int getPadding(int32_t dim) const { return _paddings[dim]; }
 
 private:
+  void inferOutputShapes();
+
   const TensorVariant _kernel;
   Shape _strides;
-  PaddingType _padding;
-  std::vector<int32_t> _pads;
+  PaddingType _paddingType;
+  std::vector<int32_t> _paddings;
 };
 
 } // namespace ops
index e25a057..bf062ef 100644 (file)
@@ -29,27 +29,46 @@ namespace ops {
 
 class DepthwiseConv2DOp : public Operation {
 public:
-  DepthwiseConv2DOp(const IODescriptor& arg, const TensorVariant& kernel, const Shape& strides,
-                    PaddingType padding)
-    : Operation(Type::depthwiseConv, {arg}), _kernel(kernel), _strides(strides), _padding(padding) {
-    _pads.resize(_kernel.getShape().rank());
+  DepthwiseConv2DOp(const IODescriptor& arg,
+                    const TensorVariant& kernel,
+                    const Shape& strides,
+                    const std::vector<int32_t>& paddings)
+      : Operation(Type::depthwiseConv, {arg}),
+        _kernel(kernel),
+        _strides(strides),
+        _paddingType(PaddingType::Custom),
+        _paddings(paddings) {
+    inferOutputShapes();
+  }
+
+  DepthwiseConv2DOp(const IODescriptor& arg,
+                    const TensorVariant& kernel,
+                    const Shape& strides,
+                    PaddingType padding_type)
+      : Operation(Type::depthwiseConv, {arg}),
+        _kernel(kernel),
+        _strides(strides),
+        _paddingType(padding_type),
+        _paddings(_kernel.getShape().rank()) {
+    assert(_paddingType != PaddingType::Custom);
+    inferOutputShapes();
   }
 
   const TensorVariant& getKernel() const { return _kernel; }
 
   const Shape& getStrides() const { return _strides; }
 
-  PaddingType getPaddingType() const { return _padding; }
+  PaddingType getPaddingType() const { return _paddingType; }
 
-  int32_t getPadding(int32_t dim) const { return _pads[dim]; }
-
-  void setPadding(int32_t dim, int32_t pad) { _pads[dim] = pad; }
+  int32_t getPadding(int32_t dim) const { return _paddings[dim]; }
 
 private:
+  void inferOutputShapes();
+
   const TensorVariant _kernel;
   Shape _strides;
-  PaddingType _padding;
-  std::vector<int32_t> _pads;
+  PaddingType _paddingType;
+  std::vector<int32_t> _paddings;
 };
 
 } // namespace ops
index 5b48200..2a0c870 100644 (file)
@@ -25,7 +25,10 @@ namespace ops {
 
 class DropoutOp : public Operation {
 public:
-  DropoutOp(const IODescriptor& arg, float rate) : Operation(Type::dropout, {arg}), _rate(rate) {}
+  DropoutOp(const IODescriptor& arg, float rate) : Operation(Type::dropout, {arg}), _rate(rate) {
+    // Infer output shape.
+    setOutputShape(0, getInputShape(0));
+  }
 
   /**
    * @return The ratio of random dropout
index 19c0ae5..559cbee 100644 (file)
@@ -39,7 +39,11 @@ public:
    * @param num_inputs Number of inputs
    */
   ElementwiseOp(const std::vector<IODescriptor>& args, OpType op_type)
-    : Operation(Type::elementwise, args), _opType(op_type) {};
+    : Operation(Type::elementwise, args), _opType(op_type) {
+    // Infer output shape.
+    // TODO Check that all inputs have the same shape.
+    setOutputShape(0, getInputShape(0));
+  };
 
 private:
   OpType _opType;
index 9341eb2..d095914 100644 (file)
@@ -25,7 +25,9 @@ namespace ops {
 
 class EluOp : public Operation {
 public:
-  EluOp(const IODescriptor& arg, float alpha) : Operation(Type::ELU, {arg}), _alpha(alpha) {}
+  EluOp(const IODescriptor& arg, float alpha) : Operation(Type::ELU, {arg}), _alpha(alpha) {
+    setOutputShape(0, getInputShape(0));
+  }
 
   float getAlpha() const { return _alpha; }
 
index 65a4b78..33b3b92 100644 (file)
@@ -27,14 +27,20 @@ namespace ops {
 class FullyConnectedOp : public Operation {
 public:
   FullyConnectedOp(const IODescriptor& arg, const TensorVariant& weights)
-    : Operation(Type::fullyConnected, {arg}), _weights(weights) {}
+      : Operation(Type::fullyConnected, {arg}), _weights(weights) {
+    inferOutputShapes();
+  }
 
   const TensorVariant& getWeights() const { return _weights; }
 
 private:
+  void inferOutputShapes();
+
   TensorVariant _weights;
 };
 
+
+
 } // namespace ops
 } // namespace mir
 } // namespace nnc
index 83dba33..d4211f8 100644 (file)
@@ -41,7 +41,9 @@ public:
         const std::vector<std::pair<int32_t, int32_t>>& paddings,
         const Scalar& scalar_value)
       : Operation(Type::pad, {arg}), _numDims(numDims),
-        _paddings(paddings), _scalarValue(scalar_value) {}
+        _paddings(paddings), _scalarValue(scalar_value) {
+    inferOutputShapes();
+  }
 
   /**
    * @param dim Dimension number
@@ -60,6 +62,8 @@ public:
   Scalar getScalar() const { return _scalarValue; }
 
 private:
+  void inferOutputShapes();
+
   std::vector<std::pair<int32_t, int32_t>> _paddings;
   int32_t _numDims;
   Scalar _scalarValue;
index 6e860aa..2576e8a 100644 (file)
@@ -38,14 +38,39 @@ public:
     EMPTY // Consider that there are no elements outside of input shape
   };
 
-  PoolOp(const IODescriptor& arg, const Shape& windowShape, const Shape& strides,
-         PoolingType poolType, PaddingType padding, BorderType borderType)
-    : Operation(Type::pool, {arg}), _padding(padding), _poolingType(poolType),
-      _borderType(borderType), _windowShape(windowShape), _strides(strides) {
-    _pads.resize(_windowShape.rank());
+  PoolOp(const IODescriptor& arg,
+         const Shape& window_shape,
+         const Shape& strides,
+         PoolingType pooling_type,
+         const std::vector<int32_t>& paddings,
+         BorderType border_type)
+      : Operation(Type::pool, {arg}),
+        _windowShape(window_shape),
+        _strides(strides),
+        _poolingType(pooling_type),
+        _paddingType(PaddingType::Custom),
+        _paddings(paddings),
+        _borderType(border_type) {
+    inferOutputShapes();
   }
 
-  PaddingType getPaddingType() const { return _padding; }
+  PoolOp(const IODescriptor& arg,
+         const Shape& window_shape,
+         const Shape& strides,
+         PoolingType pooling_type,
+         PaddingType padding_type,
+         BorderType border_type)
+      : Operation(Type::pool, {arg}),
+        _windowShape(window_shape),
+        _strides(strides),
+        _poolingType(pooling_type),
+        _paddingType(padding_type),
+        _paddings(window_shape.rank()),
+        _borderType(border_type) {
+    assert(_paddingType != PaddingType::Custom);
+    inferOutputShapes();
+  }
+  PaddingType getPaddingType() const { return _paddingType; }
 
   BorderType getBorderType() const { return _borderType; }
 
@@ -55,20 +80,17 @@ public:
 
   const Shape& getStrides() const { return _strides; }
 
-  int32_t getPadding(int32_t dim) const { return _pads[dim]; }
-
-  void setPadding(int32_t dim, int32_t pad) {
-    assert(dim < (int32_t)_pads.size());
-    _pads[dim] = pad;
-  }
+  int32_t getPadding(int32_t dim) const { return _paddings[dim]; }
 
 private:
-  PaddingType _padding;
+  void inferOutputShapes();
+
+  PaddingType _paddingType;
   PoolingType _poolingType;
   BorderType _borderType;
   Shape _windowShape;
   Shape _strides;
-  std::vector<int32_t> _pads;
+  std::vector<int32_t> _paddings;
 };
 
 } // namespace ops
index 0e722a7..2c87673 100644 (file)
@@ -33,25 +33,50 @@ public:
   /**
    * @brief Reduces with (a,b) -> a + b / n where n is the size of dimension(s) being reduced
    * @param reduce_dims vector of ints denoting reduction dimensions. assume it is sorted
-   * @param keepDims whether to keep the original rank
-   * @param fT function to reduce the tensor with (should be associative)
+   * @param keep_dims whether to keep the original rank
+   * @param func_type function to reduce the tensor with (should be associative)
    */
-  explicit ReduceFOp(const IODescriptor& arg,
-                     const std::vector<int32_t>& reduce_dims, bool keepDims,
-                     FuncType fT) :
-    Operation(Type::reduceFOp, {arg}), _reduceDims(reduce_dims),
-    _keepDims(keepDims), _fT(fT) {};
+  ReduceFOp(const IODescriptor& arg,
+            const std::vector<int32_t>& reduce_dims,
+            bool keep_dims,
+            FuncType func_type)
+      : Operation(Type::reduceFOp, {arg}), _reduceDims(reduce_dims), _keepDims(keep_dims),
+        _funcType(func_type) {
+    // Infer output shapes.
+    const auto& input_shape = getInputShape(0);
+    const auto& red_dims = getReductionDims();
+    Shape output_shape;
+    if (getKeepDims()) {
+      output_shape = input_shape;
+      for (auto red_axis: red_dims) {
+        output_shape.dim(red_axis) = 1;
+      }
+    } else {
+      std::vector<int32_t> out_dims;
+      out_dims.reserve(input_shape.rank() - red_dims.size());
+      auto red_axis = red_dims.begin();
+      for (int32_t axis_id = 0; axis_id < input_shape.rank(); axis_id++) {
+        if (axis_id == (*red_axis)) {
+          red_axis++;
+        } else {
+          out_dims.emplace_back(input_shape.dim(axis_id));
+        }
+      }
+      output_shape = Shape(out_dims);
+    }
+
+    setOutputShape(0, output_shape);
+  };
 
   const std::vector<int32_t>& getReductionDims() { return _reduceDims; };
 
   bool getKeepDims() const { return _keepDims; };
 
-  FuncType getFuncType() const { return _fT; };
+  FuncType getFuncType() const { return _funcType; };
 private:
   std::vector<int32_t> _reduceDims;
   bool _keepDims;
-  FuncType _fT;
-
+  FuncType _funcType;
 };
 
 } // namespace ops
index 2ddfcb4..58208a6 100644 (file)
@@ -25,7 +25,10 @@ namespace ops {
 
 class ReluOp : public Operation {
 public:
-  explicit ReluOp(const IODescriptor& arg) : Operation(Type::ReLU, {arg}) {}
+  explicit ReluOp(const IODescriptor& arg) : Operation(Type::ReLU, {arg}) {
+    // Infer output shape.
+    setOutputShape(0, getInputShape(0));
+  }
 };
 
 } // namespace ops
index ab9604b..6164498 100644 (file)
@@ -23,7 +23,25 @@ namespace ops {
 class ReshapeOp : public Operation {
 public:
   ReshapeOp(const IODescriptor& arg, const Shape& shape) : Operation(Type::reshape, {arg}) {
-    setOutputShape(0, shape);
+    const Shape& input_shape = getInputShape(0);
+    auto output_shape = shape;
+
+    auto in_elements_num = input_shape.numElements();
+    int32_t out_elements_num = 1;
+    // Can't use num_elements due to -1 in input shape and Shape using unsigned ints for dimensions.
+    for (int32_t d = 0; d < output_shape.rank(); ++d) {
+      auto dim = output_shape.dim(d);
+      if (dim != Shape::autoDim)
+        out_elements_num *= dim;
+    }
+
+    for (int32_t d = 0; d < output_shape.rank(); ++d) {
+      auto& dim = output_shape.dim(d);
+      if (dim == Shape::autoDim)
+        dim = static_cast<int32_t>(in_elements_num / out_elements_num);
+    }
+
+    setOutputShape(0, output_shape);
   }
 };
 
index b18a7f0..2b1b1b5 100644 (file)
@@ -20,6 +20,7 @@
 #include "core/modelIR/Operation.h"
 #include "core/modelIR/Shape.h"
 #include <vector>
+#include <cmath>
 
 namespace nnc {
 namespace mir {
@@ -32,26 +33,38 @@ public:
     nearestNeighbor, // TODO: BICUBIC and BILINEAR
   };
 
-  explicit ResizeOp(const IODescriptor& arg, ResizeMethod mode, const std::vector<float>& scales) :
-    Operation(Type::resizeIm, {arg}), _mode(mode), _scales(scales),
-    _resultShape({}) {}
+  ResizeOp(const IODescriptor& arg, ResizeMethod mode, const std::vector<float>& scales)
+      : Operation(Type::resizeIm, {arg}), _mode(mode), _scales(scales) {
+    // Infer output shape based on given scales.
+    auto& input_shape = getInputShape(0);
+    assert(input_shape.rank() == 4 && _scales.size() == 4);
+    Shape output_shape;
+    output_shape.resize(input_shape.rank());
 
-  explicit ResizeOp(const IODescriptor& arg, ResizeMethod mode, const Shape& shape) :
-    Operation(Type::resizeIm, {arg}), _mode(mode),
-    _scales({}), _resultShape(shape) {}
+    for (int32_t i = 0; i < input_shape.rank(); ++i) {
+      output_shape.dim(i) = (int32_t)lroundf(_scales.at(i) * input_shape.dim(i));
+    }
+
+    setOutputShape(0, output_shape);
+  }
+
+  ResizeOp(const IODescriptor& arg, ResizeMethod mode, const Shape& output_shape)
+      : Operation(Type::resizeIm, {arg}), _mode(mode) {
+    // Calculate scales based on given shape.
+    auto& input_shape = getInputShape(0);
+    assert(input_shape.rank() == 4 && output_shape.rank() == 4);
+    setOutputShape(0, output_shape);
+    _scales = {1.0f, static_cast<float>(output_shape.dim(1)) / input_shape.dim(1),
+               static_cast<float>(output_shape.dim(2)) / input_shape.dim(2), 1.0f};
+  }
 
   /** @return The resize mode */
   ResizeMethod getMode() const { return _mode; }
 
-  const Shape& getResultShape() const { return _resultShape; }
-
   const std::vector<float>& getScales() const { return _scales; }
 
-  void setScales(const std::vector<float>& scales) { _scales = scales; }
-
 private:
   std::vector<float> _scales;
-  Shape _resultShape;
   ResizeMethod _mode;
 };
 
index 891358e..0fd9ca2 100644 (file)
@@ -26,7 +26,10 @@ namespace ops {
 class ScaleOp : public Operation {
 public:
   ScaleOp(const IODescriptor& arg, const TensorVariant& weights)
-    : Operation(Type::scale, {arg}), _weights(weights) {}
+      : Operation(Type::scale, {arg}), _weights(weights) {
+    // Infer output shape.
+    setOutputShape(0, getInputShape(0));
+  }
 
   /**
    * @return The input 1-dimensional scale tensor.
index 8d1e6a8..edafcef 100644 (file)
@@ -28,7 +28,9 @@ namespace ops {
  */
 class SoftmaxOp : public Operation {
 public:
-  SoftmaxOp(const IODescriptor& arg, int32_t axis) : Operation(Type::softmax, {arg}), _axis(axis) {}
+  SoftmaxOp(const IODescriptor& arg, int32_t axis) : Operation(Type::softmax, {arg}), _axis(axis) {
+    setOutputShape(0, getInputShape(0));
+  }
 
   int32_t getAxis() const {
     if (_axis < 0) {
index a2b28dc..b17db47 100644 (file)
@@ -18,6 +18,7 @@
 #define _NNC_CORE_IR_MODEL_SQUEEZE_OP_H_
 
 #include "core/modelIR/Operation.h"
+#include <algorithm>
 
 namespace nnc {
 namespace mir {
@@ -26,7 +27,12 @@ namespace ops {
 class SqueezeOp : public Operation {
 public:
   SqueezeOp(const IODescriptor& arg, const std::vector<int32_t>& dims_to_squeeze)
-    : Operation(Type::squeeze, {arg}), _dims_to_squeeze(dims_to_squeeze) {}
+    : Operation(Type::squeeze, {arg}), _dims_to_squeeze(dims_to_squeeze) {
+    // Infer output shape.
+    inferOutputShapes();
+  }
+
+  void inferOutputShapes();
 
   int32_t getNumSqueezeDims() const { return static_cast<int32_t>(_dims_to_squeeze.size()); }
 
index d5a4aa2..0e00a06 100644 (file)
@@ -25,7 +25,10 @@ namespace ops {
 
 class TanhOp : public Operation {
 public:
-  explicit TanhOp(const IODescriptor& arg) : Operation(Type::tanh, {arg}) {}
+  explicit TanhOp(const IODescriptor& arg) : Operation(Type::tanh, {arg}) {
+    // Infer output shape.
+    setOutputShape(0, getInputShape(0));
+  }
 };
 
 } // namespace ops
index ec1ac22..186d2fc 100644 (file)
@@ -27,8 +27,6 @@ class ShapeHelper
 public:
   template<typename Iterable>
   static mir::Shape createShape(const Iterable &iter, std::size_t);
-
-  static void cutOffBatchDim(mir::Shape& shape);
 };
 
 template<typename Iterable>
index d07c227..0fbe1ad 100644 (file)
@@ -17,7 +17,6 @@
 #include "passes/acl_soft_backend/AclCppGenerator.h"
 #include "passes/acl_soft_backend/AclCppOpGenerator.h"
 #include "passes/acl_soft_backend/AclCppException.h"
-#include "core/modelIR/ShapeInference.h"
 #include "option/Options.h"
 
 #include <fstream>
@@ -61,10 +60,6 @@ PassData AclCppCodeGenerator::run(PassData data) {
   if (par_out.fail())
     throw AclCppException("Can not open parameter output file: " + par_path);
 
-  // Inference shapes in the computation graph.
-  mir::ShapeInference si;
-  g->accept(&si);
-
   ArtifactGeneratorCppCode code_gen(code_out);
   ArtifactGeneratorCppDecl decl_gen(decl_out);
 
index c8cd6fe..b5df136 100644 (file)
@@ -162,12 +162,14 @@ void CaffeImporter::collectUnsupportedOp(const LayerParameter& lp) {
     case CaffeOpType::scale:
     case CaffeOpType::dropout:
     case CaffeOpType::split:
-    case CaffeOpType::concat:
     case CaffeOpType::eltwise:
     case CaffeOpType::ELU:
     case CaffeOpType::tanh:
       // No checks
       break;
+    case CaffeOpType::concat:
+      _opCreator->checkConcat(lp.concat_param(), _problemsOpSet);
+      break;
     case CaffeOpType::deconvolution:
     case CaffeOpType::convolution:
       _opCreator->checkConvolution(lp.convolution_param(), _problemsOpSet);
@@ -253,6 +255,7 @@ std::vector<std::shared_ptr<IrTensor>> CaffeImporter::createOpParams(const Layer
     if (lp.has_convolution_param() && blob.shape().dim_size() == 4) {
       // TODO support non default channel axis
       assert(lp.convolution_param().axis() == 1 && "assuming channel axis number set to default");
+      // Input x Output x Height x Width -> Height x Width x Output x Input
       params.emplace_back(transposeTensor<2, 3, 1, 0>(tensor));
     } else if (lp.has_inner_product_param() && blob.shape().dim_size() == 2) {
       params.emplace_back(transposeTensor<1, 0>(tensor));
index e968054..952a039 100644 (file)
@@ -223,19 +223,22 @@ std::vector<IODescriptor> CaffeOpCreator::convertInput(const LayerParameter& lay
 
   assert((num_shapes == 1 || num_shapes == num_inputs) && "Unsupported number of shapes.");
 
-  for (int i = 0; i < num_inputs; ++i)
-  {
+  for (int i = 0; i < num_inputs; ++i) {
     const auto& blob_name = layer.top(i);
     const auto& blob_shape = params.shape(num_shapes == 1 ? 0 : i);
     Shape shape = ShapeHelper::createShape(blob_shape.dim(), blob_shape.dim_size());
-    ShapeHelper::cutOffBatchDim(shape);
-    // WARNING! Temporary solution! Assuming that every 4D input will be used for a convolution,
-    // so we change every 4D input from Caffe NCHW to Model IR HWC (batch is cut off earlier).
-    // TODO: Implement a more consistent way of handling shapes within the model.
-    if (shape.rank() == 3)
-      shape = Shape{shape.dim(1), shape.dim(2), shape.dim(0)};
-    auto variable = createOp<ops::VariableOp>(shape);
-    variable->setName(blob_name);
+
+    // TODO For now we only support convolutional networks. The input data have already been
+    // transformed from Caffe NCHW format to ModelIR NHWC; reflect the changes in the IR.
+    assert(shape.rank() == 4);
+    shape = Shape{shape.dim(0), shape.dim(2), shape.dim(3), shape.dim(1)};
+
+    // TODO Remove this limitation.
+    assert(shape.dim(0) == 1);
+
+    // FIXME We cannot use CaffeOpCreator::createOp here, because we have to set name instantly.
+    // Otherwise interpreter backend won't work.
+    auto variable = _graph->create<ops::VariableOp>(blob_name, shape);
     descriptors.push_back(variable->getOutput(0));
   }
 
@@ -257,8 +260,14 @@ std::vector<IODescriptor>
 CaffeOpCreator::convertConvolution(const std::vector<IODescriptor>& inputs,
                                    const std::vector<std::shared_ptr<IrTensor>>& params,
                                    const caffe::ConvolutionParameter& opts) {
-  ops::PaddingType pad_type = ops::PaddingType::Custom;
-  Shape stride_shape = getConvStride(opts);
+  Shape strides = getConvStride(opts);
+
+  // Set pads
+  int pad_h = opts.has_pad_h() ? opts.pad_h() : 0;
+  int pad_w = opts.has_pad_w() ? opts.pad_w() : 0;
+  if (opts.pad_size() == 1)
+    pad_h = pad_w = opts.pad(0);
+  std::vector<int32_t> paddings{pad_h, pad_w, 0};
 
   std::shared_ptr<IrTensor> unfolded_tensor = params[0];
   Operation* conv2d;
@@ -270,32 +279,13 @@ CaffeOpCreator::convertConvolution(const std::vector<IODescriptor>& inputs,
     // This is depthwise convolution
     // TODO handle properly kernel with layer multiplier
     std::shared_ptr<IrTensor> transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(params[0]);
-    conv2d = createOp<ops::DepthwiseConv2DOp>(inputs[0], *transposed_tensor, stride_shape,
-                                              pad_type);
+    conv2d = createOp<ops::DepthwiseConv2DOp>(inputs[0], *transposed_tensor, strides, paddings);
   } else {
     if (num_groups != 1) {
       // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
       unfolded_tensor = fixGroupedKernel(opts.group(), params[0]);
     }
-    conv2d = createOp<ops::Conv2DOp>(inputs[0], *unfolded_tensor, stride_shape, pad_type);
-  }
-
-  // Set pads
-  int pad_h = opts.has_pad_h() ? opts.pad_h() : 0;
-  int pad_w = opts.has_pad_w() ? opts.pad_w() : 0;
-  if (opts.pad_size() == 1)
-    pad_h = pad_w = opts.pad(0);
-
-  if (is_depthwise) {
-    auto op = static_cast<ops::DepthwiseConv2DOp*>(conv2d);
-    op->setPadding(0, pad_h);
-    op->setPadding(1, pad_w);
-    op->setPadding(2, 0);
-  } else {
-    auto op = static_cast<ops::Conv2DOp*>(conv2d);
-    op->setPadding(0, pad_h);
-    op->setPadding(1, pad_w);
-    op->setPadding(2, 0);
+    conv2d = createOp<ops::Conv2DOp>(inputs[0], *unfolded_tensor, strides, paddings);
   }
 
   // bias_term is optional (so might not be present) and defaults to true
@@ -345,11 +335,20 @@ CaffeOpCreator::convertInnerProduct(const std::vector<IODescriptor>& inputs,
   }
 }
 
+void CaffeOpCreator::checkConcat(const caffe::ConcatParameter& opts,
+                                 std::set<std::string>& problemsOpSet) {
+  if (opts.axis() != 1)
+    problemsOpSet.insert("Concat: unsupported axis");
+}
+
 std::vector<IODescriptor>
 CaffeOpCreator::convertConcat(const std::vector<IODescriptor>& inputs,
                               const caffe::ConcatParameter& opts) {
-  auto result = createOp<ops::ConcatOp>(inputs, getAxisValue(opts));
-  return {result->getOutput(0)};
+  // NCHW -> NHWC
+  assert(opts.axis() == 1);
+  int32_t axis = 3;
+  auto concat = createOp<ops::ConcatOp>(inputs, axis);
+  return {concat->getOutput(0)};
 }
 
 void CaffeOpCreator::checkPooling(const PoolingParameter& opts,
@@ -369,9 +368,16 @@ std::vector<IODescriptor>
 CaffeOpCreator::convertPooling(const std::vector<IODescriptor>& inputs,
                                const caffe::PoolingParameter& opts) {
   Shape window_shape = getPoolWindowShape(opts);
+
+  // Set pads
+  int pad_h = opts.has_pad_h() ? opts.pad_h() : 0;
+  int pad_w = opts.has_pad_w() ? opts.pad_w() : 0;
+  if (opts.has_pad())
+    pad_h = pad_w = opts.pad();
+  std::vector<int32_t> paddings{pad_h, pad_w, 1};
+
   ops::PoolOp::PoolingType pool_type = getPoolingType(opts);
-  ops::PaddingType pad_type = ops::PaddingType::Custom;
-  Shape stride = getPoolStride(opts);
+  Shape strides = getPoolStride(opts);
   ops::PoolOp::BorderType border_type;
   switch (pool_type) {
     case ops::PoolOp::PoolingType::AVG:
@@ -385,26 +391,25 @@ CaffeOpCreator::convertPooling(const std::vector<IODescriptor>& inputs,
       assert(false);
   }
 
-  auto pooling = createOp<ops::PoolOp>(inputs[0], window_shape, stride, pool_type, pad_type,
+  auto pooling = createOp<ops::PoolOp>(inputs[0], window_shape, strides, pool_type, paddings,
                                        border_type);
 
-  // Set pads
-  auto op = static_cast<ops::PoolOp*>(pooling);
-  int pad_h = opts.has_pad_h() ? opts.pad_h() : 0;
-  int pad_w = opts.has_pad_w() ? opts.pad_w() : 0;
-  if (opts.has_pad())
-    pad_h = pad_w = opts.pad();
-  op->setPadding(0, pad_h);
-  op->setPadding(1, pad_w);
-  op->setPadding(2, 0);
-
   return {pooling->getOutput(0)};
 }
 
 std::vector<IODescriptor>
 CaffeOpCreator::convertSoftmax(const std::vector<IODescriptor>& inputs,
                                const caffe::SoftmaxParameter& opts) {
-  auto softmax = createOp<ops::SoftmaxOp>(inputs[0], getAxisValue(opts));
+  auto input = inputs[0];
+  auto& input_shape = input.op->getOutputShape(input.index);
+  // Workaround until we've got Transpose operation.
+  assert(input_shape.rank() == 4 || input_shape.rank() == 2);
+  if (input_shape.rank() == 4) {
+    assert(input_shape.dim(0) == 1);
+    Shape new_shape{input_shape.dim(1), input_shape.dim(2), input_shape.dim(3)};
+    input = createOp<ops::ReshapeOp>(input, new_shape)->getOutput(0);
+  }
+  auto softmax = createOp<ops::SoftmaxOp>(input, getAxisValue(opts));
   return {softmax->getOutput(0)};
 }
 
@@ -513,26 +518,21 @@ std::vector<IODescriptor>
 CaffeOpCreator::convertDeconvolution(const std::vector<IODescriptor>& inputs,
                                      const std::vector<std::shared_ptr<IrTensor>>& params,
                                      const caffe::ConvolutionParameter& opts) {
-  ops::PaddingType pad_type = ops::PaddingType::Custom;
-  Shape stride_shape = getStride(opts);
+  Shape strides = getStride(opts);
+
+  // Set pads
+  int pad_h = opts.has_pad_h() ? opts.pad_h() : 0;
+  int pad_w = opts.has_pad_w() ? opts.pad_w() : 0;
+  if (opts.pad_size())
+    pad_h = pad_w = opts.pad(0);
+  std::vector<int32_t> paddings{pad_h, pad_w, 0};
 
   std::shared_ptr<IrTensor> unfolded_tensor = params[0];
   if (opts.group() != 1) {
     // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
     unfolded_tensor = fixGroupedKernel(opts.group(), params[0]);
   }
-  auto deconv2d = createOp<ops::DeConv2DOp>(inputs[0], *unfolded_tensor, stride_shape, pad_type);
-
-  // Set pads
-  auto op = static_cast<ops::DeConv2DOp*>(deconv2d);
-
-  int pad_h = opts.has_pad_h() ? opts.pad_h() : 0;
-  int pad_w = opts.has_pad_w() ? opts.pad_w() : 0;
-  if (opts.pad_size())
-      pad_h = pad_w = opts.pad(0);
-  op->setPadding(0, pad_h);
-  op->setPadding(1, pad_w);
-  op->setPadding(2, 0);
+  auto deconv2d = createOp<ops::DeConv2DOp>(inputs[0], *unfolded_tensor, strides, paddings);
 
   // bias_term is optional (so might not be present) and defaults to true
   if (!opts.has_bias_term() || opts.bias_term()) {
index 607495b..3988da2 100644 (file)
@@ -94,6 +94,8 @@ public:
   std::vector<mir::IODescriptor> convertSplit(const std::vector<mir::IODescriptor>& inputs,
                                               const caffe::LayerParameter& lp);
 
+  void checkConcat(const caffe::ConcatParameter& opts, std::set<std::string>&);
+
   void checkConvolution(const caffe::ConvolutionParameter& layer, std::set<std::string>&);
 
   void checkInnerProduct(const caffe::InnerProductParameter& opts, std::set<std::string>&);
index ab2b230..84df465 100644 (file)
@@ -3,8 +3,7 @@
 ##########################################
 
 set(COMMON_SOURCES
-        model_allocation.cpp
-        shape_helper.cpp)
+        model_allocation.cpp)
 
 add_library(nn_import_common STATIC ${COMMON_SOURCES})
 set_target_properties(nn_import_common PROPERTIES POSITION_INDEPENDENT_CODE ON)
index eb0ed0e..4b7e6ff 100644 (file)
@@ -293,14 +293,15 @@ void NNInterpreter::visit(ops::ResizeOp& op) {
   mapByName(&op);
   auto operand = op.getPrevNodes()[0];
   Tensor<float> input(var(operand.op->getId())[operand.index]);
-  assert(input.getShape().rank() == 3 && "Must be rank 3 (for now)");
+  assert(input.getShape().rank() == 4 && "Must be rank 4 (for now)");
   switch (op.getMode()) {
     case ops::ResizeOp::ResizeMethod::nearestNeighbor: {
       auto scales = op.getScales();
       var(op.getId()) = Fill<float>(op.getOutputShape(0), [&scales, &input, &op](const Index& id) {
         const Index in_idx = {static_cast<int> (lroundf(scales[0] * id.at(0))),
                               static_cast<int> (lroundf(scales[1] * id.at(1))),
-                              static_cast<int> (lroundf(scales[2] * id.at(2)))};
+                              static_cast<int> (lroundf(scales[2] * id.at(2))),
+                              static_cast<int> (lroundf(scales[3] * id.at(3)))};
         return input.at(in_idx);
       })();
       break;
index 374e261..58365cd 100644 (file)
@@ -36,7 +36,6 @@
 #include "passes/interpreter/Interpreter.h"
 #include "passes/interpreter/InterpreterPass.h"
 
-#include "core/modelIR/ShapeInference.h"
 #include "core/modelIR/Graph.h"
 
 #include "core/modelIR/ShapeRange.h"
@@ -93,12 +92,8 @@ PassData InterpreterPass::run(PassData data) {
   auto g = static_cast<Graph*>(data);
   assert(g);
 
-  ShapeInference shape_inference;
-
   NNInterpreter interpreter;
 
-  g->accept(&shape_inference);
-
   // Check ops
   const auto& inputs = g->collectInputs();
   assert(inputs.size() == 1 && "Interpreter doesn't support networks with multiple input nodes");
index a712fae..e58b20f 100644 (file)
@@ -34,10 +34,19 @@ std::vector<TensorVariant> DepthwiseConv2D::operator()()
   Index pads({_op.getPadding(0), _op.getPadding(1), 0u});
 
   Shape outShape = res.getShape();
+  // Assume batch size == 1 and strip it off.
+  assert(outShape.dim(0) == 1);
+  outShape = {outShape.dim(1), outShape.dim(2), outShape.dim(3)};
+
   outShape.dim(2) = 1;
   ShapeRange outRange(outShape);
 
-  ShapeRange inRange(_input.getShape());
+  Shape inShape = _input.getShape();
+  // Assume batch size == 1 and strip it off.
+  assert(inShape.dim(0) == 1);
+  inShape = {inShape.dim(1), inShape.dim(2), inShape.dim(3)};
+
+  ShapeRange inRange(inShape);
 
   Index inIdx;
   inIdx.resize(outShape.rank());
@@ -48,16 +57,21 @@ std::vector<TensorVariant> DepthwiseConv2D::operator()()
 
   for (auto &outIdx : outRange)
   {
+    // Take into account stripped off batch dimension.
+    Index tmp_out_index{0, outIdx.at(0), outIdx.at(1), outIdx.at(2)};
+
     for (auto &kIdx : ShapeRange(kernelShape))
     {
       translate(inIdx, outIdx, kIdx, strides, pads);
 
       if (inRange.contains(inIdx))
       {
-        auto in = _input.at(inIdx);
+        // Take into account stripped off batch dimension.
+        Index tmp_in_index{0, inIdx.at(0), inIdx.at(1), inIdx.at(2)};
+        auto in = _input.at(tmp_in_index);
         auto b = _kernel.at(kIdx);
-        Index outIdxK = outIdx;
-        outIdxK.at(2) = kIdx.at(2) * channelMultiplier + kIdx.at(channelMultiplierDim);
+        Index outIdxK = tmp_out_index;
+        outIdxK.at(3) = kIdx.at(2) * channelMultiplier + kIdx.at(channelMultiplierDim);
         resAccessor.at(outIdxK) += in * b;
       }
     }
@@ -71,12 +85,12 @@ DepthwiseConv2D::DepthwiseConv2D(const TensorVariant &input, const DepthwiseConv
       _padding(op.getPaddingType()), _out_shape(op.getOutputShape(0)), _op(op)
 {
 
-  assert(_op.getInputShape(0).rank() == 3);
-  assert(input.getShape().rank() == 3);
+  assert(_op.getInputShape(0).rank() == 4);
+  assert(input.getShape().rank() == 4);
   assert(_kernel.getShape().rank() == 4);
   assert(_strides.dim(2) == 1);
   assert(_op.getPadding(2) == 0);
-  assert(_kernel.getShape().dim(2) == _input.getShape().dim(2));
+  assert(_kernel.getShape().dim(2) == _input.getShape().dim(3));
 }
 
 } // namespace nnc
index 60644e8..0948a79 100644 (file)
@@ -30,7 +30,8 @@ using namespace mir::ops;
 
 Pool::Pool(const TensorVariant &_input, const PoolOp &op) : _op(op), _input(_input)
 {
-  assert(op.getWindowShape().rank() == _input.getShape().rank());
+  assert(_input.getShape().rank() == 4);
+  assert(op.getWindowShape().rank() == 3);
 }
 
 std::vector<TensorVariant> Pool::operator()()
@@ -38,8 +39,9 @@ std::vector<TensorVariant> Pool::operator()()
   auto res = allocate_tensor(_op.getOutputShape(0));
   Tensor<float> resAccessor(res);
 
-  Shape strides({_op.getStrides().dim(0), _op.getStrides().dim(1), 1u});
-  Index pads({_op.getPadding(0), _op.getPadding(1), 0u});
+  Shape window_shape{1, _op.getWindowShape().dim(0), _op.getWindowShape().dim(1), 1};
+  Shape strides{1, _op.getStrides().dim(0), _op.getStrides().dim(1), 1};
+  Index pads{0, _op.getPadding(0), _op.getPadding(1), 0};
 
   const Shape &outShape = resAccessor.getShape();
   ShapeRange outRange(outShape);
@@ -67,7 +69,7 @@ std::vector<TensorVariant> Pool::operator()()
   {
     float out = initialValue;
     size_t avgDenominator = 0;
-    for (auto &kIdx : ShapeRange(_op.getWindowShape()))
+    for (auto& kIdx : ShapeRange(window_shape))
     {
       translate(inIdx, outIdx, kIdx, strides, pads);
 
index 8d74e89..51dc70c 100644 (file)
@@ -45,11 +45,19 @@ std::vector<TensorVariant> Conv2D::operator()()
   Index pads({_op.getPadding(0), _op.getPadding(1), 0u});
 
   Shape outShape = resAccesor.getShape();
+  // Assume batch size == 1 and strip it off.
+  assert(outShape.dim(0) == 1);
+  outShape = {outShape.dim(1), outShape.dim(2), outShape.dim(3)};
+
   outShape.dim(2) = 1;
   ShapeRange outRange(outShape);
 
-  const Shape& inShape = _input.getShape();
-  ShapeRange inRange(_input.getShape());
+  Shape inShape = _input.getShape();
+  // Assume batch size == 1 and strip it off.
+  assert(inShape.dim(0) == 1);
+  inShape = {inShape.dim(1), inShape.dim(2), inShape.dim(3)};
+
+  ShapeRange inRange(inShape);
 
   Shape kShape = _kernel.getShape();
   int32_t numKernels = kShape.dim(3);
@@ -61,19 +69,23 @@ std::vector<TensorVariant> Conv2D::operator()()
 
   for (auto &outIdx : outRange)
   {
+    // Take into account stripped off batch dimension.
+    Index tmp_out_index{0, outIdx.at(0), outIdx.at(1), outIdx.at(2)};
+
     for (auto& kernelIdx : kernelRange)
     {
       translate(inputIdx, outIdx, kernelIdx, _strides, pads);
-
       if (inRange.contains(inputIdx))
       {
         auto kernelRegion = _kernel.getRegion(kernelIdx);
         assert( kernelRegion.size() == numKernels );
 
-        auto outRegion = resAccesor.getRegion(outIdx);
+        auto outRegion = resAccesor.getRegion(tmp_out_index);
         assert( outRegion.size() == numKernels );
 
-        auto in = _input.at(inputIdx);
+        // Take into account stripped off batch dimension.
+        Index tmp_in_index{0, inputIdx.at(0), inputIdx.at(1), inputIdx.at(2)};
+        auto in = _input.at(tmp_in_index);
 
         for (int32_t kernelIndex = 0; kernelIndex < numKernels; ++kernelIndex)
         {
@@ -91,8 +103,8 @@ Conv2D::Conv2D(const TensorVariant &input, const Conv2DOp &op)
       _padding(op.getPaddingType()), _out_shape(op.getOutputShape(0)), _op(op)
 {
 
-  assert(_op.getInputShape(0).rank() == 3);
-  assert(input.getShape().rank() == 3);
+  assert(_op.getInputShape(0).rank() == 4);
+  assert(input.getShape().rank() == 4);
   assert(_kernel.getShape().rank() == 4);
   assert(_strides.dim(2) == 1);
   assert(_op.getPadding(2) == 0);
index 5d4e65e..a2f0cad 100644 (file)
@@ -20,7 +20,6 @@
 #include <iostream>
 
 #include "core/modelIR/IrDotDumper.h"
-#include "core/modelIR/ShapeInference.h"
 #include "core/modelIR/operations/ConstantOp.h"
 #include "core/modelIR/Operation.h"
 #include "core/modelIR/Shape.h"
@@ -150,10 +149,9 @@ void ONNXImporterImpl::createGraphInputs() {
       mir::Shape input_shape = ShapeHelper::createShape(onnx_tensor->dims(),
                                                    static_cast<size_t>(onnx_tensor->dims_size()));
       _inputTensors[name] = createTensor(onnx_tensor, input_shape);
-      auto constant = _graph->create<mir::ops::ConstantOp>(name, _inputTensors[name].get());
+      auto constant = _graph->create<mir::ops::ConstantOp>(name, *_inputTensors[name].get());
       _tensorNameToPrevMirOp[name] = constant;
       constants.insert(constant);
-      constant->setOutputShape(0, input_shape);
     } else {
       // We're dealing with graph input
       auto onnx_input_shape = input.type().tensor_type().shape();
@@ -162,17 +160,14 @@ void ONNXImporterImpl::createGraphInputs() {
         assert(onnx_input_shape.dim(i).has_dim_value());
         shape_vector[i] = onnx_input_shape.dim(i).dim_value();
       }
-      mir::Shape input_shape(shape_vector);
-      ShapeHelper::cutOffBatchDim(input_shape);
-      // TODO: Temporary solution! Assuming that every 4D input will be used for a convolution,
-      // so we change every 4D input from ONNX NCHW to Model IR HWC (batch is cut off earlier).
-      // TODO: Implement a more consistent way of handling shapes within the model.
-      // FIXME: it works for 2D pictures only
-      if (input_shape.rank() == 3)
-        input_shape = mir::Shape{input_shape.dim(1), input_shape.dim(2), input_shape.dim(0)};
+      mir::Shape shape(shape_vector);
+      // TODO For now we only support convolutional networks. The input data have already been
+      // transformed from ONNX NCHW format to ModelIR NHWC; reflect the changes in the IR.
+      if (shape.rank() == 4)
+        shape = mir::Shape{shape.dim(0), shape.dim(1), shape.dim(2), shape.dim(0)};
 
       // TODO: Temporary solution!
-      auto node = _graph->create<mir::ops::VariableOp>(name, input_shape);
+      auto node = _graph->create<mir::ops::VariableOp>(name, shape);
       _tensorNameToPrevMirOp[name] = node;
     }
   }
index 4a78076..fad06c6 100644 (file)
@@ -17,7 +17,6 @@
 #include "passes/soft_backend/BaseGenerator.h"
 #include "ModelAnalyzer.h"
 #include "SBSerializer.h"
-#include "core/modelIR/ShapeInference.h"
 #include "option/Options.h"
 #include "pass/Pass.h"
 #include "pass/PassData.h"
@@ -115,9 +114,6 @@ PassData BaseCodeGenerator::run(PassData data)
   auto g = static_cast<mir::Graph *>(data);
   assert(g);
 
-  // inference shapes
-  mir::ShapeInference si;
-  g->accept(&si);
   // visit and analyze graph
   ModelAnalyzer ma;
   ma.analyze(g);
index 5426fe5..c97a784 100644 (file)
@@ -169,7 +169,7 @@ void Serializer::visit(ops::Conv2DOp& op) {
   // serialize strides
   serializeShape(op.getStrides());
   // serialize pads
-  int32_t padsRank = op.getInputShape(0).rank();
+  int32_t padsRank = 2; // op.getInputShape(0).rank();
   serializePads(op, padsRank);
   // serialize output shape
   serializeShape(op.getOutputShape(0));
@@ -183,7 +183,7 @@ void Serializer::visit(ops::DepthwiseConv2DOp& op) {
   // serialize strides
   serializeShape(op.getStrides());
   // serialize pads
-  int32_t padsRank = kernel.getShape().rank();
+  int32_t padsRank = 2; // kernel.getShape().rank();
   serializePads(op, padsRank);
   // serialize output shape
   serializeShape(op.getOutputShape(0));
@@ -204,7 +204,7 @@ void Serializer::visit(ops::PoolOp& op) {
   // serialize strindes
   serializeShape(op.getStrides());
   // serialize pads
-  int32_t padsRank = windowShape.rank();
+  int32_t padsRank = 2; // windowShape.rank();
   serializePads(op, padsRank);
   // serialize border type
   PoolBorderType borderType;
@@ -313,7 +313,7 @@ void Serializer::visit(mir::ops::DeConv2DOp& op) {
   // serialize strides
   serializeShape(op.getStrides());
   // serialize pads
-  int32_t padsRank = op.getInputShape(0).rank();
+  int32_t padsRank = 2; // op.getInputShape(0).rank();
   serializePads(op, padsRank);
   // serialize output shape
   serializeShape(op.getOutputShape(0));
index ae190d7..b647401 100644 (file)
@@ -143,9 +143,10 @@ void TfliteImporter::walkSubGraph(const SubGraph* s) {
   for (auto i : *s->inputs()) {
     const Tensor* t = (*s->tensors())[i];
     Shape inputShape = ShapeHelper::createShape(*t->shape(), t->shape()->size());
-    // So far we assume that if the first dimension is equal to 1,
-    // then it is the batch dimension and should be ignored
-    ShapeHelper::cutOffBatchDim(inputShape);
+
+    // TODO Remove this limitation.
+    assert(inputShape.dim(0) == 1);
+
     auto node = _graph->create<mir::ops::VariableOp>(t->name()->c_str(), inputShape);
     _opsForTensorsTheyOutput[i] = node;
   }
index 0162a74..4b7bfd7 100644 (file)
@@ -53,12 +53,10 @@ void TFLiteOpCreator::checkConv2D(const Conv2DOptions* opts,
 }
 
 std::vector<mir::Operation*> TFLiteOpCreator::convertConv2D(InputOps inputs, InputParams params,
-                                                       const Conv2DOptions* opts) {
+                                                            const Conv2DOptions* opts) {
+  Shape strides{opts->stride_h(), opts->stride_w(), 1};
   auto outputs = createOp<ops::Conv2DOp>(ActivationFunctionType_NONE, inputs[0]->getOutput(0),
-                                         *params[0],
-                                         Shape{static_cast<int32_t>(opts->stride_h()),
-                                               static_cast<int32_t>(opts->stride_w()), 1},
-                                         paddingMap[opts->padding()]);
+                                         *params[0], strides, paddingMap[opts->padding()]);
   return createOp<ops::BiasAddOp>(opts->fused_activation_function(), outputs[0]->getOutput(0),
                                   *params[1]);
 }
@@ -68,14 +66,12 @@ void TFLiteOpCreator::checkDepthwiseConv2D(const DepthwiseConv2DOptions* opts,
   checkActivationType(opts->fused_activation_function(), problems_op_set);
 }
 
-std::vector<mir::Operation*> TFLiteOpCreator::convertDepthwiseConv2D(InputOps inputs,
-                                                                     InputParams params,
-                                                                     const DepthwiseConv2DOptions* opts) {
+std::vector<mir::Operation*>
+TFLiteOpCreator::convertDepthwiseConv2D(InputOps inputs, InputParams params,
+                                        const DepthwiseConv2DOptions* opts) {
+  Shape strides{opts->stride_h(), opts->stride_w(), 1};
   auto outputs = createOp<ops::DepthwiseConv2DOp>(ActivationFunctionType_NONE,
-                                                  inputs[0]->getOutput(0),
-                                                  *params[0],
-                                                  Shape{static_cast<int32_t>(opts->stride_h()),
-                                                        static_cast<int32_t>(opts->stride_w()), 1},
+                                                  inputs[0]->getOutput(0), *params[0], strides,
                                                   paddingMap[opts->padding()]);
   return createOp<ops::BiasAddOp>(opts->fused_activation_function(), outputs[0]->getOutput(0),
                                   *params[1]);
@@ -92,8 +88,8 @@ std::vector<mir::Operation*> TFLiteOpCreator::convertConcatenation(InputOps inpu
   std::vector<IODescriptor> descriptors;
   for (auto i : inputs)
     descriptors.push_back(i->getOutput(0));
-  // Decrementing axis to account for the unnecessary batch dimension
-  return createOp<ops::ConcatOp>(opts->fused_activation_function(), descriptors, opts->axis() - 1);
+
+  return createOp<ops::ConcatOp>(opts->fused_activation_function(), descriptors, opts->axis());
 }
 
 void TFLiteOpCreator::checkPool2D(const Pool2DOptions* opts,
@@ -103,31 +99,29 @@ void TFLiteOpCreator::checkPool2D(const Pool2DOptions* opts,
 
 std::vector<mir::Operation*> TFLiteOpCreator::convertMaxPool2D(InputOps inputs, InputParams params,
                                                                const Pool2DOptions* opts) {
+  Shape window_shape{opts->filter_height(), opts->filter_width(), 1};
+  Shape strides{opts->stride_h(), opts->stride_w(), 1};
   return createOp<ops::PoolOp>(opts->fused_activation_function(), inputs[0]->getOutput(0),
-                               Shape{static_cast<int32_t>(opts->filter_height()),
-                                     static_cast<int32_t>(opts->filter_width()), 1},
-                               Shape{static_cast<int32_t>(opts->stride_h()),
-                                     static_cast<int32_t>(opts->stride_w()), 1},
-                               ops::PoolOp::PoolingType::MAX, paddingMap[opts->padding()],
-                               ops::PoolOp::BorderType::EMPTY);
+                               window_shape, strides, ops::PoolOp::PoolingType::MAX,
+                               paddingMap[opts->padding()], ops::PoolOp::BorderType::EMPTY);
 }
 
 std::vector<mir::Operation*> TFLiteOpCreator::convertAveragePool2D(InputOps inputs,
                                                                    InputParams params,
                                                                    const Pool2DOptions* opts) {
+  Shape window_shape{opts->filter_height(), opts->filter_width(), 1};
+  Shape strides{opts->stride_h(), opts->stride_w(), 1};
   return createOp<ops::PoolOp>(opts->fused_activation_function(), inputs[0]->getOutput(0),
-                               Shape{static_cast<int32_t>(opts->filter_height()),
-                                     static_cast<int32_t>(opts->filter_width()), 1},
-                               Shape{static_cast<int32_t>(opts->stride_h()),
-                                     static_cast<int32_t>(opts->stride_w()), 1},
-                               ops::PoolOp::PoolingType::AVG, paddingMap[opts->padding()],
-                               ops::PoolOp::BorderType::EMPTY);
+                               window_shape, strides, ops::PoolOp::PoolingType::AVG,
+                               paddingMap[opts->padding()], ops::PoolOp::BorderType::EMPTY);
 }
 
 std::vector<mir::Operation*> TFLiteOpCreator::createSoftmax(InputOps inputs, InputParams params,
                                                             const SoftmaxOptions* opts) {
-  // -1 represents last one dimension
-  return createOp<ops::SoftmaxOp>(ActivationFunctionType_NONE, inputs[0]->getOutput(0), -1);
+  // Softmax in TFLite is always 2-D.
+  assert(inputs[0]->getOutputShape(0).rank() == 2);
+  int32_t axis = 1;
+  return createOp<ops::SoftmaxOp>(ActivationFunctionType_NONE, inputs[0]->getOutput(0), axis);
 }
 
 std::vector<mir::Operation*> TFLiteOpCreator::convertReshape(InputOps inputs, InputParams params,
@@ -144,29 +138,30 @@ std::vector<mir::Operation*> TFLiteOpCreator::convertReshape(InputOps inputs, In
 std::vector<mir::Operation*>
 TFLiteOpCreator::createTransposeConv(InputOps& inputs, InputParams& params,
                                      const ::tflite::TransposeConvOptions* opts) {
+  Shape strides{opts->stride_h(), opts->stride_w(), 1};
   return createOp<ops::DeConv2DOp>(ActivationFunctionType_NONE, inputs[0]->getOutput(0), *params[1],
-                                   Shape{static_cast<int32_t>(opts->stride_h()),
-                                         static_cast<int32_t>(opts->stride_w()), 1},
-                                   paddingMap[opts->padding()]);
+                                   strides, paddingMap[opts->padding()]);
 }
 
-std::vector<mir::Operation*> TFLiteOpCreator::convertResizeNN(
-  InputOps& inputs, InputParams& params,
-  const ::tflite::ResizeNearestNeighborOptions* opts) {
+std::vector<mir::Operation*>
+TFLiteOpCreator::convertResizeNN(InputOps& inputs, InputParams& params,
+                                 const ::tflite::ResizeNearestNeighborOptions* opts) {
   // TODO support aligned corners
   assert(!opts->align_corners() && "Aligned corners not currently supported");
 
+  auto& input_shape = inputs[0]->getOutputShape(0);
+  assert(input_shape.rank() == 4);
   mir::Tensor<int> out_shapes = mir::Tensor<int>(*params[0].get());
-  std::vector<int> res_shape;
-  for (const auto& i : mir::ShapeRange(out_shapes.getShape()))
-    res_shape.push_back(out_shapes.at(i));
-  res_shape.push_back(Shape::autoDim);
-  // assume no batch
+  Shape res_shape;
+  res_shape.resize(4);
+  res_shape.dim(0) = input_shape.dim(0);
+  res_shape.dim(1) = out_shapes.at(Index{0});
+  res_shape.dim(2) = out_shapes.at(Index{1});
+  res_shape.dim(3) = input_shape.dim(3);
   return createOp<ops::ResizeOp>(ActivationFunctionType_NONE, inputs[0]->getOutput(0),
-                                 ops::ResizeOp::ResizeMethod::nearestNeighbor, Shape(res_shape));
+                                 ops::ResizeOp::ResizeMethod::nearestNeighbor, res_shape);
 }
 
-
 std::vector<mir::Operation*>
 TFLiteOpCreator::createAdd(InputOps& inputs, InputParams&, const ::tflite::AddOptions* opts) {
   std::vector<IODescriptor> descriptors;
@@ -211,18 +206,12 @@ std::vector<mir::Operation*> TFLiteOpCreator::convertReducer(InputOps inputs, In
   auto tensor = mir::Tensor<int>(*params.at(0));
   std::vector<int32_t> axes;
 
-  // When batch is no longer being cut off, remove this:
-  int axis_correction = 0;
-  if (inputs[0]->getOutputShape(0).dim(0) != 1) {
-    axis_correction = 1;
-  }
-
   if (params.at(0)->getShape().rank() == 0) {
     // TODO: Dangerous black magic (Default construced Index is 0 dim, as is 0 dim Tensor)
-    axes.push_back(tensor.at(Index()) - axis_correction);
+    axes.push_back(tensor.at(Index()));
   } else {
     for (const auto& i: mir::ShapeRange(tensor.getShape())) {
-      axes.emplace_back(tensor.at(i) - axis_correction);
+      axes.emplace_back(tensor.at(i));
     }
   }
   return createOp<ops::ReduceFOp>(
index 011700a..f08185f 100644 (file)
@@ -29,7 +29,6 @@
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 
-#include "core/modelIR/ShapeInference.h"
 #include "passes/common_frontend/shape_helper.h"
 
 #include "op_info_generated.h"
@@ -153,9 +152,5 @@ std::unique_ptr<Graph> make_graph(const opinfo::OperatorInfo* opInfo) {
   // Mark outputs
   g->markOutput(opNode);
 
-  // Run shape inference
-  ShapeInference shapeInferencer;
-  g->accept(&shapeInferencer);
-
   return g;
 }
index 25c0b38..e79f3bb 100644 (file)
@@ -24,8 +24,6 @@
 #include "core/modelIR/operations/common.h"
 #include "core/modelIR/operations/PoolOp.h"
 
-#include "core/modelIR/ShapeInference.h"
-
 #include "op_info_generated.h"
 #include "passes/common_frontend/shape_helper.h"
 #include "graph_creator.h"
index ba49119..edf30f2 100644 (file)
@@ -33,7 +33,6 @@
 #include "core/modelIR/operations/ReluOp.h"
 #include "core/modelIR/operations/VariableOp.h"
 
-#include "core/modelIR/ShapeInference.h"
 #include "passes/soft_backend/CPPGenerator.h"
 
 // This header generated and contains array with test_main.def contents
@@ -52,9 +51,6 @@ void fillGraph(Graph &g)
   Operation* inputOp = g.create<ops::VariableOp>("in", inputShape);
   Operation* outputOp = g.create<ops::ReluOp>("out", inputOp->getOutput(0));
   g.markOutput(outputOp);
-
-  ShapeInference shapeInferencer;
-  g.accept(&shapeInferencer);
 }
 
 static void checkFileExists(const string &path)
index 02734e6..173b3fc 100644 (file)
@@ -32,7 +32,7 @@ public:
 TEST(Graph, ReplaceInputs) {
   auto g = new Graph;
 
-  auto n1 = g->create<ops::VariableOp>("op1", Shape{});
+  auto n1 = g->create<ops::VariableOp>("op1", Shape{1});
   auto n2 = g->create<ops::ReluOp>("op2", n1->getOutput(0));
   auto n3 = g->create<ops::ReluOp>("op3", n2->getOutput(0));
   auto n4 = g->create<ops::ReluOp>("op4", n2->getOutput(0));
@@ -40,7 +40,6 @@ TEST(Graph, ReplaceInputs) {
                                      std::vector<IODescriptor>{n3->getOutput(0), n4->getOutput(0)},
                                      0);
 
-  n4->setOutputShape(0, Shape{});
   g->replaceInputNodes({"op1", "op4"});
 
   std::stringstream ss;
@@ -58,7 +57,7 @@ TEST(Graph, ReplaceOutputs) {
 
   auto g = new Graph;
 
-  auto n1 = g->create<ops::VariableOp>("op1", Shape{});
+  auto n1 = g->create<ops::VariableOp>("op1", Shape{1});
   auto n2 = g->create<ops::ReluOp>("op2", n1->getOutput(0));
   auto n3 = g->create<ops::ReluOp>("op3", n2->getOutput(0));
   auto n4 = g->create<ops::ReluOp>("op4", n2->getOutput(0));
@@ -81,7 +80,6 @@ TEST(Graph, ReplaceOutputNodeWithInput) {
 
   g->markOutput(n2);
 
-  n2->setOutputShape(0, Shape{});
   auto in2 = g->replaceWithInputNode(n2);
 
   std::vector<Operation*> expectedInputs{in2, n1};
index e6a3eda..98a67fb 100644 (file)
@@ -15,7 +15,6 @@
  */
 
 #include "core/modelIR/Graph.h"
-#include "core/modelIR/ShapeInference.h"
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
@@ -30,7 +29,6 @@ using namespace nnc::mir;
 
 TEST(ShapeInferenceTest, ReshapeAutoDimension) {
   Graph g;
-  ShapeInference si;
 
   Shape input_shape{10, 2, 5};
   Shape expected_shape{10, 1, 10};
@@ -38,52 +36,40 @@ TEST(ShapeInferenceTest, ReshapeAutoDimension) {
 
   auto input = g.create<ops::VariableOp>("input", input_shape);
   auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{10, 1, Shape::autoDim});
-  op->setInputShape(0, input_shape);
-
-  si.visit(*dynamic_cast<ops::ReshapeOp*>(op));
 
   ASSERT_EQ(expected_shape, op->getOutputShape(0));
 }
 
 TEST(ShapeInferenceTest, ResizeWithShape) {
   Graph g;
-  ShapeInference si;
-
-  Shape result_shape{10, 10, 3};
 
-  auto input = g.create<ops::VariableOp>("input", Shape{5, 5, 3});
+  Shape result_shape{2, 10, 10, 3};
 
-  auto op = g.create<ops::ResizeOp>(
-    "Resize", input->getOutput(0), ops::ResizeOp::ResizeMethod::nearestNeighbor,
-    Shape{10, 10, Shape::autoDim}
-  );
+  auto input = g.create<ops::VariableOp>("input", Shape{1, 5, 5, 3});
 
-  g.accept(&si);
+  auto op = g.create<ops::ResizeOp>("Resize", input->getOutput(0),
+                                    ops::ResizeOp::ResizeMethod::nearestNeighbor, result_shape);
 
   ASSERT_EQ(result_shape, op->getOutputShape(0));
 }
 
 TEST(ShapeInferenceTest, ResizeWithScale) {
   Graph g;
-  ShapeInference si;
 
-  Shape result_shape{30, 10, 3};
+  Shape result_shape{1, 30, 10, 3};
 
-  auto input = g.create<ops::VariableOp>("input", Shape{5, 5, 3});
+  auto input = g.create<ops::VariableOp>("input", Shape{1, 5, 5, 3});
 
   auto op = g.create<ops::ResizeOp>(
     "Resize", input->getOutput(0), ops::ResizeOp::ResizeMethod::nearestNeighbor,
-    std::vector<float>{6, 2, 1}
+    std::vector<float>{1, 6, 2, 1}
   );
 
-  g.accept(&si);
-
   ASSERT_EQ(result_shape, op->getOutputShape(0));
 }
 
 TEST(ShapeInferenceTest, ReduceChangeRank) {
   Graph g;
-  ShapeInference si;
 
   Shape resultShape{10, 10};
 
@@ -91,31 +77,24 @@ TEST(ShapeInferenceTest, ReduceChangeRank) {
 
   auto n = g.create<ops::ReduceFOp>("reduce", input->getOutput(0), std::vector<int32_t>{1, 3},
                                     false, ops::ReduceFOp::FuncType::mean);
-  n->setInputShape(0, Shape{10, 2, 10, 9});
-
-  g.accept(&si);
-
+  
   ASSERT_EQ(resultShape, n->getOutputShape(0));
 }
 
 TEST(ShapeInferenceTest, ReshapeAutoDimensionShrink) {
   Graph g;
-  ShapeInference si;
 
   Shape input_shape{10, 2, 10};
   Shape result_shape_shrink{10, 20};
 
   auto input = g.create<ops::VariableOp>("input", input_shape);
   auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0), Shape{10, Shape::autoDim});
-  op->setInputShape(0, input_shape);
 
-  si.visit(*dynamic_cast<ops::ReshapeOp*>(op));
   ASSERT_EQ(result_shape_shrink, op->getOutputShape(0));
 }
 
 TEST(ShapeInferenceTest, ReshapeAutoDimensionExpand) {
   Graph g;
-  ShapeInference si;
 
   Shape input_shape{10, 2, 10};
   Shape result_shape_expand{5, 10, 2, 2};
@@ -123,15 +102,12 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionExpand) {
   auto input = g.create<ops::VariableOp>("input", input_shape);
   auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0),
                                      Shape{5, Shape::autoDim, 2, 2});
-  op->setInputShape(0, input_shape);
 
-  si.visit(*dynamic_cast<ops::ReshapeOp*>(op));
   ASSERT_EQ(result_shape_expand, op->getOutputShape(0));
 }
 
 TEST(ShapeInferenceTest, SqueezeTestAllDims) {
   Graph g;
-  ShapeInference si;
 
   Shape input_shape{1, 2, 1, 4};
   Shape expected_shape{2, 4};
@@ -139,14 +115,11 @@ TEST(ShapeInferenceTest, SqueezeTestAllDims) {
   auto input = g.create<ops::VariableOp>("input", input_shape);
   auto sq1 = g.create<ops::SqueezeOp>("squeeze_1", input->getOutput(0), std::vector<int32_t>{});
 
-  g.accept(&si);
-
   ASSERT_EQ(sq1->getOutputShape(0), expected_shape);
 }
 
 TEST(ShapeInferenceTest, SqueezeTestSpecificDims) {
   Graph g;
-  ShapeInference si;
 
   Shape input_shape{1, 2, 1, 4};
   Shape expected_shape{1, 2, 4};
@@ -154,14 +127,11 @@ TEST(ShapeInferenceTest, SqueezeTestSpecificDims) {
   auto input = g.create<ops::VariableOp>("input", input_shape);
   auto sq1 = g.create<ops::SqueezeOp>("squeeze_1", input->getOutput(0), std::vector<int32_t>{2});
 
-  g.accept(&si);
-
   ASSERT_EQ(sq1->getOutputShape(0), expected_shape);
 }
 
 TEST(ShapeInferenceTest, SqueezeTestScalarResult) {
   Graph g;
-  ShapeInference si;
 
   Shape input_shape{1, 1, 1, 1};
   Shape expected_shape{1};
@@ -169,7 +139,5 @@ TEST(ShapeInferenceTest, SqueezeTestScalarResult) {
   auto input = g.create<ops::VariableOp>("input", input_shape);
   auto sq1 = g.create<ops::SqueezeOp>("squeeze_1", input->getOutput(0), std::vector<int32_t>{});
 
-  g.accept(&si);
-
   ASSERT_EQ(sq1->getOutputShape(0), expected_shape);
 }
index d26d6b0..a22a41d 100644 (file)
@@ -38,50 +38,41 @@ TEST(Operation, ConnectionTest) {
 }
 
 TEST(Operation, InputOutputShapeTest) {
-  Shape inShape{1,2,3};
-  Shape outShape{3,2,1};
+  Shape input_shape{1,2,3};
 
-  ops::VariableOp input(Shape{});
+  ops::VariableOp input(input_shape);
   ops::SoftmaxOp op(input.getOutput(0), 0);
-  op.setInputShape(0, inShape);
-  op.setOutputShape(0, outShape);
 
-  ASSERT_EQ(inShape, op.getInputShape(0));
-  ASSERT_EQ(outShape, op.getOutputShape(0));
+  ASSERT_EQ(input_shape, input.getOutputShape(0));
+  ASSERT_EQ(input_shape, op.getInputShape(0));
 }
 
 TEST(Operation, SoftmaxAxisTest) {
-  Shape inShape{1,2,3};
+  Shape input_shape{1,2,3};
 
-  ops::VariableOp input(Shape{});
+  ops::VariableOp input(input_shape);
 
   ops::SoftmaxOp op_1(input.getOutput(0), 1);
-  op_1.setInputShape(0, inShape);
   ASSERT_EQ(op_1.getAxis(), 1);
 
   ops::SoftmaxOp op_n1(input.getOutput(0), -1);
-  op_n1.setInputShape(0, inShape);
   ASSERT_EQ(op_n1.getAxis(), 2);
 
   ops::SoftmaxOp op_n3(input.getOutput(0), -3);
-  op_n3.setInputShape(0, inShape);
   ASSERT_EQ(op_n3.getAxis(), 0);
 }
 
 TEST(Operation, ConcatAxisTest) {
   Shape inShape{1,2,3};
 
-  ops::VariableOp input1(Shape{}), input2(Shape{});
+  ops::VariableOp input1(inShape), input2(inShape);
 
   ops::ConcatOp op_1({input1.getOutput(0), input2.getOutput(0)}, 1);
-  op_1.setInputShape(0, inShape);
   ASSERT_EQ(op_1.getAxis(), 1);
 
   ops::ConcatOp op_n1({input1.getOutput(0), input2.getOutput(0)}, -1);
-  op_n1.setInputShape(0, inShape);
   ASSERT_EQ(op_n1.getAxis(), 2);
 
   ops::ConcatOp op_n3({input1.getOutput(0), input2.getOutput(0)}, -3);
-  op_n3.setInputShape(0, inShape);
   ASSERT_EQ(op_n3.getAxis(), 0);
 }
index 05c50ff..d871db1 100644 (file)
@@ -74,7 +74,6 @@
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/ShapeRange.h"
 
-#include "core/modelIR/ShapeInference.h"
 #include "passes/interpreter/Interpreter.h"
 
 #include "gtest/gtest.h"
@@ -119,9 +118,6 @@ mir::Operation* fillGraph(mir::Graph& g,
   // Mark outputs
   g.markOutput(op);
 
-  // Run shape inference
-  mir::ShapeInference shapeInferencer;
-  g.accept(&shapeInferencer);
   return op;
 }
 
@@ -490,7 +486,7 @@ TEST(cpp_operations_test, conv2d)
           for (iT strideH = 1; strideH <= 3; ++strideH)
             for (iT strideW = 1; strideW <= 3; ++strideW)
             {
-              vector<int> inputShapeData{5, 7, static_cast<int>(inputC)};  // HWC
+              vector<int> inputShapeData{1, 5, 7, static_cast<int>(inputC)};  // NHWC
               mir::Shape kernelShape{kernelH, kernelW, inputC, outputC}; // HWCN
               mir::Shape strides{strideH, strideW, 1};
               vector<unique_ptr<mir::TensorVariant>> inputNTensors(1);
@@ -522,7 +518,7 @@ TEST(cpp_operations_tests, depthwise_conv)
           for (iT strideH = 1; strideH <= 3; ++strideH)
             for (iT multiplier = 1; multiplier <= 2; ++multiplier)
             {
-              vector<int> inputShapeData{5, 7, static_cast<int>(channels)};  // HWC
+              vector<int> inputShapeData{1, 5, 7, static_cast<int>(channels)};  // NHWC
               mir::Shape kernelShape{kernelH, kernelW, channels, multiplier}; // HWCN
               mir::Shape strides{strideH, strideW, 1};
               vector<unique_ptr<mir::TensorVariant>> inputNTensors(1);
@@ -568,10 +564,10 @@ static void genericPoolTest(Func testFunc, const vector<irOps::PoolOp::BorderTyp
         for (iT strideH = 1; strideH <= 3; ++strideH)
           for (iT strideW = 1; strideW <= 3; ++strideW)
           {
-            vector<int> shapeData{5, 7, static_cast<int>(channels)};
+            vector<int> shapeData{1, 5, 7, static_cast<int>(channels)};
             mir::Shape windowShape{windowH, windowW, 1};
             mir::Shape strides{strideH, strideW, 1};
-            auto padT = irOps::PaddingType::Valid;
+            auto padT = irOps::PaddingType::Same;
             Tensor aInputTensor;
             vector<unique_ptr<mir::TensorVariant>> inputNTensors(1);
             fillTensors(inputNTensors[0], aInputTensor, shapeData, 1.0f);
index 12dfa08..9033c38 100644 (file)
@@ -18,7 +18,6 @@
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/operations/ReluOp.h"
 #include "core/modelIR/operations/ConcatOp.h"
-#include "core/modelIR/ShapeInference.h"
 
 #include <gtest/gtest.h>
 
@@ -54,9 +53,6 @@ TEST(ModelAnalyzer, linearization) {
                                                                               tail2->getOutput(0)},
                                             0);
 
-  ShapeInference si;
-  g.accept(&si);
-
   // Check that layout is desired
   ModelAnalyzer ma;
   ma.analyze(&g);
index 3dfca22..ba4e86a 100644 (file)
@@ -21,7 +21,6 @@
 #include "passes/caffe2_frontend/caffe2_importer.h"
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/IrDotDumper.h"
-#include "core/modelIR/ShapeInference.h"
 #include "pass/PassException.h"
 
 using namespace nnc;
@@ -37,9 +36,7 @@ int main(int argc, const char **argv) {
   try {
     importer.import();
     IrDotDumper dotDumper;
-    ShapeInference inf;
     auto g = static_cast<Graph *>(importer.createIR());
-    g->accept(&inf);
     g->accept(&dotDumper);
     dotDumper.writeDot(std::cout);
   }
index 7bffc6e..997d657 100644 (file)
@@ -21,7 +21,6 @@
 #include "passes/caffe_frontend/caffe_importer.h"
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/IrDotDumper.h"
-#include "core/modelIR/ShapeInference.h"
 #include "pass/PassException.h"
 
 using namespace nnc;
@@ -35,9 +34,7 @@ int main(int argc, const char **argv) {
   try {
     importer.import();
     IrDotDumper dotDumper;
-    ShapeInference inf;
     auto g = importer.createIR();
-    g->accept(&inf);
     g->accept(&dotDumper);
     dotDumper.writeDot(std::cout);
   }
index 31ac07a..481da52 100644 (file)
@@ -22,7 +22,6 @@
 #include "passes/tflite_frontend/tflite_importer.h"
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/IrDotDumper.h"
-#include "core/modelIR/ShapeInference.h"
 
 using namespace nnc;
 using namespace nnc::mir;
@@ -36,13 +35,10 @@ int main(int argc, const char **argv) {
   try {
     importer.import();
     IrDotDumper dotDumper;
-    mir::ShapeInference inf;
     auto g = importer.createIR();
-    g->accept(&inf);
     g->accept(&dotDumper);
     dotDumper.writeDot(std::cout);
-  }
-  catch (PassException &e) {
+  } catch (PassException &e) {
     std::cout << "Error: " << e.what() << std::endl;
     return -1;
   }