[enco] Extract CanonicalChannelAxis (#1259)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 31 Aug 2018 00:56:50 +0000 (09:56 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 31 Aug 2018 00:56:50 +0000 (09:56 +0900)
This commit extracts CanonicalChannelAxis class from ConvolutionSpec to
make it reusable from outside.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp

index 345a499..59c6a18 100644 (file)
@@ -2,6 +2,42 @@
 
 #include <cassert>
 
+namespace
+{
+
+/**
+ * @breif Infer (positive) channel axis from channel axis specifier (which may be negative)
+ *
+ * NOTE This implementation SHOULD be aligned with CanonicalAxisIndex method in ::caffe::Blob
+ */
+class CanonicalChannelAxis
+{
+public:
+  CanonicalChannelAxis(int32_t axis) : _axis{axis}
+  {
+    // DO NOTHING
+  }
+
+public:
+  uint32_t eval(const nncc::core::ADT::tensor::Shape &ifm) const
+  {
+    if (_axis > 0)
+    {
+      return static_cast<uint32_t>(_axis);
+    }
+
+    assert(ifm.rank() >= static_cast<uint32_t>(-_axis));
+    return static_cast<uint32_t>(ifm.rank() + _axis);
+  }
+
+private:
+  int32_t _axis;
+};
+
+CanonicalChannelAxis canonical_channel_axis(int32_t axis) { return CanonicalChannelAxis{axis}; }
+
+} // namespace
+
 ConvolutionSpec::ConvolutionSpec(const ::caffe::ConvolutionParameter &param)
     : _param(param), _num_output{0}
 {
@@ -22,15 +58,7 @@ ConvolutionSpec::ConvolutionSpec(const ::caffe::ConvolutionParameter &param)
 
 uint32_t ConvolutionSpec::channel_axis(void) const
 {
-  const int32_t axis = _param.axis();
-
-  if (axis > 0)
-  {
-    return static_cast<uint32_t>(axis);
-  }
-
-  assert(ifm_rank() >= static_cast<uint32_t>(-axis));
-  return ifm_rank() + axis;
+  return canonical_channel_axis(_param.axis()).eval(ifm_shape());
 }
 
 uint32_t ConvolutionSpec::stride(uint32_t spatial_axis) const