From 37003c2f9a1fdc1ffe58f36175cc3022f2e85da0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 31 Aug 2018 09:56:50 +0900 Subject: [PATCH] [enco] Extract CanonicalChannelAxis (#1259) This commit extracts CanonicalChannelAxis class from ConvolutionSpec to make it reusable from outside. Signed-off-by: Jonghyun Park --- .../enco/frontend/caffe/src/ConvolutionSpec.cpp | 46 +++++++++++++++++----- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp b/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp index 345a499..59c6a18 100644 --- a/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp +++ b/contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp @@ -2,6 +2,42 @@ #include +namespace +{ + +/** + * @breif Infer (positive) channel axis from channel axis specifier (which may be negative) + * + * NOTE This implementation SHOULD be aligned with CanonicalAxisIndex method in ::caffe::Blob + */ +class CanonicalChannelAxis +{ +public: + CanonicalChannelAxis(int32_t axis) : _axis{axis} + { + // DO NOTHING + } + +public: + uint32_t eval(const nncc::core::ADT::tensor::Shape &ifm) const + { + if (_axis > 0) + { + return static_cast(_axis); + } + + assert(ifm.rank() >= static_cast(-_axis)); + return static_cast(ifm.rank() + _axis); + } + +private: + int32_t _axis; +}; + +CanonicalChannelAxis canonical_channel_axis(int32_t axis) { return CanonicalChannelAxis{axis}; } + +} // namespace + ConvolutionSpec::ConvolutionSpec(const ::caffe::ConvolutionParameter ¶m) : _param(param), _num_output{0} { @@ -22,15 +58,7 @@ ConvolutionSpec::ConvolutionSpec(const ::caffe::ConvolutionParameter ¶m) uint32_t ConvolutionSpec::channel_axis(void) const { - const int32_t axis = _param.axis(); - - if (axis > 0) - { - return static_cast(axis); - } - - assert(ifm_rank() >= static_cast(-axis)); - return ifm_rank() + axis; + return canonical_channel_axis(_param.axis()).eval(ifm_shape()); } uint32_t ConvolutionSpec::stride(uint32_t spatial_axis) const -- 2.7.4