[nnc] Support for Gather operation in soft backend (#2663)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Thu, 13 Dec 2018 16:20:54 +0000 (19:20 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 13 Dec 2018 16:20:54 +0000 (19:20 +0300)
Add support for Gather operation to soft backend.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
contrib/nnc/include/core/modelIR/operations/GatherOp.h
contrib/nnc/passes/soft_backend/CPPGenerator.cpp
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/code_snippets/cpp_common_funcs.def
contrib/nnc/passes/soft_backend/code_snippets/cpp_gather.def [new file with mode: 0644]
contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def
contrib/nnc/unittests/soft_backend/CPPOperations.cpp

index 3c9bbc4..5e2b32c 100644 (file)
@@ -35,6 +35,8 @@ public:
     inferOutputShapes();
   }
 
+  int32_t getAxis() const { return _axis; }
+
 private:
   void inferOutputShapes();
 
index 1e2928c..ec900bc 100644 (file)
@@ -47,6 +47,7 @@ using namespace std;
 #include "cpp_elementwise.generated.h"
 #include "cpp_pad.generated.h"
 #include "cpp_transpose.generated.h"
+#include "cpp_gather.generated.h"
 
 namespace nnc
 {
@@ -291,6 +292,7 @@ void CPPCodeGenerator::materializeCode(ostream &out, const ModelAnalyzer &ma, co
   out.write(cpp_pad, sizeof(cpp_pad));
   out.write(cpp_conv_transpose, sizeof(cpp_conv_transpose));
   out.write(cpp_transpose, sizeof(cpp_transpose));
+  out.write(cpp_gather, sizeof(cpp_gather));
   out.write(cpp_operations, sizeof(cpp_operations));
   out.write(cpp_scale, sizeof(cpp_scale));
   out.write(cpp_dropout, sizeof(cpp_dropout));
index e287105..c4449a8 100644 (file)
 
 #include "CommonData.def"
 
+#include "core/modelIR/operations/BatchNormOp.h"
+#include "core/modelIR/operations/BiasAddOp.h"
+#include "core/modelIR/operations/CappedReluOp.h"
 #include "core/modelIR/operations/ConcatOp.h"
 #include "core/modelIR/operations/ConstantOp.h"
 #include "core/modelIR/operations/Conv2DOp.h"
 #include "core/modelIR/operations/Deconv2DOp.h"
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
-#include "core/modelIR/operations/SoftmaxOp.h"
-#include "core/modelIR/operations/PoolOp.h"
+#include "core/modelIR/operations/DropoutOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GatherOp.h"
 #include "core/modelIR/operations/GemmOp.h"
-#include "core/modelIR/operations/CappedReluOp.h"
-#include "core/modelIR/operations/BiasAddOp.h"
+#include "core/modelIR/operations/PadOp.h"
+#include "core/modelIR/operations/PoolOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
 #include "core/modelIR/operations/ReluOp.h"
-#include "core/modelIR/operations/ResizeOp.h"
-#include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
-#include "core/modelIR/operations/BatchNormOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
-#include "core/modelIR/operations/DropoutOp.h"
-#include "core/modelIR/operations/TanhOp.h"
-#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
-#include "core/modelIR/operations/PadOp.h"
-#include "core/modelIR/operations/ReduceFOp.h"
+#include "core/modelIR/operations/TanhOp.h"
 #include "core/modelIR/operations/TransposeOp.h"
 
 #include "pass/PassException.h"
@@ -370,7 +371,10 @@ void Serializer::visit(mir::ops::TransposeOp& op) {
 
 void Serializer::visit(mir::ops::GatherOp& op) {
   _curOp->_paramStartOffset = _buffer.size();
-  assert(false && "Not yet implemented");
+  // serialize parameters
+  serializeT<int32_t>(op.getAxis());
+  // serialize output shape
+  serializeShape(op.getOutputShape(0));
 }
 
 } // namespace nnc
index d7ea8ed..a538249 100644 (file)
@@ -577,6 +577,10 @@ inline int Offset(const RuntimeShape& shape, int* index) {
   return Offset(shape, index[0], index[1], index[2], index[3]);
 }
 
+struct GatherParams {
+  int16 axis;
+};
+
 struct TransposeParams {
   int8 perm_count;
   int32 perm[4];
diff --git a/contrib/nnc/passes/soft_backend/code_snippets/cpp_gather.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_gather.def
new file mode 100644 (file)
index 0000000..2a8c86f
--- /dev/null
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+template <typename T, typename CoordsT = int32>
+inline void Gather(const GatherParams& op_params,
+                   const RuntimeShape& input_shape, const T* input_data,
+                   const RuntimeShape& coords_shape, const CoordsT* coords_data,
+                   const RuntimeShape& output_shape, T* output_data) {
+  int axis = op_params.axis;
+  if (axis < 0) {
+    axis += input_shape.DimensionsCount();
+  }
+  TFLITE_DCHECK_GE(axis, 0);
+  TFLITE_DCHECK_LT(axis, input_shape.DimensionsCount());
+  const int axis_size = input_shape.Dims(axis);
+  const int coords_count = coords_shape.FlatSize();
+
+  int outer_size = 1;
+  for (int i = 0; i < axis; ++i) {
+    outer_size *= input_shape.Dims(i);
+  }
+
+  int inner_size = 1;
+  for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
+    inner_size *= input_shape.Dims(i);
+  }
+
+  for (int outer = 0; outer < outer_size; ++outer) {
+    for (int i = 0; i < coords_count; ++i) {
+      TFLITE_DCHECK_GE(coords_data[i], 0);
+      TFLITE_DCHECK_LT(coords_data[i], axis_size);
+      std::memcpy(
+          output_data + (outer * coords_count + i) * inner_size,
+          input_data + (outer * axis_size + coords_data[i]) * inner_size,
+          sizeof(T) * inner_size);
+    }
+  }
+}
index c381001..8b7075e 100644 (file)
@@ -587,3 +587,18 @@ void transpose(Tensor &out, const char *params, const Tensor &in) {
             shapeToRuntimeShape(in.getShape()), in.getData(),
             shapeToRuntimeShape(out.getShape()), out.getData());
 }
+
+void gather(Tensor &out, const char *params, const Tensor &data, const Tensor &indices) {
+  GatherParams gather_params;
+  gather_params.axis = deserializeT<int32_t>(params);
+
+  Shape out_s = deserializeShape(params);
+  out.reShape(out_s);
+
+  // reinterpret_cast is used here because indices in ModelIR are integral, but getData returns
+  // pointer to float.
+  Gather(gather_params,
+         shapeToRuntimeShape(data.getShape()), data.getData(),
+         shapeToRuntimeShape(indices.getShape()), reinterpret_cast<const int32*>(indices.getData()),
+         shapeToRuntimeShape(out.getShape()), out.getData());
+}
index e7dd519..ef1dbde 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "code_snippets/eigen.def"
 
+#include "code_snippets/cpp_header_types.def"
 #include "code_snippets/cpp_common_funcs.def"
 
 #include "code_snippets/cpp_add_bias.def"
 #include "code_snippets/cpp_conv.def"
 #include "code_snippets/cpp_conv_transpose.def"
 #include "code_snippets/cpp_depthwise_conv.def"
+#include "code_snippets/cpp_elementwise.def"
+#include "code_snippets/cpp_elu.def"
 #include "code_snippets/cpp_fully_connected.def"
+#include "code_snippets/cpp_gather.def"
+#include "code_snippets/cpp_pad.def"
 #include "code_snippets/cpp_pool.def"
 #include "code_snippets/cpp_reduce.def"
 #include "code_snippets/cpp_relu.def"
 #include "code_snippets/cpp_softmax.def"
-#include "code_snippets/cpp_elu.def"
-#include "code_snippets/cpp_elementwise.def"
 #include "code_snippets/cpp_tanh.def"
-#include "code_snippets/cpp_pad.def"
 #include "code_snippets/cpp_transpose.def"
 
-#include "CommonData.def"
-#include "code_snippets/cpp_header_types.def"
 #include "code_snippets/cpp_operations.def"
 #include "code_snippets/cpp_scale.def"