[nnc] Added slice (#2680)
authorАндрей Шедько/AI Tools Lab /SRR/Engineer/삼성전자 <a.shedko@samsung.com>
Mon, 17 Dec 2018 12:37:55 +0000 (15:37 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 17 Dec 2018 12:37:55 +0000 (15:37 +0300)
Added slice to MIR, Interpreter, C++ SoftBackend
Added Importer Support for slice to tflite importer
Fixed SoftBackend Impl

Signed-off-by: Andrei Shedko <a.shedko@partner.samsung.com>
25 files changed:
contrib/nnc/core/CMakeLists.txt
contrib/nnc/core/modelIR/IrDotDumper.cpp
contrib/nnc/core/modelIR/Operation.cpp
contrib/nnc/core/modelIR/operations/SliceOp.cpp [new file with mode: 0644]
contrib/nnc/include/core/modelIR/IrDotDumper.h
contrib/nnc/include/core/modelIR/operations/SliceOp.h [new file with mode: 0644]
contrib/nnc/include/core/modelIR/operations/operations.lst.h
contrib/nnc/include/passes/interpreter/Interpreter.h
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.h
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/ops/common.cpp
contrib/nnc/passes/interpreter/ops/common.h
contrib/nnc/passes/soft_backend/CPPGenerator.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.h
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/SBSerializer.h
contrib/nnc/passes/soft_backend/code_snippets/cpp_common_funcs.def
contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def
contrib/nnc/passes/soft_backend/code_snippets/cpp_slice.def [new file with mode: 0644]
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.h
contrib/nnc/unittests/soft_backend/CPPOperations.cpp

index 50dc209..5936137 100644 (file)
@@ -9,6 +9,7 @@ set(SOURCES "modelIR/operations/ConcatOp.cpp"
             "modelIR/operations/PadOp.cpp"
             "modelIR/operations/PoolOp.cpp"
             "modelIR/operations/SqueezeOp.cpp"
+            "modelIR/operations/SliceOp.cpp"
             "modelIR/operations/TransposeOp.cpp"
             "modelIR/Graph.cpp"
             "modelIR/Index.cpp"
index 7e059a2..b430c72 100644 (file)
@@ -39,6 +39,7 @@
 #include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/SigmoidOp.h"
+#include "core/modelIR/operations/SliceOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqrtOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
@@ -207,6 +208,16 @@ void IrDotDumper::visit(ops::ScaleOp& op) {
   dotBuilder.updateWithOp(&op, nodeInfo);
 }
 
+void IrDotDumper::visit(ops::SliceOp& op) {
+  auto node_info = DotIrNodeInfo().withType("SliceOp", op.getName())
+    .withInShapes(getInputShapes(op))
+    .withShape("Starts", op.getStarts())
+    .withShape("Sizes", op.getSizes())
+    .withOutShapes(getOutputShapes(op));
+
+  dotBuilder.updateWithOp(&op, node_info);
+}
+
 void IrDotDumper::visit(ops::DropoutOp& op) {
   auto nodeInfo = DotIrNodeInfo().withType("DropoutOp", op.getName())
                                  .withInShapes(getInputShapes(op))
index 187e08c..2eb6eda 100644 (file)
@@ -37,6 +37,7 @@
 #include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/SigmoidOp.h"
+#include "core/modelIR/operations/SliceOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
 #include "core/modelIR/operations/SqrtOp.h"
diff --git a/contrib/nnc/core/modelIR/operations/SliceOp.cpp b/contrib/nnc/core/modelIR/operations/SliceOp.cpp
new file mode 100644 (file)
index 0000000..2c19ebb
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+#include "core/modelIR/operations/SliceOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+// Only supports 4d inputs
+void SliceOp::inferOutputShapes() {
+  const Shape& input_shape = getInputShape(0);
+  assert(input_shape.rank() == 4 && "4d input only");
+  Shape output_shape(input_shape.rank());
+  for (int i = 0; i < input_shape.rank(); i++) {
+    if (_sizes.dim(i) == -1) {
+      output_shape.dim(i) = input_shape.dim(i) - _starts.dim(i);
+    } else {
+      output_shape.dim(i) = _sizes.dim(i);
+    }
+  }
+  setOutputShape(0, output_shape);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
\ No newline at end of file
index 5770e54..0b3128e 100644 (file)
@@ -53,6 +53,7 @@ public:
   void visit(ops::ResizeOp& op) override;
   void visit(ops::ScaleOp& op) override;
   void visit(ops::SigmoidOp& op) override;
+  void visit(ops::SliceOp& op) override;
   void visit(ops::SoftmaxOp& op) override;
   void visit(ops::SqrtOp& op) override;
   void visit(ops::SqueezeOp& op) override;
diff --git a/contrib/nnc/include/core/modelIR/operations/SliceOp.h b/contrib/nnc/include/core/modelIR/operations/SliceOp.h
new file mode 100644 (file)
index 0000000..4787517
--- /dev/null
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_IR_MODEL_SLICE_H_
+#define _NNC_CORE_IR_MODEL_SLICE_H_
+
+#include "core/modelIR/Operation.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+class SliceOp : public Operation {
+public:
+  SliceOp(const IODescriptor& arg, const Shape& starts, const Shape& sizes) :
+    Operation(Type::slice, {arg}),
+    _starts(starts),
+    _sizes(sizes) {
+    inferOutputShapes();
+  }
+
+  const Shape& getStarts() { return _starts; }
+
+  const Shape& getSizes() { return _sizes; }
+
+private:
+  void inferOutputShapes();
+
+  Shape _starts;
+  Shape _sizes;
+};
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
+
+#endif //_NNC_CORE_IR_MODEL_SLICE_H_
+
index 2fd8e9d..c3fe3d3 100644 (file)
@@ -35,6 +35,7 @@ HANDLE_OP(reshape, ReshapeOp)
 HANDLE_OP(resizeIm, ResizeOp)
 HANDLE_OP(scale, ScaleOp)
 HANDLE_OP(sigmoid, SigmoidOp)
+HANDLE_OP(slice, SliceOp)
 HANDLE_OP(batchNorm, BatchNormOp)
 HANDLE_OP(dropout, DropoutOp)
 HANDLE_OP(tanh, TanhOp)
index bceb810..e57d519 100644 (file)
@@ -58,6 +58,7 @@ public:
   void visit(ops::ResizeOp& op) override;
   void visit(ops::ScaleOp& op) override;
   void visit(ops::SigmoidOp& op) override;
+  void visit(ops::SliceOp& op) override;
   void visit(ops::SoftmaxOp& op) override;
   void visit(ops::SqrtOp& op) override;
   void visit(ops::SqueezeOp& op) override;
index 21b21f0..acb8cba 100644 (file)
@@ -511,6 +511,10 @@ void AclCppOpGenerator::visit(ops::ScaleOp& op) {
   runLayer(layer2);
 }
 
+void AclCppOpGenerator::visit(mir::ops::SliceOp& op) {
+  assert(false && "Unimplemented operation: SliceOp");
+}
+
 void AclCppOpGenerator::visit(ops::BatchNormOp& op) {
   // Not supported in our framework, but present in ACL API.
   throw AclCppException("Not supported in inference yet.");
index e97e1ba..f48eb3d 100644 (file)
@@ -69,6 +69,7 @@ public:
   void visit(mir::ops::ResizeOp& op) override;
   void visit(mir::ops::ScaleOp& op) override;
   void visit(mir::ops::SigmoidOp& op) override;
+  void visit(mir::ops::SliceOp& op) override;
   void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqrtOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
index c7f2e76..e001f32 100644 (file)
@@ -37,6 +37,7 @@
 #include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/SigmoidOp.h"
+#include "core/modelIR/operations/SliceOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqrtOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
 #include "ops/Scale.h"
 #include "ops/Softmax.h"
 #include "ops/Transpose.h"
+#include "ops/Dropout.h"
+#include "ops/BatchNorm.h"
+#include "ops/Pad.h"
+#include "ops/common.h"
 
 #include <vector>
 #include <cmath>
@@ -216,6 +221,17 @@ void NNInterpreter::visit(ops::ScaleOp& op) {
    var(op.getId()) = Scale(input, op)();
 }
 
+
+void NNInterpreter::visit(ops::SliceOp& op) {
+  mapByName(&op);
+  auto operand = op.getPrevNodes()[0];
+  auto input = Tensor<float>(var(operand.op->getId())[operand.index]);
+  var(op.getId()) = Fill<float>(op.getOutputShape(0), [&input, &op](const Index& id) {
+    Index idx = nnc::shift(id, op.getStarts());
+    return input.at(idx);
+  })();
+}
+
 void NNInterpreter::visit(ops::DropoutOp& op) {
   mapByName(&op);
   auto operand = op.getPrevNodes()[0];
index 76d0af0..1c615ca 100644 (file)
@@ -34,4 +34,13 @@ void translate(Index &translatedIndex, const Index &sourceIndex, const Index &ke
   }
 }
 
+Index shift(const Index& in_index, const Shape& shift_from) {
+  Index index = in_index;
+  assert(index.rank() == shift_from.rank());
+  for (int32_t d = 0; d < in_index.rank(); ++d) {
+    index.at(d) = index.at(d) + shift_from.dim(d);
+  }
+  return index;
+}
+
 } // namespace nnc
index 8da7dd5..a414847 100644 (file)
@@ -33,4 +33,12 @@ namespace nnc
 void translate(mir::Index &translatedIndex, const mir::Index &sourceIndex, const mir::Index &kernelIndex,
                const mir::Shape &strides, const mir::Index &paddings);
 
+/**
+ * Shift in_index by `shift`
+ * @param[in] in_index argument
+ * @param[in] shift
+ * @return the
+ */
+mir::Index shift(const mir::Index& in_index, const mir::Shape& shift);
+
 } // namespace nnc
index b0513a2..2c4b851 100644 (file)
@@ -42,6 +42,7 @@ using namespace std;
 #include "cpp_reduce.generated.h"
 #include "cpp_softmax.generated.h"
 #include "cpp_scale.generated.h"
+#include "cpp_slice.generated.h"
 #include "cpp_dropout.generated.h"
 #include "cpp_batchnorm.generated.h"
 #include "cpp_elu.generated.h"
@@ -289,6 +290,7 @@ void CPPCodeGenerator::materializeCode(ostream &out, const ModelAnalyzer &ma, co
   out.write(cpp_relu, sizeof(cpp_relu));
   out.write(cpp_reduce, sizeof(cpp_reduce));
   out.write(cpp_softmax, sizeof(cpp_softmax));
+  out.write(cpp_slice, sizeof(cpp_slice));
   out.write(cpp_elementwise, sizeof(cpp_elementwise));
   out.write(cpp_elu, sizeof(cpp_elu));
   out.write(cpp_tanh, sizeof(cpp_tanh));
@@ -297,7 +299,9 @@ void CPPCodeGenerator::materializeCode(ostream &out, const ModelAnalyzer &ma, co
   out.write(cpp_conv_transpose, sizeof(cpp_conv_transpose));
   out.write(cpp_transpose, sizeof(cpp_transpose));
   out.write(cpp_gather, sizeof(cpp_gather));
+  // Operations calls into all of the above
   out.write(cpp_operations, sizeof(cpp_operations));
+  // Below call into operations
   out.write(cpp_scale, sizeof(cpp_scale));
   out.write(cpp_dropout, sizeof(cpp_dropout));
   out.write(cpp_batchnorm, sizeof(cpp_batchnorm));
index 206e18b..a2ef5a4 100644 (file)
@@ -45,6 +45,7 @@
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/SigmoidOp.h"
+#include "core/modelIR/operations/SliceOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqrtOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
@@ -252,6 +253,10 @@ void ModelAnalyzer::visit(ops::ScaleOp& op) {
   addOpDescr(&op, "scale");
 }
 
+void ModelAnalyzer::visit(mir::ops::SliceOp& op) {
+  addOpDescr(&op, "slice");
+}
+
 void ModelAnalyzer::visit(ops::BatchNormOp& op) {
   addOpDescr(&op, "batchNorm");
 }
index f6965fc..e01806e 100644 (file)
@@ -111,6 +111,7 @@ public:
   void visit(mir::ops::ResizeOp& op) override;
   void visit(mir::ops::ScaleOp& op) override;
   void visit(mir::ops::SigmoidOp& op) override;
+  void visit(mir::ops::SliceOp& op) override;
   void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqrtOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
index 9314f21..7120f66 100644 (file)
@@ -41,6 +41,7 @@
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
+#include "core/modelIR/operations/SliceOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
 #include "core/modelIR/operations/SqrtOp.h"
@@ -273,6 +274,13 @@ void Serializer::visit(ops::ScaleOp& op) {
   serializeTensor(op.getWeights());
 }
 
+void Serializer::visit(mir::ops::SliceOp& op) {
+  _curOp->_paramStartOffset = _buffer.size();
+  serializeShape(op.getStarts());
+  serializeShape(op.getSizes());
+  serializeShape(op.getOutputShape(0));
+}
+
 void Serializer::visit(ops::DropoutOp& op) {
   _curOp->_paramStartOffset = _buffer.size();
   serializeT<float>(op.getRate());
index 92c22c4..789be93 100644 (file)
@@ -63,6 +63,7 @@ public:
   void visit(mir::ops::ResizeOp& op) override;
   void visit(mir::ops::ScaleOp& op) override;
   void visit(mir::ops::SigmoidOp& op) override;
+  void visit(mir::ops::SliceOp& op) override;
   void visit(mir::ops::SoftmaxOp& op) override;
   void visit(mir::ops::SqrtOp& op) override;
   void visit(mir::ops::SqueezeOp& op) override;
index 1659122..e1b6caa 100644 (file)
@@ -561,6 +561,13 @@ void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
   }
 }
 
+struct SliceParams {
+  int8 begin_count;
+  int32 begin[4];
+  int8 size_count;
+  int32 size[4];
+};
+
 // Get common shape dim, DCHECKing that they all agree.
 inline int MatchingDim(const RuntimeShape& shape1, int index1,
                        const RuntimeShape& shape2, int index2) {
index d99913f..1547fe1 100644 (file)
@@ -21,7 +21,6 @@
 #include <fcntl.h>
 #include <unistd.h>
 #include <cstring>
-#include <iostream>
 
 using namespace std;
 
@@ -443,6 +442,26 @@ void biasAdd(Tensor &out, const char *params, const Tensor &in)
   AddBiasAndEvalActivationFunction(bias.data, bias.dims, out.getData(), shapeToDims(out.getShape()));
 }
 
+void slice(Tensor& out, const char* params, const Tensor& in) {
+  Shape starts = deserializeShape(params);
+  Shape sizes = deserializeShape(params);
+  Shape out_s = deserializeShape(params);
+
+  out.reShape(out_s);
+  SliceParams slice_params;
+  slice_params.begin_count = starts.getDims();
+  slice_params.size_count = sizes.getDims();
+  for (int i = 0; i < 4; i++) {
+    slice_params.begin[i] = starts[i];
+    slice_params.size[i] = sizes[i];
+  }
+  Slice(
+    slice_params,
+    shapeToRuntimeShape(in.getShape()), in.getData(),
+    shapeToRuntimeShape(out_s), out.getData()
+  );
+}
+
 void relu(Tensor &out, const char *params, const Tensor &in)
 {
   const float *input = in.getData();
diff --git a/contrib/nnc/passes/soft_backend/code_snippets/cpp_slice.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_slice.def
new file mode 100644 (file)
index 0000000..3b07053
--- /dev/null
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+template <typename T>
+inline void Slice(const SliceParams& op_params,
+                  const RuntimeShape& input_shape, const T* input_data,
+                  const RuntimeShape& output_shape, T* output_data) {
+  const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+  // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
+  TFLITE_DCHECK_LE(op_params.begin_count, 4);
+  TFLITE_DCHECK_LE(op_params.size_count, 4);
+  const int begin_count = op_params.begin_count;
+  const int size_count = op_params.size_count;
+  // We front-pad the begin and size vectors.
+  const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
+  const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
+                     ? ext_shape.Dims(0)
+                     : start_b + op_params.size[0];
+  const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
+  const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
+                     ? ext_shape.Dims(1)
+                     : start_h + op_params.size[size_count - 3];
+  const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
+  const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
+                     ? ext_shape.Dims(2)
+                     : start_w + op_params.size[size_count - 2];
+  const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
+  const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
+                     ? ext_shape.Dims(3)
+                     : start_d + op_params.size[size_count - 1];
+
+  T* out_ptr = output_data;
+  for (int in_b = start_b; in_b < stop_b; ++in_b) {
+    for (int in_h = start_h; in_h < stop_h; ++in_h) {
+      for (int in_w = start_w; in_w < stop_w; ++in_w) {
+        const int len = stop_d - start_d;
+        memcpy(out_ptr,
+               input_data + Offset(ext_shape, in_b, in_h, in_w, start_d),
+               len * sizeof(T));
+        out_ptr += len;
+      }
+    }
+  }
+}
index caba51b..de667b5 100644 (file)
@@ -91,6 +91,7 @@ void TfliteImporter::processUnsupportedOp(const Operator* op) {
                                       _problemsOpSet);
       break;
     case BuiltinOperator_SOFTMAX:
+    case BuiltinOperator_SLICE:
     case BuiltinOperator_RESHAPE:
     case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
     case BuiltinOperator_SQUEEZE:
@@ -202,6 +203,9 @@ void TfliteImporter::walkOperator(const Operator* op) {
     case BuiltinOperator_SOFTMAX:
       outputs = _opCreator->createSoftmax(inputs, params, op->builtin_options_as<SoftmaxOptions>());
       break;
+    case BuiltinOperator_SLICE:
+      outputs = _opCreator->createSlice(inputs, params, op->builtin_options_as_SliceOptions());
+      break;
     case BuiltinOperator_SQUEEZE:
       outputs = _opCreator->createSqueeze(inputs, params, op->builtin_options_as<SqueezeOptions>());
       break;
index 5e56930..e7bf27b 100644 (file)
@@ -32,6 +32,7 @@
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/ResizeOp.h"
 #include "core/modelIR/operations/SigmoidOp.h"
+#include "core/modelIR/operations/SliceOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqrtOp.h"
 #include "core/modelIR/operations/SqueezeOp.h"
@@ -188,6 +189,25 @@ std::vector<mir::Operation*> TFLiteOpCreator::createSoftmax(InputOps inputs, Inp
   return createOp<ops::SoftmaxOp>(ActivationFunctionType_NONE, inputs[0]->getOutput(0), axis);
 }
 
+Shape shapeFromTensor(mir::Tensor<int32_t>&& t) {
+  Shape temporary_shape(4);
+  int j = 0;
+  for (auto i : mir::ShapeRange(t.getShape())) {
+    temporary_shape.dim(j++) = t.at(i);
+  }
+  return temporary_shape;
+}
+
+std::vector<mir::Operation*> TFLiteOpCreator::createSlice(InputOps inputs, InputParams params,
+                                                          const ::tflite::SliceOptions*) {
+  auto starts = shapeFromTensor(mir::Tensor<int32_t>(params[0]));
+  auto sizes = shapeFromTensor(mir::Tensor<int32_t>(params[1]));
+  assert(starts.rank() == inputs[0]->getOutputShape(0).rank() &&
+         starts.rank() == sizes.rank());
+  return createOp<ops::SliceOp>(ActivationFunctionType_NONE, inputs[0]->getOutput(0),
+                                starts, sizes);
+}
+
 std::vector<mir::Operation*> TFLiteOpCreator::convertReshape(InputOps inputs, InputParams params,
                                                              const ReshapeOptions* opts) {
   // TODO: we should also support "-1" values in new_shape, which means that correct
index 99aff24..0d802c2 100644 (file)
@@ -66,6 +66,8 @@ public:
 
   std::vector<mir::Operation*> createSoftmax(InputOps, InputParams, const ::tflite::SoftmaxOptions*);
 
+  std::vector<mir::Operation*> createSlice(InputOps, InputParams, const ::tflite::SliceOptions*);
+
   std::vector<mir::Operation*> convertReshape(InputOps, InputParams,
                                               const ::tflite::ReshapeOptions*);
 
index 234f048..5be131f 100644 (file)
@@ -44,6 +44,7 @@
 #include "code_snippets/cpp_relu.def"
 #include "code_snippets/cpp_softmax.def"
 #include "code_snippets/cpp_sqrt.def"
+#include "code_snippets/cpp_slice.def"
 #include "code_snippets/cpp_tanh.def"
 #include "code_snippets/cpp_transpose.def"
 
@@ -72,6 +73,7 @@
 #include "core/modelIR/operations/ReshapeOp.h"
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/SigmoidOp.h"
+#include "core/modelIR/operations/SliceOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/SqrtOp.h"
 #include "core/modelIR/operations/TanhOp.h"
@@ -812,6 +814,30 @@ TEST(cpp_operations_test, softmax) {
   }
 }
 
+TEST(cpp_operations_test, slice4d) {
+  vector<int> shape_data{5, 30, 40, 12};
+  vector<int> starts[] = {{0, 0, 0, 0},
+                          {1, 1, 1, 1},
+                          {1, 0, 1, 0},
+                          {0, 1, 1, 0}};
+  vector<int> sizes[] = {
+    {-1, -1, -1, -1},
+    {4,  -1, 10, -1},
+  };
+  for (auto st : starts) {
+    for (auto sz : sizes) {
+      Tensor a_input_tensor;
+      vector<unique_ptr<mir::TensorVariant>> input_n_tensor(1);
+      fillTensors(input_n_tensor[0], a_input_tensor, shape_data, 1.0f);
+      auto op_gen = [&st, &sz](mir::Graph& g, const std::vector<mir::IODescriptor>& inputs) {
+        return g.create<mir::ops::SliceOp>("y", inputs[0], mir::Shape(st),
+                                           mir::Shape(sz));
+      };
+      createAndRunTestGraph(op_gen, slice, input_n_tensor, a_input_tensor);
+    }
+  }
+}
+
 TEST(cpp_operations_test, reshape) {
   // test prerequisites
   vector<int> input_shape_data{2, 3, 4, 5};