[nnc] Added reducemean to modelir (#2234)
authorАндрей Шедько/AI Tools Lab /SRR/Assistant Engineer/삼성전자 <a.shedko@partner.samsung.com>
Thu, 29 Nov 2018 11:43:29 +0000 (14:43 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 29 Nov 2018 11:43:29 +0000 (14:43 +0300)
- Added ReduceMean to ModelIR (Mean in tflite, not present in caffe)
The functionality allows one to support other reduction operations
rather effortlessly (in interpreter).

- Added tests for shape Inference.

Signed-off-by: Andrei Shedko <a.shedko@partner.samsung.com>
24 files changed:
contrib/nnc/core/modelIR/Index.cpp
contrib/nnc/core/modelIR/IrDotDumper.cpp
contrib/nnc/core/modelIR/Operation.cpp
contrib/nnc/core/modelIR/ShapeInference.cpp
contrib/nnc/include/core/modelIR/IrDotDumper.h
contrib/nnc/include/core/modelIR/ShapeInference.h
contrib/nnc/include/core/modelIR/operations/ReduceFOp.h [new file with mode: 0644]
contrib/nnc/include/core/modelIR/operations/operations.lst.h
contrib/nnc/include/passes/acl_soft_backend/AclCppOpGenerator.h
contrib/nnc/include/passes/interpreter/Interpreter.h
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/ops/DeConv2D.h
contrib/nnc/passes/interpreter/ops/Reduce.h
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.h
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/SBSerializer.h
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.h
contrib/nnc/unittests/core/ShapeInference.cpp
contrib/nnc/unittests/tflite_frontend/CMakeLists.txt
contrib/nnc/unittests/tflite_frontend/test_data/gen_test.py

index d6df7b5..36e0921 100644 (file)
 
 #include <algorithm>
 
-namespace nnc
-{
-namespace mir
-{
+namespace nnc {
+namespace mir {
 
-Index::Index(std::initializer_list<int32_t>&& l) : _indices{l} {}
-
-Index::Index(std::vector<int32_t>&& vec) : _indices(vec) {}
+Index::Index(std::initializer_list<int32_t>&& l) : _indices{l} {
+  // DO NOTHING
+}
 
 int32_t Index::rank(void) const { return _indices.size(); }
-Index &Index::resize(int32_t size)
-{
+
+Index& Index::resize(int32_t size) {
   _indices.resize(size);
   return *this;
 }
 
-Index &Index::fill(int32_t index)
-{
+Index& Index::fill(int32_t index) {
   std::fill(_indices.begin(), _indices.end(), index);
   return (*this);
 }
 
 int32_t &Index::at(int32_t axis) { return _indices[(axis < 0) ? (_indices.size() + axis) : axis]; }
-int32_t Index::at(int32_t axis) const { return _indices[(axis < 0) ? (_indices.size() + axis) : axis]; }
+int32_t Index::at(int32_t axis) const {
+  return _indices[(axis < 0) ? (_indices.size() + axis) : axis];
+}
 
-std::ostream &operator<<(std::ostream &s, const Index &sh)
-{
+std::ostream& operator<<(std::ostream& s, const Index& sh) {
   s << "[ ";
-  for (int32_t i = 0; i < sh.rank(); ++i)
-  {
-    if (i != 0 )
+  for (int32_t i = 0; i < sh.rank(); ++i) {
+    if (i != 0)
       s << ", ";
     s << sh.at(i);
   }
index df6c0ab..558dc05 100644 (file)
@@ -232,6 +232,17 @@ void mir::IrDotDumper::visit(ops::PadOp& op) {
   dotBuilder.updateWithOp(&op, node_info);
 }
 
+void IrDotDumper::visit(ops::ReduceFOp& op) {
+  auto node_info = DotIrNodeInfo().withType("ReduceFOp", op.getName())
+    .withInShapes(getInputShapes(op))
+    .withOutShapes(getOutputShapes(op))
+    .withShape("Reduction dims", Shape(op.getReductionDims())) // appropriated shape to dims
+    .withMisc("Keep Dims", op.getKeepDims())
+    .withMisc("OPType", (float) op.getFuncType());
+
+  dotBuilder.updateWithOp(&op, node_info);
+}
+
 void IrDotDumper::visit(ops::ResizeOp& op) {
   auto node_info = DotIrNodeInfo().withType("Resize", op.getName())
     .withInShapes(getInputShapes(op))
@@ -244,3 +255,4 @@ void IrDotDumper::visit(ops::ResizeOp& op) {
 
 } // namespace mir
 } // namespace nnc
+
index 8936bca..dc501ed 100644 (file)
@@ -37,8 +37,7 @@
 #include "core/modelIR/operations/SqueezeOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/PadOp.h"
-
-#include <cassert>
+#include "core/modelIR/operations/ReduceFOp.h"
 
 namespace nnc {
 namespace mir {
@@ -84,6 +83,8 @@ void Operation::accept(IVisitor* v) {
       break;
 #include "core/modelIR/operations/operations.lst.h"
 #undef HANDLE_OP
+    default:
+      assert(false && "OP not defined!");
   }
 }
 
index 1252404..34f48ca 100644 (file)
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
 #include "core/modelIR/operations/PadOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 
-namespace nnc
-{
-namespace mir
-{
+namespace nnc {
+namespace mir {
 
 using nnc::mir::Shape;
 
-template<class Op>
-void fillHWShapesForPaddedOperations(Op& op, const Shape &windowShape, Shape &outShape)
-{
-  auto &strides = op.getStrides();
-  auto &inShape = op.getInputShape(0);
+template <class Op>
+void fillHWShapesForPaddedOperations(Op& op, const Shape& windowShape, Shape& outShape) {
+  auto& strides = op.getStrides();
+  auto& inShape = op.getInputShape(0);
   auto inRank = inShape.rank();
   outShape.resize(inRank);
 
   ops::PaddingType pType = op.getPaddingType();
-  switch (pType)
-  {
-  case ops::PaddingType::Same:
-    for (int32_t d = 0; d < inRank - 1; ++d)
-    {
-      outShape.dim(d) = (inShape.dim(d) - 1) / strides.dim(d) + 1;
-      int pad_along_axis;
-      if (inShape.dim(d) % strides.dim(d) == 0)
-      {
-        pad_along_axis = std::max((int)windowShape.dim(d) - (int)strides.dim(d), 0);
+  switch (pType) {
+    case ops::PaddingType::Same:
+      for (int32_t d = 0; d < inRank - 1; ++d) {
+        outShape.dim(d) = (inShape.dim(d) - 1) / strides.dim(d) + 1;
+        int pad_along_axis;
+        if (inShape.dim(d) % strides.dim(d) == 0) {
+          pad_along_axis = std::max(( int ) windowShape.dim(d) - ( int ) strides.dim(d), 0);
+        } else {
+          pad_along_axis = std::max(( int ) (outShape.dim(d) - 1) * ( int ) strides.dim(d) +
+                                    ( int ) windowShape.dim(d) - ( int ) inShape.dim(d),
+                                    0);
+        }
+        op.setPadding(d, pad_along_axis / 2);
       }
-      else
-      {
-        pad_along_axis = std::max((int)(outShape.dim(d) - 1) * (int)strides.dim(d) +
-                                  (int)windowShape.dim(d) - (int)inShape.dim(d),
-                                  0);
+      break;
+    case ops::PaddingType::Valid:
+      for (int32_t d = 0; d < inRank - 1; ++d) {
+        op.setPadding(d, 0);
       }
-      op.setPadding(d, pad_along_axis / 2);
-    }
-    break;
-  case ops::PaddingType::Valid:
-    for (int32_t d = 0; d < inRank - 1; ++d)
-    {
-      op.setPadding(d, 0);
-    }
-    // FALLTHROUGH
-  case ops::PaddingType::Custom:
-    for (int32_t d = 0; d < inRank - 1; ++d)
-    {
-      outShape.dim(d) = (inShape.dim(d) + 2*op.getPadding(d) - windowShape.dim(d)) / strides.dim(d) + 1;
-    }
-    break;
-  default:
-    assert(false && "invalid padding type");
-    break;
+      // FALLTHROUGH
+    case ops::PaddingType::Custom:
+      for (int32_t d = 0; d < inRank - 1; ++d) {
+        outShape.dim(d) =
+          (inShape.dim(d) + 2 * op.getPadding(d) - windowShape.dim(d)) / strides.dim(d) + 1;
+      }
+      break;
+    default:
+      assert(false && "invalid padding type");
+      break;
   }
   // For now padding for channels is not supported, initialize it with zero
   op.setPadding(inRank - 1, 0);
@@ -105,14 +97,12 @@ void ShapeInference::visit(ops::ConcatOp& op) {
   Shape outShape;
   outShape.resize(op.getInputShape(0).rank());
 
-  for (int32_t d = 0; d < outShape.rank(); ++d)
-  {
+  for (int32_t d = 0; d < outShape.rank(); ++d) {
     outShape.dim(d) = op.getInputShape(0).dim(d);
   }
   outShape.dim(axis) = 0;
 
-  for (size_t i = 0; i < op.getNumInputs(); ++i)
-  {
+  for (size_t i = 0; i < op.getNumInputs(); ++i) {
     outShape.dim(axis) += op.getInputShape(i).dim(axis);
   }
 
@@ -123,8 +113,8 @@ void ShapeInference::visit(ops::Conv2DOp& op) {
   fillInputShapes(op);
 
   Shape outShape;
-  auto &kernel = op.getKernel();
-  auto &kernelShape = kernel.getShape();
+  autokernel = op.getKernel();
+  autokernelShape = kernel.getShape();
 
   fillHWShapesForPaddedOperations(op, kernelShape, outShape);
 
@@ -141,9 +131,8 @@ void ShapeInference::visit(ops::ConstantOp&) {
 
 void ShapeInference::fillInputShapes(Operation& op) {
   size_t i = 0;
-  for (auto &in : op.getPrevNodes())
-  {
-    const Shape &inShape = in.op->getOutputShape(in.index);
+  for (auto& in : op.getPrevNodes()) {
+    const Shape& inShape = in.op->getOutputShape(in.index);
     op.setInputShape(i++, inShape);
   }
 }
@@ -188,8 +177,8 @@ void ShapeInference::visit(ops::PoolOp& op) {
   fillInputShapes(op);
 
   Shape outShape;
-  auto &windowShape = op.getWindowShape();
-  auto &inShape = op.getInputShape(0);
+  autowindowShape = op.getWindowShape();
+  autoinShape = op.getInputShape(0);
   const int32_t inRank = inShape.rank();
   // Assuming input tensor is 3-dimensional. Will support more general cases when needed.
   assert(inRank == 3);
@@ -202,17 +191,16 @@ void ShapeInference::visit(ops::PoolOp& op) {
 
 void ShapeInference::visit(ops::FullyConnectedOp& op) {
   fillInputShapes(op);
-  const Shape &inShape = op.getInputShape(0);
-  const Shape &wShape = op.getWeights().getShape();
+  const ShapeinShape = op.getInputShape(0);
+  const ShapewShape = op.getWeights().getShape();
   const int32_t weightsRank = wShape.rank();
   const int32_t inRank = inShape.rank();
 
   assert(weightsRank >= 2);
   assert(inRank == weightsRank);
   assert(inShape.dim(inRank - 1) == wShape.dim(weightsRank - 2));
-  (void)inRank;
-  for (int32_t i = 0; i < weightsRank - 2; ++i)
-  {
+  ( void ) inRank;
+  for (int32_t i = 0; i < weightsRank - 2; ++i) {
     assert(wShape.dim(i) == inShape.dim(i));
   }
 
@@ -231,8 +219,8 @@ void ShapeInference::visit(ops::DepthwiseConv2DOp& op) {
   fillInputShapes(op);
 
   Shape outShape;
-  auto &kernelShape = op.getKernel().getShape();
-  auto &inShape = op.getInputShape(0);
+  autokernelShape = op.getKernel().getShape();
+  autoinShape = op.getInputShape(0);
   int inRank = inShape.rank();
   int kernelRank = kernelShape.rank();
 
@@ -260,14 +248,14 @@ void ShapeInference::visit(ops::ReshapeOp& op) {
   auto inElementsNum = inShape.numElements();
   int32_t outElementsNum = 1;
   //can't use num_elements due to -1 in input shape and Shape using unsigned ints for dimensions
-  for( int32_t d = 0; d < outShape.rank(); ++d ) {
+  for (int32_t d = 0; d < outShape.rank(); ++d) {
     auto dim = outShape.dim(d);
     if( dim != Shape::autoDim) {
       outElementsNum *= dim;
     }
   }
 
-  for( int32_t d = 0; d < outShape.rank(); ++d ) {
+  for (int32_t d = 0; d < outShape.rank(); ++d) {
     auto& dim = outShape.dim(d);
     if( dim == Shape::autoDim ) {
       dim = static_cast<int32_t>(inElementsNum / outElementsNum);
@@ -304,8 +292,8 @@ void ShapeInference::visit(ops::DeConv2DOp& op) {
 
   Shape out_shape;
   Shape in_shape = op.getInputShape(0);
-  auto &kernel = op.getKernel();
-  auto &kernel_shape = kernel.getShape();
+  autokernel = op.getKernel();
+  autokernel_shape = kernel.getShape();
 
   assert(kernel_shape.rank() == 4);
   assert(in_shape.rank() == 3);
@@ -318,18 +306,19 @@ void ShapeInference::visit(ops::DeConv2DOp& op) {
 
   switch (pad_type) {
     case ops::PaddingType::Same:
-      for (int32_t d = 0;d < in_rank; ++d) {
+      for (int32_t d = 0; d < in_rank; ++d) {
         out_shape.dim(d) = in_shape.dim(d) * strides.dim(d) + 1 - strides.dim(d);
       }
       break;
     case ops::PaddingType::Valid:
-      for (int32_t d = 0;d < in_rank; ++d) {
+      for (int32_t d = 0; d < in_rank; ++d) {
         out_shape.dim(d) = in_shape.dim(d) * strides.dim(d) + kernel_shape.dim(d) - strides.dim(d);
       }
       break;
     case ops::PaddingType::Custom:
       for (int32_t d = 0; d < in_rank - 1; ++d) {
-        out_shape.dim(d) = ( in_shape.dim(d) -1 )* strides.dim(d) - 2 * op.getPadding(d) + kernel_shape.dim(d);
+        out_shape.dim(d) =
+          (in_shape.dim(d) - 1) * strides.dim(d) - 2 * op.getPadding(d) + kernel_shape.dim(d);
       }
       break;
     default: {
@@ -426,5 +415,34 @@ void ShapeInference::visit(ops::PadOp& op) {
   op.setOutputShape(0, out_shape);
 }
 
+void ShapeInference::visit(ops::ReduceFOp& op) {
+  fillInputShapes(op);
+  assert(op.getNumInputs() == 1);
+
+  const auto& input_shape = op.getInputShape(0);
+  const auto& red_dims = op.getReductionDims();
+  Shape output_shape;
+  if (op.getKeepDims()) {
+    output_shape = input_shape;
+    for (auto red_axis: red_dims) {
+      output_shape.dim(red_axis) = 1;
+    }
+  } else {
+    std::vector<int32_t> out_dims;
+    out_dims.reserve(input_shape.rank() - op.getReductionDims().size());
+    auto red_axis = red_dims.begin();
+    for (int32_t axis_id = 0; axis_id < input_shape.rank(); axis_id++) {
+      if (axis_id == (*red_axis)) {
+        red_axis++;
+      } else {
+        out_dims.emplace_back(input_shape.dim(axis_id));
+      }
+    }
+    output_shape = Shape(out_dims);
+  }
+
+  op.setOutputShape(0, output_shape);
+}
+
 } // namespace mir
 } // namespace nnc
index 829c321..2ab5c3b 100644 (file)
@@ -40,6 +40,7 @@
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
 #include "core/modelIR/operations/PadOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 
 #include "core/modelIR/ir_dot_builder.h"
 
@@ -76,6 +77,7 @@ public:
   void visit(ops::ElementwiseOp& op) override;
   void visit(ops::SqueezeOp& op) override;
   void visit(ops::PadOp& op) override;
+  void visit(ops::ReduceFOp& op) override;
 
   void writeDot(std::ostream &os) { dotBuilder.writeDot(os); };
 
index e30746f..5138455 100644 (file)
@@ -50,6 +50,7 @@ public:
   void visit(ops::EluOp& op) override;
   void visit(ops::SqueezeOp& op) override;
   void visit(ops::PadOp& op) override;
+  void visit(ops::ReduceFOp& op) override;
 
 protected:
   void fillInputShapes(Operation& op);
diff --git a/contrib/nnc/include/core/modelIR/operations/ReduceFOp.h b/contrib/nnc/include/core/modelIR/operations/ReduceFOp.h
new file mode 100644 (file)
index 0000000..0e722a7
--- /dev/null
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_IR_MODEL_REDUCE_F_H_
+#define _NNC_CORE_IR_MODEL_REDUCE_F_H_
+
+#include "core/modelIR/Operation.h"
+#include <vector>
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+class ReduceFOp : public Operation {
+public:
+  enum class FuncType {
+    mean, //TODO add other reducers
+  };
+
+  /**
+   * @brief Reduces with (a,b) -> a + b / n where n is the size of dimension(s) being reduced
+   * @param reduce_dims vector of ints denoting reduction dimensions. assume it is sorted
+   * @param keepDims whether to keep the original rank
+   * @param fT function to reduce the tensor with (should be associative)
+   */
+  explicit ReduceFOp(const IODescriptor& arg,
+                     const std::vector<int32_t>& reduce_dims, bool keepDims,
+                     FuncType fT) :
+    Operation(Type::reduceFOp, {arg}), _reduceDims(reduce_dims),
+    _keepDims(keepDims), _fT(fT) {};
+
+  const std::vector<int32_t>& getReductionDims() { return _reduceDims; };
+
+  bool getKeepDims() const { return _keepDims; };
+
+  FuncType getFuncType() const { return _fT; };
+private:
+  std::vector<int32_t> _reduceDims;
+  bool _keepDims;
+  FuncType _fT;
+
+};
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
+
+#endif //_NNC_CORE_IR_MODEL_REDUCE_F_H_
index b1455b6..c09a1f5 100644 (file)
@@ -40,3 +40,4 @@ HANDLE_OP(deConv2D, DeConv2DOp)
 HANDLE_OP(ELU, EluOp)
 HANDLE_OP(squeeze, SqueezeOp)
 HANDLE_OP(pad, PadOp)
+HANDLE_OP(reduceFOp, ReduceFOp)
index 906bf93..53ea6f6 100644 (file)
@@ -69,6 +69,7 @@ public:
   void visit(mir::ops::EluOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
   void visit(mir::ops::PadOp& op) override;
+  void visit(mir::ops::ReduceFOp& op) override;
 
 private:
   using AF = ArtifactFactory;
index ed973ee..ccdb301 100644 (file)
@@ -58,6 +58,7 @@ public:
   void visit(ops::EluOp& op) override;
   void visit(ops::SqueezeOp& op) override;
   void visit(ops::PadOp& op) override;
+  void visit(ops::ReduceFOp& op) override;
 
   void setInput(const std::string &name, const TensorVariant& data);
   std::vector<TensorVariant> &getResult(Operation* op);
index 0e8439c..47c7cc9 100644 (file)
@@ -24,6 +24,7 @@
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/Deconv2DOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 
 #include <algorithm>
 
@@ -791,5 +792,9 @@ void AclCppOpGenerator::visit(mir::ops::ResizeOp& op) {
   assert(false && "Unimplemented operation: Resize");
 }
 
+void AclCppOpGenerator::visit(mir::ops::ReduceFOp& op) {
+  assert(false && "Unimplemented operation: ReduceFOp");
+}
+
 }
 // namespace nnc
index cf0378b..eb0ed0e 100644 (file)
@@ -31,6 +31,7 @@
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/VariableOp.h"
 #include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 #include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/ConcatOp.h"
@@ -48,7 +49,6 @@
 #include "ops/conv_2D.h"
 #include "ops/DeConv2D.h"
 #include "ops/Depthwise_conv_2D.h"
-#include "ops/Elementwise.h"
 #include "ops/FullyConnected.h"
 #include "ops/Pool.h"
 #include "ops/Reshape.h"
@@ -311,4 +311,31 @@ void NNInterpreter::visit(ops::ResizeOp& op) {
 
 }
 
+void NNInterpreter::visit(ops::ReduceFOp& op) {
+  mapByName(&op);
+  // should always be an integer in a float
+  const float reduction_area =
+    static_cast<float>(op.getInputShape(0).numElements() / op.getOutputShape(0).numElements());
+
+  auto operand = op.getPrevNodes()[0];
+  auto& input = var(operand.op->getId())[operand.index];
+
+  std::function<float(float, float)> func;
+  switch (op.getFuncType()) {
+    case ops::ReduceFOp::FuncType::mean: {
+      func = [](float running_sum, float item) { return running_sum + item; };
+      var(op.getId()) = ReduceN<float>(op.getInputShape(0),
+                                       op.getOutputShape(0), input, op.getReductionDims(), func)();
+      Tensor<float> out_t = Tensor<float>(var(op.getId())[0]); // for numerical  stability
+      var(op.getId()) = Fill<float>(op.getOutputShape(0),
+                                    [&out_t, &op, reduction_area](const Index& id) {
+                                      return out_t.at(id) / reduction_area;
+                                    })();
+    }
+      break;
+    default:
+      assert(false && "Not Implemented");
+  }
+}
+
 } // namespace nnc
index 0c7443e..d15a25e 100644 (file)
@@ -45,7 +45,6 @@ private:
   const mir::ops::PaddingType _padding;
   const mir::Shape &_out_shape;
   const mir::ops::DeConv2DOp &_op;
-
 };
 
 } // namespace nnc
index d5485f9..f4d2d96 100644 (file)
 
 #include "core/modelIR/Shape.h"
 #include "core/modelIR/Tensor.h"
+#include "core/modelIR/ShapeRange.h"
 
 #include "OperationImpl.h"
 #include "Fill.h"
 
+namespace nnc {
 
-namespace nnc
-{
-
-template <typename T> class Reduce : public OperationImpl<T>
-{
+template <typename T> class Reduce : public OperationImpl<T> {
 public:
-  Reduce(const mir::Shape &inputShape, const mir::Shape &outputShape, const mir::TensorVariant &input, int32_t axis,
-         std::function<T(const T &, const T &)> reduceFunc)
-      : _inShape(inputShape), _outputShape(outputShape), _input(input), _axis(axis),
-        _reduceFunc(reduceFunc)
-  {
+  Reduce(const mir::Shape& inputShape, const mir::Shape& outputShape,
+         const mir::TensorVariant& input, int32_t axis,
+         std::function<T(const T&, const T&)> reduceFunc)
+    : _inShape(inputShape), _outputShape(outputShape), _input(input), _axis(axis),
+      _reduceFunc(reduceFunc) {
     assert(outputShape.dim(axis) == 1);
   }
 
-  std::vector<mir::TensorVariant> operator()() override
-  {
-    return Fill<T>(_outputShape, [this](const mir::Index &id) {
+  std::vector<mir::TensorVariant> operator()() override {
+    return Fill<T>(_outputShape, [this](const mir::Index& id) {
       T element = T();
       mir::Index inputId = id;
       int32_t end = _inShape.dim(_axis);
-      for (int32_t i = 0; i < end; ++i)
-      {
+      for (int32_t i = 0; i < end; ++i) {
         inputId.at(_axis) = i;
         element = _reduceFunc(element, _input.at(inputId));
       }
@@ -56,13 +52,74 @@ public:
   }
 
 private:
-  const mir::Shape &_inShape;
-  const mir::Shape &_outputShape;
+  const mir::Shape_inShape;
+  const mir::Shape_outputShape;
   const mir::Tensor<T> _input;
   const int32_t _axis;
   const std::function<T(T, T)> _reduceFunc;
 };
 
+template <typename T> class ReduceN : public OperationImpl<T> {
+public:
+  /**
+   * @brief Reduces a tensor to output shape
+   * @param inputShape
+   * @param outputShape
+   * @param input Stores the values
+   * @param reductionDims vector of dims to reduce to 1
+   * @param reduceFunc function to reduce the tensor with (should be associative)
+   */
+  ReduceN(const mir::Shape& inputShape, const mir::Shape& outputShape,
+          const mir::TensorVariant& input,
+          std::vector<int32_t> reductionDims, std::function<T(const T&, const T&)> reduceFunc)
+    : _inShape(inputShape), _outputShape(outputShape), _input(input), _reductionDims(reductionDims),
+      _reduceFunc(reduceFunc) {
+    if (inputShape.rank() == outputShape.rank()) {
+      for (auto axis: reductionDims) {
+        assert(outputShape.dim(axis) == 1);
+      }
+      _keepDims = true;
+    }
+  }
+
+  std::vector<mir::TensorVariant> operator()() override {
+    auto res = this->allocate_tensor(_outputShape);
+    mir::Tensor<T> res_accesor(res);
+
+    mir::Index out_id;
+    out_id.resize(_outputShape.rank());
+    for (const mir::Index& input_id : mir::ShapeRange(_inShape)) {
+      int32_t out_idx_id = 0;
+      int32_t red_dim = 0;
+      // change out id to point to the correct cell
+      for (int d = 0; d != _inShape.rank(); ++d) {
+        if (d == _reductionDims[red_dim]) {
+          red_dim++;
+          if (_keepDims)
+            out_id.at(out_idx_id++) = 0;
+          else
+            continue;
+        } else {
+          out_id.at(out_idx_id++) = input_id.at(d);
+        }
+      }
+      res_accesor.at(out_id) = _reduceFunc(res_accesor.at(out_id), _input.at(input_id));
+    }
+
+    return {res};
+  }
+
+private:
+
+  const mir::Shape& _inShape;
+  const mir::Shape& _outputShape;
+  const mir::Tensor<T> _input;
+  const std::vector<int32_t> _reductionDims;
+  const std::function<T(T, T)> _reduceFunc;
+  bool _keepDims = false;
+};
+
+
 } // namespace nnc
 
 #endif //_NNC_CORE_BACKEND_INTERPRETER_REDUCE_IMPL_
index c475fbe..48aa237 100644 (file)
@@ -44,6 +44,7 @@
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/VariableOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 
 using namespace std;
 
@@ -286,4 +287,9 @@ void ModelAnalyzer::visit(mir::ops::PadOp& op) {
   assert(false && "Not implemented yet");
 }
 
+void ModelAnalyzer::visit(mir::ops::ReduceFOp& op) {
+  assert(false && "Not implemented yet");
+  addOpDescr(&op, "ReduceMean");
+}
+
 } // namespace nnc
index b528011..fb1fcf3 100644 (file)
@@ -111,6 +111,7 @@ public:
   void visit(mir::ops::EluOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
   void visit(mir::ops::PadOp& op) override;
+  void visit(mir::ops::ReduceFOp& op) override;
 
   /**
    * @return vector of id's of network input tensors
index 9391d0f..5426fe5 100644 (file)
@@ -40,6 +40,7 @@
 #include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 
 #include "pass/PassException.h"
 #include <algorithm>
@@ -331,4 +332,11 @@ void Serializer::visit(mir::ops::ResizeOp& op) {
   throw PassException("Not implemented yet");
 }
 
+void Serializer::visit(mir::ops::ReduceFOp& op) {
+  _curOp->_paramStartOffset = _buffer.size();
+  serializeShape(Shape(op.getReductionDims())); // reuse shape serialization
+  serializeT<int32_t>(op.getKeepDims());
+  serializeShape(op.getOutputShape(0));
+}
+
 } // namespace nnc
index 6c44fde..6d01a27 100644 (file)
@@ -63,6 +63,7 @@ public:
   void visit(mir::ops::EluOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
   void visit(mir::ops::PadOp& op) override;
+  void visit(mir::ops::ReduceFOp& op) override;
 
   void serialize(std::list<OpDescr> &inferenceSequence);
 
index 0269df9..ae190d7 100644 (file)
@@ -97,6 +97,7 @@ void TfliteImporter::processUnsupportedOp(const Operator* op) {
     case BuiltinOperator_PAD:
     case BuiltinOperator_ADD:
     case BuiltinOperator_MUL:
+    case BuiltinOperator_MEAN:
     case BuiltinOperator_MAXIMUM:
     case BuiltinOperator_DIV:
     case BuiltinOperator_TRANSPOSE_CONV:
@@ -187,6 +188,10 @@ void TfliteImporter::walkOperator(const Operator* op) {
       outputs = _opCreator->convertResizeNN(inputs, params,
         op->builtin_options_as<ResizeNearestNeighborOptions>());
       break;
+    case BuiltinOperator_MEAN:
+      outputs = _opCreator->convertReducer(inputs, params,ops::ReduceFOp::FuncType::mean,
+       op->builtin_options_as<ReducerOptions>());
+      break;
     case BuiltinOperator_FULLY_CONNECTED:
       outputs = _opCreator->convertFullyConnected(inputs, params,
                                                   op->builtin_options_as<FullyConnectedOptions>());
index 566f72f..0162a74 100644 (file)
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/Deconv2DOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
 #include "core/modelIR/operations/PadOp.h"
 #include "core/modelIR/Tensor.h"
+#include "core/modelIR/ShapeRange.h"
 #include "pass/PassException.h"
 
 #include "core/modelIR/Tensor.h"
@@ -202,6 +204,32 @@ TFLiteOpCreator::createMax(InputOps& inputs, InputParams&,
                                       ops::ElementwiseOp::OpType::max);
 }
 
+std::vector<mir::Operation*> TFLiteOpCreator::convertReducer(InputOps inputs, InputParams params,
+                                                             ops::ReduceFOp::FuncType ft,
+                                                             const ::tflite::ReducerOptions* opts) {
+  assert(params.at(0)->getShape().rank() <= 1 && "Must be 1-dim or 0-dim tensor");
+  auto tensor = mir::Tensor<int>(*params.at(0));
+  std::vector<int32_t> axes;
+
+  // When batch is no longer being cut off, remove this:
+  int axis_correction = 0;
+  if (inputs[0]->getOutputShape(0).dim(0) != 1) {
+    axis_correction = 1;
+  }
+
+  if (params.at(0)->getShape().rank() == 0) {
+    // TODO: Dangerous black magic (Default construced Index is 0 dim, as is 0 dim Tensor)
+    axes.push_back(tensor.at(Index()) - axis_correction);
+  } else {
+    for (const auto& i: mir::ShapeRange(tensor.getShape())) {
+      axes.emplace_back(tensor.at(i) - axis_correction);
+    }
+  }
+  return createOp<ops::ReduceFOp>(
+    ActivationFunctionType_NONE, inputs[0]->getOutput(0),
+    axes, opts->keep_dims(), ft);
+}
+
 void TFLiteOpCreator::checkFullyConnected(const FullyConnectedOptions* opts,
                                           std::set<std::string>& problems_op_set) {
   checkActivationType(opts->fused_activation_function(), problems_op_set);
index 5b8a14d..5ada473 100644 (file)
@@ -29,6 +29,7 @@
 #include "core/modelIR/Shape.h"
 
 #include "core/modelIR/operations/common.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 
 #include "schema_generated.h"
 #include "passes/common_frontend/shape_helper.h"
@@ -61,6 +62,9 @@ public:
   std::vector<mir::Operation*> convertAveragePool2D(InputOps, InputParams,
                                                     const ::tflite::Pool2DOptions*);
 
+  std::vector<mir::Operation*> convertReducer(InputOps, InputParams, ops::ReduceFOp::FuncType,
+                                        const ::tflite::ReducerOptions*);
+
   std::vector<mir::Operation*> createSoftmax(InputOps, InputParams, const ::tflite::SoftmaxOptions*);
 
   std::vector<mir::Operation*> convertReshape(InputOps, InputParams,
index 924b1a0..e6a3eda 100644 (file)
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
+#include "core/modelIR/Shape.h"
+
+#include <vector>
 
 #include "gtest/gtest.h"
 
@@ -77,6 +81,23 @@ TEST(ShapeInferenceTest, ResizeWithScale) {
   ASSERT_EQ(result_shape, op->getOutputShape(0));
 }
 
+TEST(ShapeInferenceTest, ReduceChangeRank) {
+  Graph g;
+  ShapeInference si;
+
+  Shape resultShape{10, 10};
+
+  auto input = g.create<ops::VariableOp>("input", Shape{10, 2, 10, 9});
+
+  auto n = g.create<ops::ReduceFOp>("reduce", input->getOutput(0), std::vector<int32_t>{1, 3},
+                                    false, ops::ReduceFOp::FuncType::mean);
+  n->setInputShape(0, Shape{10, 2, 10, 9});
+
+  g.accept(&si);
+
+  ASSERT_EQ(resultShape, n->getOutputShape(0));
+}
+
 TEST(ShapeInferenceTest, ReshapeAutoDimensionShrink) {
   Graph g;
   ShapeInference si;
index 8e47a7b..d5f413a 100644 (file)
@@ -2,13 +2,14 @@ file(GLOB_RECURSE TESTS "*.cpp")
 
 #Feature detect:
 execute_process(
-        COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test_data/gen_test.py
-        WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/test_data/
+        COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test_data/gen_test.py ${CMAKE_CURRENT_BINARY_DIR}
+        OUTPUT_VARIABLE outp
+        WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
         RESULT_VARIABLE test_create_failed
 )
 
 if (NNC_FRONTEND_TFLITE_ENABLED AND NOT ${test_create_failed})
-  add_definitions(-DTFLITE_TEST_DIR="${CMAKE_CURRENT_SOURCE_DIR}/test_data/")
+  add_definitions(-DTFLITE_TEST_DIR="${CMAKE_CURRENT_BINARY_DIR}/")
   add_nnc_unit_test(nnc_tflite_frontend_test ${TESTS} ${OPTIONS_SRC})
   if (TARGET nnc_tflite_frontend_test)
     nncc_target_link_libraries(nnc_tflite_frontend_test tflite_import nnc_support nnc_core )
index 30032a2..43a2486 100755 (executable)
@@ -1,6 +1,14 @@
 #!/usr/bin/python3
-import tensorflow as tf
 import numpy as np
+import sys
+try:
+    import tensorflow as tf
+except:
+    print("!! Tensorflow not installed, tflite frontend test not generated", file=sys.stderr)
+    exit(999)
+
+resDir = sys.argv[1]
+if resDir[-1]!="/": resDir +="/"
 
 output_shape = [1, 28, 28, 1]
 strides = [1,1,1,1]
@@ -13,7 +21,6 @@ Y = tf.sin(X)
 out0 = tf.identity(Y, name="out")
 # Filter the input image.
 with tf.Session() as sess:
-    print('Evaluating...')
     out = sess.run(out0, feed_dict = {"input:0": np.ones((1, 28, 28, 1)).astype(np.float32)})
     # print(sess.graph_def)
 
@@ -22,4 +29,4 @@ with tf.Session() as sess:
     tflite_model = tf.contrib.lite.TocoConverter(
         frozen_graphdef, [X], [out0]).convert()
 
-    open("unsupported.tflite", "wb").write(tflite_model)
+    open(resDir+"unsupported.tflite", "wb").write(tflite_model)