From 79b0bb1f4c5f9e83e1459ee8b780c4ee6cb492ca Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 7 Sep 2018 13:31:36 +0900 Subject: [PATCH] [enco] Introduce ShapeQuery class (#1404) * [enco] Introduce ShapeQuery class This commit generalizes CanonicalChannelAxis implementation as ShapeQuery to reuse this logic for other layers. Signed-off-by: Jonghyun Park * Fix typos in comment --- .../enco/frontend/caffe/src/ConvolutionSpec.cpp | 9 +--- contrib/enco/frontend/caffe/src/ShapeQuery.cpp | 24 +++++++++ contrib/enco/frontend/caffe/src/ShapeQuery.h | 59 ++++++++++++++++++++++ 3 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 contrib/enco/frontend/caffe/src/ShapeQuery.cpp create mode 100644 contrib/enco/frontend/caffe/src/ShapeQuery.h diff --git a/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp b/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp index c2b12be..0088ce3 100644 --- a/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp +++ b/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp @@ -1,4 +1,5 @@ #include "ConvolutionSpec.h" +#include "ShapeQuery.h" #include @@ -21,13 +22,7 @@ public: public: uint32_t eval(const nncc::core::ADT::tensor::Shape &ifm) const { - if (_axis > 0) - { - return static_cast(_axis); - } - - assert(ifm.rank() >= static_cast(-_axis)); - return static_cast(ifm.rank() + _axis); + return query_on(ifm).axis(axis_specifier(_axis)); } private: diff --git a/contrib/enco/frontend/caffe/src/ShapeQuery.cpp b/contrib/enco/frontend/caffe/src/ShapeQuery.cpp new file mode 100644 index 0000000..bef0f7b --- /dev/null +++ b/contrib/enco/frontend/caffe/src/ShapeQuery.cpp @@ -0,0 +1,24 @@ +#include "ShapeQuery.h" + +#include + +// +// AxisSpecifier +// +AxisSpecifier axis_specifier(int32_t value) { return AxisSpecifier{value}; } + +// +// ShapeQuery +// +uint32_t ShapeQuery::axis(const AxisSpecifier &specifier) const +{ + if (specifier.value() > 0) + { + return static_cast(specifier.value()); + } + + assert(_shape->rank() >= static_cast(-specifier.value())); + return static_cast(_shape->rank() + specifier.value()); +} + +ShapeQuery query_on(const nncc::core::ADT::tensor::Shape &shape) { return ShapeQuery{&shape}; } diff --git a/contrib/enco/frontend/caffe/src/ShapeQuery.h b/contrib/enco/frontend/caffe/src/ShapeQuery.h new file mode 100644 index 0000000..2184863 --- /dev/null +++ b/contrib/enco/frontend/caffe/src/ShapeQuery.h @@ -0,0 +1,59 @@ +#ifndef __SHAPE_QUERY_H__ +#define __SHAPE_QUERY_H__ + +#include + +/** + * @brief A wrapper class for an integer number that specifies axis + * + * Several Caffe layers includes 'axis' parameter (which may be negative) which specifies + * some axis required for operation. + * + * Here are several examples: + * - Convolution layer uses 'axis' parameter to specify "channel" axis + * (http://caffe.berkeleyvision.org/tutorial/layers/convolution.html) + * - Concat layer uses 'axis' parameter to specify axis to be concatenated + * (http://caffe.berkeleyvision.org/tutorial/layers/concat.html) + * + * AxisSpecifier class is introduced to distinguish this 'axis' parameter from other integers + * (to prevent possible mistake). + */ +class AxisSpecifier +{ +public: + explicit AxisSpecifier(int32_t value) : _value{value} + { + // DO NOTHING + } + +public: + int32_t value(void) const { return _value; } + +private: + int32_t _value = 1; +}; + +AxisSpecifier axis_specifier(int32_t value); + +/** + * @brief A wrapper class that allows additional queries over tensor shape. + */ +class ShapeQuery +{ +public: + explicit ShapeQuery(const nncc::core::ADT::tensor::Shape *shape) : _shape{shape} + { + // DO NOTHING + } + +public: + // @brief Return the dimension number (axis) specified by a given axis specifier + uint32_t axis(const AxisSpecifier &) const; + +private: + const nncc::core::ADT::tensor::Shape *_shape; +}; + +ShapeQuery query_on(const nncc::core::ADT::tensor::Shape &); + +#endif // __SHAPE_QUERY_H__ -- 2.7.4