[nnc] Add squeeze operation support (#2138)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Fri, 9 Nov 2018 14:12:49 +0000 (17:12 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Fri, 9 Nov 2018 14:12:49 +0000 (17:12 +0300)
This commit adds squeeze operation and ShapeInference support for it

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
19 files changed:
contrib/nnc/core/modelIR/IrDotDumper.cpp
contrib/nnc/core/modelIR/ShapeInference.cpp
contrib/nnc/include/core/modelIR/IrDotDumper.h
contrib/nnc/include/core/modelIR/ShapeInference.h
contrib/nnc/include/core/modelIR/operations/SqueezeOp.h [new file with mode: 0644]
contrib/nnc/include/core/modelIR/operations/operations.lst.h
contrib/nnc/include/passes/acl_soft_backend/AclCppOpGenerator.h
contrib/nnc/include/passes/interpreter/Interpreter.h
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/ops/Reshape.h
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.h
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/SBSerializer.h
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.h
contrib/nnc/unittests/core/ShapeInference.cpp

index 96e7419..8f4ff30 100644 (file)
 #include <iostream>
 #include "core/modelIR/IrDotDumper.h"
 
-#include "core/modelIR/Shape.h"
-#include "core/modelIR/ir_node.h"
-
-#include "core/modelIR/ir_dot_node_info.h"
-
 namespace nnc
 {
 namespace mir
@@ -226,5 +221,17 @@ void mir::IrDotDumper::visit(INode *node, ops::ElementwiseOp &op) {
   dotBuilder.updateWithNode(node, nodeInfo);
 }
 
+void IrDotDumper::visit(INode* node, ops::SqueezeOp& op) {
+  auto node_info = DotIrNodeInfo().withType("SqueezeOp", node->getName())
+    .withInShapes(getInputShapes(op))
+    .withOutShapes(getOutputShapes(op));
+
+  for (auto dim : op.getDimsToSqueeze()) {
+    node_info.withMisc("SqueezeDim", dim);
+  }
+
+  dotBuilder.updateWithNode(node, node_info);
+}
+
 } // namespace mir
 } // namespace nnc
index 5b720a1..605cb32 100644 (file)
@@ -15,6 +15,8 @@
  */
 
 #include <cmath>
+#include <algorithm>
+
 #include "core/modelIR/ShapeInference.h"
 
 #include "core/modelIR/operations/FullyConnectedOp.h"
@@ -35,6 +37,7 @@
 #include "core/modelIR/operations/DropoutOp.h"
 #include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/SqueezeOp.h"
 
 namespace nnc
 {
@@ -336,5 +339,53 @@ void ShapeInference::visit(INode::Ref node, ops::ElementwiseOp &op) {
   op.setOutputShape(0, op.getInputShape(0));
 }
 
+void ShapeInference::visit(INode* node, ops::SqueezeOp& op) {
+  fillInputShapes(node, op);
+  assert(op.getNumInputs() == 1);
+
+  const auto& input_shape = op.getInputShape(0);
+  int32_t input_rank = input_shape.rank();
+  Shape output_shape;
+  int32_t output_rank = 0;
+
+  std::vector<int32_t> dims_to_squeeze;
+
+  if (op.getNumSqueezeDims() == 0) {
+    for (int32_t i = 0; i < input_rank; ++i) {
+      if (input_shape.dim(i) == 1) {
+        dims_to_squeeze.push_back(i);
+      }
+    }
+  } else {
+    dims_to_squeeze = op.getDimsToSqueeze();
+    std::sort(dims_to_squeeze.begin(), dims_to_squeeze.end());
+    dims_to_squeeze.erase(
+      std::unique(dims_to_squeeze.begin(), dims_to_squeeze.end()),
+      dims_to_squeeze.end()
+    );
+  }
+
+  if (dims_to_squeeze.size() == static_cast<size_t>(input_rank)) {
+    //Input shape have 1s in all dimensions, output shape is (1,)
+    op.setOutputShape(0, Shape{1});
+    return;
+  }
+
+  size_t squeezing_idx = 0;
+  output_shape.resize(input_rank - dims_to_squeeze.size());
+  for (int32_t i = 0; i < input_rank; ++i) {
+    if (squeezing_idx < dims_to_squeeze.size() && i == dims_to_squeeze[squeezing_idx]) {
+      if (input_shape.dim(i) != 1)
+        throw std::invalid_argument("All squeezed dimensions should have size 1");
+
+      squeezing_idx++;
+    } else {
+      output_shape.dim(output_rank++) = input_shape.dim(i);
+    }
+  }
+
+  op.setOutputShape(0, output_shape);
+}
+
 } // namespace mir
 } // namespace nnc
index e9006b6..7f93ca2 100644 (file)
@@ -37,6 +37,7 @@
 #include "core/modelIR/operations/DropoutOp.h"
 #include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/SqueezeOp.h"
 
 #include "core/modelIR/ir_dot_builder.h"
 
@@ -70,6 +71,7 @@ public:
   void visit(INode *node, ops::EluOp &op) override;
   void visit(INode *node, ops::TanhOp &op) override;
   void visit(INode *node, ops::ElementwiseOp &op) override;
+  void visit(INode* node, ops::SqueezeOp& op) override;
 
   void writeDot(std::ostream &os) { dotBuilder.writeDot(os); };
 
index a1cc09b..f1176a4 100644 (file)
@@ -46,6 +46,7 @@ class ShapeInference : public IVisitor {
   void visit(INode::Ref node, ops::ElementwiseOp &op) override;
   void visit(INode::Ref node, ops::DeConv2DOp &op) override;
   void visit(INode::Ref node, ops::EluOp &op) override;
+  void visit(INode* node, ops::SqueezeOp& op) override;
 
 protected:
   void fillInputShapes(INode::Ref node, OpDescription &op);
diff --git a/contrib/nnc/include/core/modelIR/operations/SqueezeOp.h b/contrib/nnc/include/core/modelIR/operations/SqueezeOp.h
new file mode 100644 (file)
index 0000000..7e187b6
--- /dev/null
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_IR_MODEL_SQUEEZE_OP_H_
+#define _NNC_CORE_IR_MODEL_SQUEEZE_OP_H_
+
+#include "core/modelIR/operations/operation.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+class SqueezeOp : public OpDescription {
+public:
+  explicit SqueezeOp(const std::vector<int32_t>& dims_to_squeeze) :
+    OpDescription(1, 1),
+    _dims_to_squeeze(dims_to_squeeze) {}
+
+  int32_t getNumSqueezeDims() {
+    return static_cast<int32_t>(_dims_to_squeeze.size());
+  }
+
+  const std::vector<int32_t>& getDimsToSqueeze() {
+    return _dims_to_squeeze;
+  }
+
+private:
+  std::vector<int32_t> _dims_to_squeeze;
+};
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
+
+#endif // _NNC_CORE_IR_MODEL_SQUEEZE_OP_H_
index 7c57f80..216f246 100644 (file)
@@ -66,6 +66,7 @@ public:
   void visit(mir::INode* node, mir::ops::ElementwiseOp& op) override;
   void visit(mir::INode* node, mir::ops::DeConv2DOp& op) override;
   void visit(mir::INode* node, mir::ops::EluOp& op) override;
+  void visit(mir::INode* node, mir::ops::SqueezeOp& op) override;
 
 private:
   using AF = ArtifactFactory;
index 3d6c68e..36b255e 100644 (file)
@@ -55,6 +55,7 @@ public:
   void visit(INode::Ref node, ops::ElementwiseOp &op) override;
   void visit(INode::Ref node, ops::DeConv2DOp &op) override;
   void visit(INode::Ref node, ops::EluOp &op) override;
+  void visit(INode* node, ops::SqueezeOp& op) override;
 
   void setInput(const std::string &name, const TensorVariant& data);
   std::vector<TensorVariant> &getResult(INode::Ref node);
index 5c34833..2d5ca59 100644 (file)
@@ -652,5 +652,9 @@ void AclCppOpGenerator::serializeTensor(const TensorVariant& tensor) {
     _parOut.write(tensor.at(idx), sizeof(float));
 }
 
+void AclCppOpGenerator::visit(INode* node, ops::SqueezeOp& op) {
+  assert(false && "Unimplemented operation: Squeeze");
+}
+
 }
 // namespace nnc
index 0b40b7e..430c803 100644 (file)
@@ -37,6 +37,7 @@
 #include "core/modelIR/operations/DropoutOp.h"
 #include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/SqueezeOp.h"
 
 #include "ops/Bias.h"
 #include "ops/Concat.h"
@@ -118,7 +119,7 @@ void NNInterpreter::visit(INode::Ref node, ops::ReshapeOp &op)
   mapByName(node);
   auto operand = node->getPrevNodes()[0];
   auto input = var(operand.node->getId())[operand.index];
-  var(node->getId()) = Reshape<float>(input, op)();
+  var(node->getId()) = Reshape<float>(input, op.getOutputShape(0))();
 }
 
 void NNInterpreter::visit(INode::Ref node, ops::ReluOp &op)
@@ -275,4 +276,12 @@ void NNInterpreter::visit(INode::Ref node, ops::EluOp &op) {
   })();
 }
 
+void NNInterpreter::visit(INode* node, ops::SqueezeOp& op) {
+  mapByName(node);
+  auto operand = node->getPrevNodes()[0];
+  auto& input = var(operand.node->getId())[operand.index];
+  //Squeeze is just a special case of reshape
+  var(node->getId()) = Reshape<float>(input, op.getOutputShape(0))();
+}
+
 } // namespace nnc
index 574aad3..5bd0b8e 100644 (file)
 #include "OperationImpl.h"
 #include "Fill.h"
 
-namespace nnc
-{
+namespace nnc {
 
-template <typename T> class Reshape : public OperationImpl<T>
-{
+template <typename T>
+class Reshape : public OperationImpl<T> {
 public:
-  Reshape(const mir::TensorVariant &input, const mir::ops::ReshapeOp &op) : _input(input), _op(op)
-  {
-    assert(_op.getInputShape(0).numElements() == _op.getOutputShape(0).numElements());
-  }
-
-  std::vector<mir::TensorVariant> operator()() override
-  {
-    const mir::Shape &outShape = _op.getOutputShape(0);
-    const mir::Shape &inShape = _op.getInputShape(0);
+  Reshape(const mir::TensorVariant& input, const mir::Shape& output_shape)
+    : _input(input),
+    _output_shape(output_shape) {
 
-    mir::ShapeRange inRange(inShape);
-    mir::ShapeRange outRange(outShape);
+    assert(input.getShape().numElements() == _output_shape.numElements());
+  }
 
+  std::vector<mir::TensorVariant> operator()() override {
+    mir::ShapeRange inRange(_input.getShape());
     auto inIter = inRange.begin();
 
-    auto out = OperationImpl<T>::allocate_tensor(outShape);
+    auto out = OperationImpl<T>::allocate_tensor(_output_shape);
 
     mir::Tensor<float> outAccessor(out);
     // Shapes element count compared in Reshape ctor
-    return Fill<T>(outShape, [this, &inIter](const mir::Index &) -> float { return _input.at(*inIter++); })();
+    return Fill<T>(_output_shape, [this, &inIter](const mir::Index&) -> T { return _input.at(*inIter++); })();
   }
 
 private:
   mir::Tensor<T> _input;
-  const mir::ops::ReshapeOp &_op;
+  const mir::Shape& _output_shape;
 };
 
 } // namespace nnc
index 6fe7c52..0f6bc34 100644 (file)
@@ -218,4 +218,8 @@ void ModelAnalyzer::visit(mir::INode *node, mir::ops::EluOp &op) {
   addOpDescr(node, "EluOp");
 }
 
+void ModelAnalyzer::visit(INode* node, ops::SqueezeOp& op) {
+  addOpDescr(node, "reshape");
+}
+
 } // namespace nnc
index e6eb983..00d71a6 100644 (file)
@@ -99,6 +99,7 @@ public:
   void visit(mir::INode *node, mir::ops::ElementwiseOp &op) override;
   void visit(mir::INode *node, mir::ops::DeConv2DOp &op) override;
   void visit(mir::INode *node, mir::ops::EluOp &op) override;
+  void visit(mir::INode* node, mir::ops::SqueezeOp& op) override;
 
   /**
    * @return vector of id's of network input tensors
index 943cde5..e2788b9 100644 (file)
@@ -37,6 +37,7 @@
 #include "core/modelIR/operations/DropoutOp.h"
 #include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/SqueezeOp.h"
 #include "core/modelIR/ir_node.h"
 
 #include "pass/PassException.h"
@@ -327,4 +328,9 @@ void Serializer::visit(mir::INode *node, mir::ops::DeConv2DOp &op) {
   serializeShape(op.getOutputShape(0));
 }
 
+void Serializer::visit(INode* node, ops::SqueezeOp& op) {
+  _curOp->_paramStartOffset = _buffer.size();
+  serializeShape(op.getOutputShape(0));
+}
+
 } // namespace nnc
index 3062ea3..7773b6f 100644 (file)
@@ -60,6 +60,7 @@ public:
   void visit(mir::INode *node, mir::ops::ElementwiseOp &op) override;
   void visit(mir::INode *node, mir::ops::DeConv2DOp &op) override;
   void visit(mir::INode *node, mir::ops::EluOp &op) override;
+  void visit(mir::INode* node, mir::ops::SqueezeOp& op) override;
 
   void serialize(std::list<OpDescr> &inferenceSequence);
 
index f82a06b..05af8d6 100644 (file)
@@ -92,6 +92,7 @@ void TfliteImporter::processUnsupportedOp(const Operator* op) {
       break;
     case BuiltinOperator_SOFTMAX:
     case BuiltinOperator_RESHAPE:
+    case BuiltinOperator_SQUEEZE:
       // No checks
       break;
     default:
@@ -181,6 +182,9 @@ void TfliteImporter::walkOperator(const Operator* op) {
     case BuiltinOperator_SOFTMAX:
       outputs = _opCreator->createSoftmax(inputs, params, op->builtin_options_as<SoftmaxOptions>());
       break;
+    case BuiltinOperator_SQUEEZE:
+      outputs = _opCreator->createSqueeze(inputs, params, op->builtin_options_as<SqueezeOptions>());
+      break;
     default:
       assert(false && "All unsupported types should have been found before this pass.");
   }
index 99b78f6..7f90dc8 100644 (file)
@@ -27,6 +27,7 @@
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
+#include "core/modelIR/operations/SqueezeOp.h"
 #include "pass/PassException.h"
 
 using namespace nnc::mir;
@@ -183,4 +184,12 @@ void TFLiteOpCreator::connectInputs(INode::Ref op, std::vector<INode::Ref>& inpu
     op->connectInputTo(i, inputs[i]->getOutput(0));
 }
 
+std::vector<INode*> TFLiteOpCreator::createSqueeze(InputOps inputs, InputParams params,
+                                                   const ::tflite::SqueezeOptions* opts) {
+
+  std::vector<int32_t> squeeze_dims{opts->squeeze_dims()->begin(), opts->squeeze_dims()->end()};
+
+  return createOp<ops::SqueezeOp>(inputs, ActivationFunctionType_NONE, squeeze_dims);
+}
+
 } // namespace nnc
index a399a17..502ec53 100644 (file)
@@ -68,6 +68,8 @@ public:
   std::vector<INode::Ref> convertFullyConnected(InputOps, InputParams,
                                                 const ::tflite::FullyConnectedOptions*);
 
+  std::vector<INode*> createSqueeze(InputOps& inputs, InputParams& params, const ::tflite::SqueezeOptions* opts);
+
   void checkPool2D(const ::tflite::Pool2DOptions*, std::set<std::string>&);
 
   void checkConcatenation(const ::tflite::ConcatenationOptions*, std::set<std::string>&);
index 3f06e2a..6497990 100644 (file)
@@ -17,6 +17,7 @@
 #include "core/modelIR/graph.h"
 #include "core/modelIR/ShapeInference.h"
 #include "core/modelIR/operations/ReshapeOp.h"
+#include "core/modelIR/operations/SqueezeOp.h"
 #include "core/modelIR/Shape.h"
 
 #include "gtest/gtest.h"
@@ -68,3 +69,59 @@ TEST(ShapeInferenceTest, ReshapeAutoDimensionVaryRank) {
   si.visit(n, *static_cast<ops::ReshapeOp*>(n->getOperation()));
   ASSERT_EQ(resultShapeExpand, n->getOperation()->getOutputShape(0));
 }
+
+TEST(ShapeInferenceTest, SqueezeTestAllDims) {
+  Graph g;
+  ShapeInference si;
+
+  Shape input_shape{1, 2, 1, 4};
+  Shape expected_shape{2, 4};
+
+  auto input = g.create<ops::VariableOp>("input");
+  input->getOperation()->setOutputShape(0, input_shape);
+
+  auto sq1 = g.create<ops::SqueezeOp>("squeeze_1", std::vector<int32_t>{});
+  sq1->connectInputTo(0, input->getOutput(0));
+
+  g.accept(&si);
+
+  ASSERT_EQ(sq1->getOperation()->getOutputShape(0), expected_shape);
+}
+
+TEST(ShapeInferenceTest, SqueezeTestSpecificDims) {
+  Graph g;
+  ShapeInference si;
+
+  Shape input_shape{1, 2, 1, 4};
+  Shape expected_shape{1, 2, 4};
+
+  auto input = g.create<ops::VariableOp>("input");
+  input->getOperation()->setOutputShape(0, input_shape);
+
+
+  auto sq1 = g.create<ops::SqueezeOp>("squeeze_1", std::vector<int32_t>{2});
+  sq1->connectInputTo(0, input->getOutput(0));
+
+  g.accept(&si);
+
+  ASSERT_EQ(sq1->getOperation()->getOutputShape(0), expected_shape);
+}
+
+TEST(ShapeInferenceTest, SqueezeTestScalarResult) {
+  Graph g;
+  ShapeInference si;
+
+  Shape input_shape{1, 1, 1, 1};
+  Shape expected_shape{1};
+
+  auto input = g.create<ops::VariableOp>("input");
+  input->getOperation()->setOutputShape(0, input_shape);
+
+
+  auto sq1 = g.create<ops::SqueezeOp>("squeeze_1", std::vector<int32_t>{});
+  sq1->connectInputTo(0, input->getOutput(0));
+
+  g.accept(&si);
+
+  ASSERT_EQ(sq1->getOperation()->getOutputShape(0), expected_shape);
+}