[nnc] Introduce Gather operation (#2626)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Wed, 12 Dec 2018 10:00:58 +0000 (13:00 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 12 Dec 2018 10:00:58 +0000 (13:00 +0300)
Add GatherOp class to modelIR.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
15 files changed:
contrib/nnc/core/CMakeLists.txt
contrib/nnc/core/modelIR/IrDotDumper.cpp
contrib/nnc/core/modelIR/Operation.cpp
contrib/nnc/core/modelIR/operations/GatherOp.cpp [new file with mode: 0644]
contrib/nnc/include/core/modelIR/IrDotDumper.h
contrib/nnc/include/core/modelIR/operations/GatherOp.h [new file with mode: 0644]
contrib/nnc/include/core/modelIR/operations/operations.lst.h
contrib/nnc/include/passes/interpreter/Interpreter.h
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.h
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.h
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/SBSerializer.h

index cc4f450..52aba7f 100644 (file)
@@ -3,6 +3,7 @@ set(SOURCES "modelIR/operations/ConcatOp.cpp"
             "modelIR/operations/DeConv2DOp.cpp"
             "modelIR/operations/DepthwiseConv2DOp.cpp"
             "modelIR/operations/FullyConnectedOp.cpp"
+            "modelIR/operations/GatherOp.cpp"
             "modelIR/operations/PadOp.cpp"
             "modelIR/operations/PoolOp.cpp"
             "modelIR/operations/SqueezeOp.cpp"
index b43961c..c0ad967 100644 (file)
 #include <iostream>
 
 #include "core/modelIR/IrDotDumper.h"
+#include "core/modelIR/operations/BatchNormOp.h"
+#include "core/modelIR/operations/BiasAddOp.h"
+#include "core/modelIR/operations/CappedReluOp.h"
+#include "core/modelIR/operations/ConcatOp.h"
+#include "core/modelIR/operations/ConstantOp.h"
+#include "core/modelIR/operations/Conv2DOp.h"
+#include "core/modelIR/operations/Deconv2DOp.h"
+#include "core/modelIR/operations/DepthwiseConv2DOp.h"
+#include "core/modelIR/operations/DropoutOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/EluOp.h"
+#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GatherOp.h"
+#include "core/modelIR/operations/PadOp.h"
+#include "core/modelIR/operations/PoolOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
+#include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ReshapeOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
+#include "core/modelIR/operations/ScaleOp.h"
+#include "core/modelIR/operations/SoftmaxOp.h"
+#include "core/modelIR/operations/SqueezeOp.h"
+#include "core/modelIR/operations/TanhOp.h"
+#include "core/modelIR/operations/TransposeOp.h"
+#include "core/modelIR/operations/VariableOp.h"
 
 namespace nnc {
 namespace mir {
@@ -264,6 +289,10 @@ void IrDotDumper::visit(ops::TransposeOp& op) {
   dotBuilder.updateWithOp(&op, node_info);
 }
 
+void IrDotDumper::visit(ops::GatherOp& op) {
+  auto node_info = DotIrNodeInfo().withType("GatherOp", op.getName());
+}
+
 } // namespace mir
 } // namespace nnc
 
index 4952d32..2005ea6 100644 (file)
  */
 
 #include "core/modelIR/Operation.h"
-#include "core/modelIR/operations/FullyConnectedOp.h"
-#include "core/modelIR/operations/SoftmaxOp.h"
+#include "core/modelIR/operations/BatchNormOp.h"
+#include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/CappedReluOp.h"
-#include "core/modelIR/operations/DepthwiseConv2DOp.h"
+#include "core/modelIR/operations/ConcatOp.h"
 #include "core/modelIR/operations/ConstantOp.h"
 #include "core/modelIR/operations/Conv2DOp.h"
 #include "core/modelIR/operations/Deconv2DOp.h"
+#include "core/modelIR/operations/DepthwiseConv2DOp.h"
+#include "core/modelIR/operations/DropoutOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/EluOp.h"
+#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GatherOp.h"
+#include "core/modelIR/operations/PadOp.h"
 #include "core/modelIR/operations/PoolOp.h"
-#include "core/modelIR/operations/VariableOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 #include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/ResizeOp.h"
-#include "core/modelIR/operations/EluOp.h"
-#include "core/modelIR/operations/ConcatOp.h"
-#include "core/modelIR/operations/BiasAddOp.h"
-#include "core/modelIR/operations/BatchNormOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
-#include "core/modelIR/operations/DropoutOp.h"
-#include "core/modelIR/operations/TanhOp.h"
-#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
-#include "core/modelIR/operations/ReshapeOp.h"
-#include "core/modelIR/operations/PadOp.h"
-#include "core/modelIR/operations/ReduceFOp.h"
+#include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/TransposeOp.h"
+#include "core/modelIR/operations/VariableOp.h"
 
 namespace nnc {
 namespace mir {
diff --git a/contrib/nnc/core/modelIR/operations/GatherOp.cpp b/contrib/nnc/core/modelIR/operations/GatherOp.cpp
new file mode 100644 (file)
index 0000000..e0d1bae
--- /dev/null
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/GatherOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+void GatherOp::inferOutputShapes() {
+  const auto& data_shape = getInputShape(0);
+  const auto& indices_shape = getInputShape(1);
+
+  auto data_rank = data_shape.rank();
+  auto indices_rank = indices_shape.rank();
+  auto output_rank = data_rank + indices_rank - 1;
+
+  assert(_axis >= -data_rank && _axis < data_rank);
+  int32_t axis = _axis < 0 ? _axis + data_rank : _axis;
+
+  Shape output_shape;
+  output_shape.resize(output_rank);
+
+  // Output shape is data.shape[:axis] + indices.shape + data.shape[axis + 1:].
+  int32_t output_index = 0;
+  for (int32_t i = 0; i < axis; ++i)
+    output_shape.dim(output_index++) = data_shape.dim(i);
+  for (int32_t i = 0; i < indices_rank; ++i)
+    output_shape.dim(output_index++) = indices_shape.dim(i);
+  for (int32_t i = axis + 1; i < data_rank; ++i)
+    output_shape.dim(output_index++) = data_shape.dim(i);
+
+  setOutputShape(0, output_shape);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
index cc54ab3..f5a907a 100644 (file)
 #define _NNC_BACKEND_INTERPRETER_CORE_DOTDUMPER_
 
 #include "core/modelIR/Visitor.h"
-#include "core/modelIR/operations/FullyConnectedOp.h"
-#include "core/modelIR/operations/SoftmaxOp.h"
-#include "core/modelIR/operations/CappedReluOp.h"
-#include "core/modelIR/operations/ConstantOp.h"
-#include "core/modelIR/operations/DepthwiseConv2DOp.h"
-#include "core/modelIR/operations/Conv2DOp.h"
-#include "core/modelIR/operations/Deconv2DOp.h"
-#include "core/modelIR/operations/PoolOp.h"
-#include "core/modelIR/operations/VariableOp.h"
-#include "core/modelIR/operations/ReluOp.h"
-#include "core/modelIR/operations/EluOp.h"
-#include "core/modelIR/operations/ConcatOp.h"
-#include "core/modelIR/operations/BiasAddOp.h"
-#include "core/modelIR/operations/ReshapeOp.h"
-#include "core/modelIR/operations/ResizeOp.h"
-#include "core/modelIR/operations/BatchNormOp.h"
-#include "core/modelIR/operations/ScaleOp.h"
-#include "core/modelIR/operations/DropoutOp.h"
-#include "core/modelIR/operations/TanhOp.h"
-#include "core/modelIR/operations/ElementwiseOp.h"
-#include "core/modelIR/operations/SqueezeOp.h"
-#include "core/modelIR/operations/PadOp.h"
-#include "core/modelIR/operations/ReduceFOp.h"
-#include "core/modelIR/operations/TransposeOp.h"
 
 #include "core/modelIR/ir_dot_builder.h"
 
@@ -56,30 +32,32 @@ namespace mir
  */
 class IrDotDumper : public IVisitor {
 public:
+  void visit(ops::BatchNormOp& op) override;
+  void visit(ops::BiasAddOp& op) override;
+  void visit(ops::CappedReluOp& op) override;
   void visit(ops::ConcatOp& op) override;
   void visit(ops::ConstantOp& op) override;
-  void visit(ops::ReluOp& op) override;
   void visit(ops::Conv2DOp& op) override;
+  void visit(ops::DeConv2DOp& op) override;
   void visit(ops::DepthwiseConv2DOp& op) override;
-  void visit(ops::SoftmaxOp& op) override;
-  void visit(ops::PoolOp& op) override;
+  void visit(ops::DropoutOp& op) override;
+  void visit(ops::ElementwiseOp& op) override;
+  void visit(ops::EluOp& op) override;
   void visit(ops::FullyConnectedOp& op) override;
-  void visit(ops::CappedReluOp& op) override;
-  void visit(ops::BiasAddOp& op) override;
-  void visit(ops::VariableOp& op) override;
+  void visit(ops::GatherOp& op) override;
+  void visit(ops::PadOp& op) override;
+  void visit(ops::PoolOp& op) override;
+  void visit(ops::ReduceFOp& op) override;
+  void visit(ops::ReluOp& op) override;
   void visit(ops::ReshapeOp& op) override;
   void visit(ops::ResizeOp& op) override;
   void visit(ops::ScaleOp& op) override;
-  void visit(ops::BatchNormOp& op) override;
-  void visit(ops::DropoutOp& op) override;
-  void visit(ops::DeConv2DOp& op) override;
-  void visit(ops::EluOp& op) override;
-  void visit(ops::TanhOp& op) override;
-  void visit(ops::ElementwiseOp& op) override;
+  void visit(ops::SoftmaxOp& op) override;
   void visit(ops::SqueezeOp& op) override;
-  void visit(ops::PadOp& op) override;
-  void visit(ops::ReduceFOp& op) override;
+  void visit(ops::TanhOp& op) override;
   void visit(ops::TransposeOp& op) override;
+  void visit(ops::VariableOp& op) override;
+
 
   void writeDot(std::ostream &os) { dotBuilder.writeDot(os); };
 
diff --git a/contrib/nnc/include/core/modelIR/operations/GatherOp.h b/contrib/nnc/include/core/modelIR/operations/GatherOp.h
new file mode 100644 (file)
index 0000000..3c9bbc4
--- /dev/null
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_IR_MODEL_GATHER_H_
+#define _NNC_CORE_IR_MODEL_GATHER_H_
+
+#include "core/modelIR/Operation.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+/**
+ * @brief Gather operation as defined by ONNX spec.
+ * https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gather
+ * https://www.tensorflow.org/api_docs/python/tf/gather
+ */
+class GatherOp : public Operation {
+public:
+  GatherOp(const IODescriptor& data, const IODescriptor& indices, int32_t axis)
+      : Operation(Type::gather, {data, indices}), _axis(axis) {
+    inferOutputShapes();
+  }
+
+private:
+  void inferOutputShapes();
+
+  int32_t _axis;
+};
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
+
+#endif //_NNC_CORE_IR_MODEL_GATHER_H_
index 040070b..c4ae1d2 100644 (file)
@@ -21,6 +21,7 @@
 HANDLE_OP(concat, ConcatOp)
 HANDLE_OP(conv2D, Conv2DOp)
 HANDLE_OP(depthwiseConv, DepthwiseConv2DOp)
+HANDLE_OP(gather, GatherOp)
 HANDLE_OP(softmax, SoftmaxOp)
 HANDLE_OP(pool, PoolOp)
 HANDLE_OP(fullyConnected, FullyConnectedOp)
index ce4d938..f1ddba3 100644 (file)
@@ -36,30 +36,31 @@ class NNInterpreter : public IVisitor {
 public:
   explicit NNInterpreter() = default;
 
+  void visit(ops::BatchNormOp& op) override;
+  void visit(ops::BiasAddOp& op) override;
+  void visit(ops::CappedReluOp& op) override;
   void visit(ops::ConcatOp& op) override;
   void visit(ops::ConstantOp& op) override;
   void visit(ops::Conv2DOp& op) override;
+  void visit(ops::DeConv2DOp& op) override;
   void visit(ops::DepthwiseConv2DOp& op) override;
-  void visit(ops::ReluOp& op) override;
-  void visit(ops::SoftmaxOp& op) override;
-  void visit(ops::PoolOp& op) override;
-  void visit(ops::FullyConnectedOp& op) override;
-  void visit(ops::CappedReluOp& op) override;
-  void visit(ops::BiasAddOp& op) override;
-  void visit(ops::VariableOp& op) override;
-  void visit(ops::ReshapeOp& op) override;
-  void visit(ops::ResizeOp& op) override;
-  void visit(ops::ScaleOp& op) override;
-  void visit(ops::BatchNormOp& op) override;
   void visit(ops::DropoutOp& op) override;
-  void visit(ops::TanhOp& op) override;
   void visit(ops::ElementwiseOp& op) override;
-  void visit(ops::DeConv2DOp& op) override;
   void visit(ops::EluOp& op) override;
-  void visit(ops::SqueezeOp& op) override;
+  void visit(ops::FullyConnectedOp& op) override;
+  void visit(ops::GatherOp& op) override;
   void visit(ops::PadOp& op) override;
+  void visit(ops::PoolOp& op) override;
   void visit(ops::ReduceFOp& op) override;
+  void visit(ops::ReluOp& op) override;
+  void visit(ops::ReshapeOp& op) override;
+  void visit(ops::ResizeOp& op) override;
+  void visit(ops::ScaleOp& op) override;
+  void visit(ops::SoftmaxOp& op) override;
+  void visit(ops::SqueezeOp& op) override;
+  void visit(ops::TanhOp& op) override;
   void visit(ops::TransposeOp& op) override;
+  void visit(ops::VariableOp& op) override;
 
   void setInput(const std::string &name, const TensorVariant& data);
   std::vector<TensorVariant> &getResult(Operation* op);
index 1b32ab2..5516041 100644 (file)
@@ -6,26 +6,26 @@
 #include "core/modelIR/Tensor.h"
 
 #include "core/modelIR/Operation.h"
-#include "core/modelIR/operations/VariableOp.h"
-#include "core/modelIR/operations/SoftmaxOp.h"
-#include "core/modelIR/operations/Conv2DOp.h"
-#include "core/modelIR/operations/ConstantOp.h"
-#include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/BatchNormOp.h"
-#include "core/modelIR/operations/DropoutOp.h"
-#include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/CappedReluOp.h"
-#include "core/modelIR/operations/TanhOp.h"
-#include "core/modelIR/operations/ReshapeOp.h"
-#include "core/modelIR/operations/ResizeOp.h"
+#include "core/modelIR/operations/ConcatOp.h"
+#include "core/modelIR/operations/ConstantOp.h"
+#include "core/modelIR/operations/Conv2DOp.h"
+#include "core/modelIR/operations/Deconv2DOp.h"
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
+#include "core/modelIR/operations/DropoutOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
-#include "core/modelIR/operations/ConcatOp.h"
 #include "core/modelIR/operations/PoolOp.h"
-#include "core/modelIR/operations/BiasAddOp.h"
-#include "core/modelIR/operations/ElementwiseOp.h"
-#include "core/modelIR/operations/Deconv2DOp.h"
 #include "core/modelIR/operations/ReduceFOp.h"
+#include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ReshapeOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
+#include "core/modelIR/operations/ScaleOp.h"
+#include "core/modelIR/operations/SoftmaxOp.h"
+#include "core/modelIR/operations/TanhOp.h"
+#include "core/modelIR/operations/VariableOp.h"
 
 #include <algorithm>
 
@@ -883,5 +883,9 @@ void AclCppOpGenerator::visit(mir::ops::TransposeOp& op) {
   assert(false && "Unimplemented operation: TransposeOp");
 }
 
+void AclCppOpGenerator::visit(mir::ops::GatherOp& op) {
+  assert(false && "Unimplemented operation: GatherOp");
+}
+
 }
 // namespace nnc
index 077049b..b6108fc 100644 (file)
@@ -47,30 +47,31 @@ public:
    * @brief Implementations of the IVisitor visitors.
    * @param op
    */
+  void visit(mir::ops::BatchNormOp& op) override;
+  void visit(mir::ops::BiasAddOp& op) override;
+  void visit(mir::ops::CappedReluOp& op) override;
   void visit(mir::ops::ConcatOp& op) override;
   void visit(mir::ops::ConstantOp& op) override;
   void visit(mir::ops::Conv2DOp& op) override;
+  void visit(mir::ops::DeConv2DOp& op) override;
   void visit(mir::ops::DepthwiseConv2DOp& op) override;
-  void visit(mir::ops::SoftmaxOp& op) override;
-  void visit(mir::ops::PoolOp& op) override;
+  void visit(mir::ops::DropoutOp& op) override;
+  void visit(mir::ops::ElementwiseOp& op) override;
+  void visit(mir::ops::EluOp& op) override;
   void visit(mir::ops::FullyConnectedOp& op) override;
-  void visit(mir::ops::CappedReluOp& op) override;
-  void visit(mir::ops::BiasAddOp& op) override;
-  void visit(mir::ops::VariableOp& op) override;
+  void visit(mir::ops::GatherOp& op) override;
+  void visit(mir::ops::PadOp& op) override;
+  void visit(mir::ops::PoolOp& op) override;
+  void visit(mir::ops::ReduceFOp& op) override;
   void visit(mir::ops::ReluOp& op) override;
   void visit(mir::ops::ReshapeOp& op) override;
   void visit(mir::ops::ResizeOp& op) override;
   void visit(mir::ops::ScaleOp& op) override;
-  void visit(mir::ops::BatchNormOp& op) override;
-  void visit(mir::ops::DropoutOp& op) override;
-  void visit(mir::ops::TanhOp& op) override;
-  void visit(mir::ops::ElementwiseOp& op) override;
-  void visit(mir::ops::DeConv2DOp& op) override;
-  void visit(mir::ops::EluOp& op) override;
+  void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
-  void visit(mir::ops::PadOp& op) override;
-  void visit(mir::ops::ReduceFOp& op) override;
+  void visit(mir::ops::TanhOp& op) override;
   void visit(mir::ops::TransposeOp& op) override;
+  void visit(mir::ops::VariableOp& op) override;
 
 private:
   using AF = ArtifactFactory;
index a908b39..76cd84a 100644 (file)
@@ -347,4 +347,8 @@ void NNInterpreter::visit(ops::TransposeOp& op) {
   var(op.getId()) = Transpose(input, op)();
 }
 
+void NNInterpreter::visit(ops::GatherOp& op) {
+  assert(false && "Not yet imlemented");
+}
+
 } // namespace nnc
index 8cbbcb5..3fe75d3 100644 (file)
 #include "core/modelIR/ShapeRange.h"
 
 #include "core/modelIR/Graph.h"
+#include "core/modelIR/operations/BatchNormOp.h"
+#include "core/modelIR/operations/BiasAddOp.h"
+#include "core/modelIR/operations/CappedReluOp.h"
 #include "core/modelIR/operations/ConcatOp.h"
 #include "core/modelIR/operations/ConstantOp.h"
 #include "core/modelIR/operations/Conv2DOp.h"
 #include "core/modelIR/operations/Deconv2DOp.h"
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
-#include "core/modelIR/operations/SoftmaxOp.h"
-#include "core/modelIR/operations/PoolOp.h"
+#include "core/modelIR/operations/DropoutOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
-#include "core/modelIR/operations/CappedReluOp.h"
-#include "core/modelIR/operations/BiasAddOp.h"
+#include "core/modelIR/operations/GatherOp.h"
+#include "core/modelIR/operations/PadOp.h"
+#include "core/modelIR/operations/PoolOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 #include "core/modelIR/operations/ReluOp.h"
-#include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
-#include "core/modelIR/operations/BatchNormOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
-#include "core/modelIR/operations/DropoutOp.h"
-#include "core/modelIR/operations/TanhOp.h"
-#include "core/modelIR/operations/ElementwiseOp.h"
-#include "core/modelIR/operations/VariableOp.h"
+#include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
-#include "core/modelIR/operations/PadOp.h"
-#include "core/modelIR/operations/ReduceFOp.h"
+#include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/TransposeOp.h"
+#include "core/modelIR/operations/VariableOp.h"
 
 using namespace std;
 
@@ -303,4 +304,8 @@ void ModelAnalyzer::visit(mir::ops::TransposeOp& op) {
   addOpDescr(&op, "transpose");
 }
 
+void ModelAnalyzer::visit(mir::ops::GatherOp& op) {
+  addOpDescr(&op, "gather");
+}
+
 } // namespace nnc
index 0b16711..a9ce551 100644 (file)
@@ -89,30 +89,31 @@ public:
  */
   void analyze(const mir::Graph* g);
 
+  void visit(mir::ops::BatchNormOp& op) override;
+  void visit(mir::ops::BiasAddOp& op) override;
+  void visit(mir::ops::CappedReluOp& op) override;
   void visit(mir::ops::ConcatOp& op) override;
   void visit(mir::ops::ConstantOp& op) override;
   void visit(mir::ops::Conv2DOp& op) override;
+  void visit(mir::ops::DeConv2DOp& op) override;
   void visit(mir::ops::DepthwiseConv2DOp& op) override;
-  void visit(mir::ops::SoftmaxOp& op) override;
-  void visit(mir::ops::PoolOp& op) override;
+  void visit(mir::ops::DropoutOp& op) override;
+  void visit(mir::ops::ElementwiseOp& op) override;
+  void visit(mir::ops::EluOp& op) override;
   void visit(mir::ops::FullyConnectedOp& op) override;
-  void visit(mir::ops::CappedReluOp& op) override;
-  void visit(mir::ops::BiasAddOp& op) override;
-  void visit(mir::ops::VariableOp& op) override;
+  void visit(mir::ops::GatherOp& op) override;
+  void visit(mir::ops::PadOp& op) override;
+  void visit(mir::ops::PoolOp& op) override;
+  void visit(mir::ops::ReduceFOp& op) override;
   void visit(mir::ops::ReluOp& op) override;
   void visit(mir::ops::ReshapeOp& op) override;
   void visit(mir::ops::ResizeOp& op) override;
   void visit(mir::ops::ScaleOp& op) override;
-  void visit(mir::ops::BatchNormOp& op) override;
-  void visit(mir::ops::DropoutOp& op) override;
-  void visit(mir::ops::TanhOp& op) override;
-  void visit(mir::ops::ElementwiseOp& op) override;
-  void visit(mir::ops::DeConv2DOp& op) override;
-  void visit(mir::ops::EluOp& op) override;
+  void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
-  void visit(mir::ops::PadOp& op) override;
-  void visit(mir::ops::ReduceFOp& op) override;
+  void visit(mir::ops::TanhOp& op) override;
   void visit(mir::ops::TransposeOp& op) override;
+  void visit(mir::ops::VariableOp& op) override;
 
   /**
    * @return vector of id's of network input tensors
index fe7b9e4..c2241ff 100644 (file)
@@ -363,4 +363,9 @@ void Serializer::visit(mir::ops::TransposeOp& op) {
   serializeShape(op.getOutputShape(0));
 }
 
+void Serializer::visit(mir::ops::GatherOp& op) {
+  _curOp->_paramStartOffset = _buffer.size();
+  assert(false && "Not yet implemented");
+}
+
 } // namespace nnc
index dcac3cb..748db4e 100644 (file)
@@ -41,30 +41,31 @@ namespace nnc
 class Serializer: public mir::IVisitor {
 public:
 
+  void visit(mir::ops::BatchNormOp& op) override;
+  void visit(mir::ops::BiasAddOp& op) override;
+  void visit(mir::ops::CappedReluOp& op) override;
   void visit(mir::ops::ConcatOp& op) override;
   void visit(mir::ops::ConstantOp& op) override;
   void visit(mir::ops::Conv2DOp& op) override;
+  void visit(mir::ops::DeConv2DOp& op) override;
   void visit(mir::ops::DepthwiseConv2DOp& op) override;
-  void visit(mir::ops::SoftmaxOp& op) override;
-  void visit(mir::ops::PoolOp& op) override;
+  void visit(mir::ops::DropoutOp& op) override;
+  void visit(mir::ops::ElementwiseOp& op) override;
+  void visit(mir::ops::EluOp& op) override;
   void visit(mir::ops::FullyConnectedOp& op) override;
-  void visit(mir::ops::CappedReluOp& op) override;
-  void visit(mir::ops::BiasAddOp& op) override;
-  void visit(mir::ops::VariableOp& op) override;
+  void visit(mir::ops::GatherOp& op) override;
+  void visit(mir::ops::PadOp& op) override;
+  void visit(mir::ops::PoolOp& op) override;
+  void visit(mir::ops::ReduceFOp& op) override;
   void visit(mir::ops::ReluOp& op) override;
   void visit(mir::ops::ReshapeOp& op) override;
   void visit(mir::ops::ResizeOp& op) override;
   void visit(mir::ops::ScaleOp& op) override;
-  void visit(mir::ops::BatchNormOp& op) override;
-  void visit(mir::ops::DropoutOp& op) override;
-  void visit(mir::ops::TanhOp& op) override;
-  void visit(mir::ops::ElementwiseOp& op) override;
-  void visit(mir::ops::DeConv2DOp& op) override;
-  void visit(mir::ops::EluOp& op) override;
+  void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
-  void visit(mir::ops::PadOp& op) override;
-  void visit(mir::ops::ReduceFOp& op) override;
+  void visit(mir::ops::TanhOp& op) override;
   void visit(mir::ops::TransposeOp& op) override;
+  void visit(mir::ops::VariableOp& op) override;
 
   void serialize(std::list<OpDescr> &inferenceSequence);