From: 박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 Date: Fri, 31 Aug 2018 00:56:50 +0000 (+0900) Subject: [enco] Extract CanonicalChannelAxis (#1259) X-Git-Tag: nncc_backup~1998 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=37003c2f9a1fdc1ffe58f36175cc3022f2e85da0;p=platform%2Fcore%2Fml%2Fnnfw.git [enco] Extract CanonicalChannelAxis (#1259) This commit extracts CanonicalChannelAxis class from ConvolutionSpec to make it reusable from outside. Signed-off-by: Jonghyun Park --- diff --git a/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp b/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp index 345a499..59c6a18 100644 --- a/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp +++ b/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp @@ -2,6 +2,42 @@ #include +namespace +{ + +/** + * @breif Infer (positive) channel axis from channel axis specifier (which may be negative) + * + * NOTE This implementation SHOULD be aligned with CanonicalAxisIndex method in ::caffe::Blob + */ +class CanonicalChannelAxis +{ +public: + CanonicalChannelAxis(int32_t axis) : _axis{axis} + { + // DO NOTHING + } + +public: + uint32_t eval(const nncc::core::ADT::tensor::Shape &ifm) const + { + if (_axis > 0) + { + return static_cast(_axis); + } + + assert(ifm.rank() >= static_cast(-_axis)); + return static_cast(ifm.rank() + _axis); + } + +private: + int32_t _axis; +}; + +CanonicalChannelAxis canonical_channel_axis(int32_t axis) { return CanonicalChannelAxis{axis}; } + +} // namespace + ConvolutionSpec::ConvolutionSpec(const ::caffe::ConvolutionParameter ¶m) : _param(param), _num_output{0} { @@ -22,15 +58,7 @@ ConvolutionSpec::ConvolutionSpec(const ::caffe::ConvolutionParameter ¶m) uint32_t ConvolutionSpec::channel_axis(void) const { - const int32_t axis = _param.axis(); - - if (axis > 0) - { - return static_cast(axis); - } - - assert(ifm_rank() >= static_cast(-axis)); - return ifm_rank() + axis; + return canonical_channel_axis(_param.axis()).eval(ifm_shape()); } uint32_t ConvolutionSpec::stride(uint32_t spatial_axis) const