From e39dfae2dc5d06971b7993e3dddc8b9e4817bf69 Mon Sep 17 00:00:00 2001 From: "Efimov Alexander/AI Tools Lab/./Samsung Electronics" Date: Wed, 22 Aug 2018 16:09:17 +0300 Subject: [PATCH] Add support of negative axis in softmax (#1117) Use negative parameter in softmax operation to count axises from last one --- .../include/nnc/core/IR/model/operations/softmax_op.h | 13 ++++++++++++- .../nnc/libs/frontend/caffe/src/caffe_op_creator.cpp | 3 ++- .../nnc/libs/frontend/tflite/src/tflite_op_creator.cpp | 4 ++-- contrib/nnc/unittests/core/operation.cpp | 17 +++++++++++++++++ 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/contrib/nnc/libs/core/include/nnc/core/IR/model/operations/softmax_op.h b/contrib/nnc/libs/core/include/nnc/core/IR/model/operations/softmax_op.h index e9d6f5b..091dcdb 100644 --- a/contrib/nnc/libs/core/include/nnc/core/IR/model/operations/softmax_op.h +++ b/contrib/nnc/libs/core/include/nnc/core/IR/model/operations/softmax_op.h @@ -21,7 +21,18 @@ class SoftmaxOp : public OpDescription public: explicit SoftmaxOp(int axis) : OpDescription(1, 1), _axis(axis) {} - int getAxis() const { return _axis; } + int getAxis() const + { + if (_axis < 0) + { + // Negative axis is used to index starting from the last element of the shape + // -1 means last element, -2 means second from end, like in python + int res = _axis + getInputShape(0).rank(); + assert(res >= 0); + return res; + } + return _axis; + } private: int _axis; diff --git a/contrib/nnc/libs/frontend/caffe/src/caffe_op_creator.cpp b/contrib/nnc/libs/frontend/caffe/src/caffe_op_creator.cpp index edf18df..2d41440 100644 --- a/contrib/nnc/libs/frontend/caffe/src/caffe_op_creator.cpp +++ b/contrib/nnc/libs/frontend/caffe/src/caffe_op_creator.cpp @@ -234,7 +234,8 @@ __attribute__ ((unused)) static ops::PoolOp::PoolingType getPoolingType(const Po template __attribute__ ((unused)) static int getAxisValue(const OptsType& opts) { - int axis = 2; + // -1 represents last one dimension + int axis = -1; if (opts.has_axis()) { axis = opts.axis(); diff --git a/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp b/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp index 2f7e84e..d5cab37 100644 --- a/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp +++ b/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp @@ -76,8 +76,8 @@ std::vector OpCreator::createAvgPool(InputOps inputs, InputParams pa std::vector OpCreator::createSoftmax(InputOps inputs, InputParams params, const SoftmaxOptions *opts) { - // TODO: here assuming that softmax is applied to a 1-d tensor - return createOp(inputs, ActivationFunctionType_NONE, 1); + // -1 represents last one dimension + return createOp(inputs, ActivationFunctionType_NONE, -1); } std::vector OpCreator::createReshape(InputOps inputs, InputParams params, diff --git a/contrib/nnc/unittests/core/operation.cpp b/contrib/nnc/unittests/core/operation.cpp index 35f34e5..3944fb4 100644 --- a/contrib/nnc/unittests/core/operation.cpp +++ b/contrib/nnc/unittests/core/operation.cpp @@ -1,4 +1,5 @@ #include "nnc/core/IR/model/operations/operation.h" +#include "nnc/core/IR/model/operations/softmax_op.h" #include @@ -16,3 +17,19 @@ TEST(OpDescription, InputOutputShapeTest) { ASSERT_EQ(inShape, op.getInputShape(0)); ASSERT_EQ(outShape, op.getOutputShape(0)); } + +TEST(OpDescription, SoftmaxAxisTest) { + Shape inShape{1,2,3}; + + ops::SoftmaxOp op_1(1); + op_1.setInputShape(0, inShape); + ASSERT_EQ(op_1.getAxis(), 1); + + ops::SoftmaxOp op_n1(-1); + op_n1.setInputShape(0, inShape); + ASSERT_EQ(op_n1.getAxis(), 2); + + ops::SoftmaxOp op_n3(-3); + op_n3.setInputShape(0, inShape); + ASSERT_EQ(op_n3.getAxis(), 0); +} \ No newline at end of file -- 2.7.4