Add support for Gather operation to soft backend.
Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
inferOutputShapes();
}
+ int32_t getAxis() const { return _axis; }
+
private:
void inferOutputShapes();
#include "cpp_elementwise.generated.h"
#include "cpp_pad.generated.h"
#include "cpp_transpose.generated.h"
+#include "cpp_gather.generated.h"
namespace nnc
{
out.write(cpp_pad, sizeof(cpp_pad));
out.write(cpp_conv_transpose, sizeof(cpp_conv_transpose));
out.write(cpp_transpose, sizeof(cpp_transpose));
+ out.write(cpp_gather, sizeof(cpp_gather));
out.write(cpp_operations, sizeof(cpp_operations));
out.write(cpp_scale, sizeof(cpp_scale));
out.write(cpp_dropout, sizeof(cpp_dropout));
#include "CommonData.def"
+#include "core/modelIR/operations/BatchNormOp.h"
+#include "core/modelIR/operations/BiasAddOp.h"
+#include "core/modelIR/operations/CappedReluOp.h"
#include "core/modelIR/operations/ConcatOp.h"
#include "core/modelIR/operations/ConstantOp.h"
#include "core/modelIR/operations/Conv2DOp.h"
#include "core/modelIR/operations/Deconv2DOp.h"
#include "core/modelIR/operations/DepthwiseConv2DOp.h"
-#include "core/modelIR/operations/SoftmaxOp.h"
-#include "core/modelIR/operations/PoolOp.h"
+#include "core/modelIR/operations/DropoutOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/EluOp.h"
#include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GatherOp.h"
#include "core/modelIR/operations/GemmOp.h"
-#include "core/modelIR/operations/CappedReluOp.h"
-#include "core/modelIR/operations/BiasAddOp.h"
+#include "core/modelIR/operations/PadOp.h"
+#include "core/modelIR/operations/PoolOp.h"
+#include "core/modelIR/operations/ReduceFOp.h"
#include "core/modelIR/operations/ReluOp.h"
-#include "core/modelIR/operations/ResizeOp.h"
-#include "core/modelIR/operations/EluOp.h"
#include "core/modelIR/operations/ReshapeOp.h"
-#include "core/modelIR/operations/BatchNormOp.h"
+#include "core/modelIR/operations/ResizeOp.h"
#include "core/modelIR/operations/ScaleOp.h"
-#include "core/modelIR/operations/DropoutOp.h"
-#include "core/modelIR/operations/TanhOp.h"
-#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/SoftmaxOp.h"
#include "core/modelIR/operations/SqueezeOp.h"
-#include "core/modelIR/operations/PadOp.h"
-#include "core/modelIR/operations/ReduceFOp.h"
+#include "core/modelIR/operations/TanhOp.h"
#include "core/modelIR/operations/TransposeOp.h"
#include "pass/PassException.h"
void Serializer::visit(mir::ops::GatherOp& op) {
_curOp->_paramStartOffset = _buffer.size();
- assert(false && "Not yet implemented");
+ // serialize parameters
+ serializeT<int32_t>(op.getAxis());
+ // serialize output shape
+ serializeShape(op.getOutputShape(0));
}
} // namespace nnc
return Offset(shape, index[0], index[1], index[2], index[3]);
}
+struct GatherParams {
+ int16 axis;
+};
+
struct TransposeParams {
int8 perm_count;
int32 perm[4];
--- /dev/null
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+template <typename T, typename CoordsT = int32>
+inline void Gather(const GatherParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& coords_shape, const CoordsT* coords_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ int axis = op_params.axis;
+ if (axis < 0) {
+ axis += input_shape.DimensionsCount();
+ }
+ TFLITE_DCHECK_GE(axis, 0);
+ TFLITE_DCHECK_LT(axis, input_shape.DimensionsCount());
+ const int axis_size = input_shape.Dims(axis);
+ const int coords_count = coords_shape.FlatSize();
+
+ int outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= input_shape.Dims(i);
+ }
+
+ int inner_size = 1;
+ for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
+ inner_size *= input_shape.Dims(i);
+ }
+
+ for (int outer = 0; outer < outer_size; ++outer) {
+ for (int i = 0; i < coords_count; ++i) {
+ TFLITE_DCHECK_GE(coords_data[i], 0);
+ TFLITE_DCHECK_LT(coords_data[i], axis_size);
+ std::memcpy(
+ output_data + (outer * coords_count + i) * inner_size,
+ input_data + (outer * axis_size + coords_data[i]) * inner_size,
+ sizeof(T) * inner_size);
+ }
+ }
+}
shapeToRuntimeShape(in.getShape()), in.getData(),
shapeToRuntimeShape(out.getShape()), out.getData());
}
+
+void gather(Tensor &out, const char *params, const Tensor &data, const Tensor &indices) {
+ GatherParams gather_params;
+ gather_params.axis = deserializeT<int32_t>(params);
+
+ Shape out_s = deserializeShape(params);
+ out.reShape(out_s);
+
+ // reinterpret_cast is used here because indices in ModelIR are integral, but getData returns
+ // pointer to float.
+ Gather(gather_params,
+ shapeToRuntimeShape(data.getShape()), data.getData(),
+ shapeToRuntimeShape(indices.getShape()), reinterpret_cast<const int32*>(indices.getData()),
+ shapeToRuntimeShape(out.getShape()), out.getData());
+}
#include "code_snippets/eigen.def"
+#include "code_snippets/cpp_header_types.def"
#include "code_snippets/cpp_common_funcs.def"
#include "code_snippets/cpp_add_bias.def"
#include "code_snippets/cpp_conv.def"
#include "code_snippets/cpp_conv_transpose.def"
#include "code_snippets/cpp_depthwise_conv.def"
+#include "code_snippets/cpp_elementwise.def"
+#include "code_snippets/cpp_elu.def"
#include "code_snippets/cpp_fully_connected.def"
+#include "code_snippets/cpp_gather.def"
+#include "code_snippets/cpp_pad.def"
#include "code_snippets/cpp_pool.def"
#include "code_snippets/cpp_reduce.def"
#include "code_snippets/cpp_relu.def"
#include "code_snippets/cpp_softmax.def"
-#include "code_snippets/cpp_elu.def"
-#include "code_snippets/cpp_elementwise.def"
#include "code_snippets/cpp_tanh.def"
-#include "code_snippets/cpp_pad.def"
#include "code_snippets/cpp_transpose.def"
-#include "CommonData.def"
-#include "code_snippets/cpp_header_types.def"
#include "code_snippets/cpp_operations.def"
#include "code_snippets/cpp_scale.def"