[enco] Introduce ShapeQuery class (#1404)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 7 Sep 2018 04:31:36 +0000 (13:31 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 7 Sep 2018 04:31:36 +0000 (13:31 +0900)
* [enco] Introduce ShapeQuery class

This commit generalizes CanonicalChannelAxis implementation as
ShapeQuery to reuse this logic for other layers.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Fix typos in comment

contrib/enco/frontend/caffe/src/ConvolutionSpec.cpp
contrib/enco/frontend/caffe/src/ShapeQuery.cpp [new file with mode: 0644]
contrib/enco/frontend/caffe/src/ShapeQuery.h [new file with mode: 0644]

index c2b12be..0088ce3 100644 (file)
@@ -1,4 +1,5 @@
 #include "ConvolutionSpec.h"
+#include "ShapeQuery.h"
 
 #include <cassert>
 
@@ -21,13 +22,7 @@ public:
 public:
   uint32_t eval(const nncc::core::ADT::tensor::Shape &ifm) const
   {
-    if (_axis > 0)
-    {
-      return static_cast<uint32_t>(_axis);
-    }
-
-    assert(ifm.rank() >= static_cast<uint32_t>(-_axis));
-    return static_cast<uint32_t>(ifm.rank() + _axis);
+    return query_on(ifm).axis(axis_specifier(_axis));
   }
 
 private:
diff --git a/contrib/enco/frontend/caffe/src/ShapeQuery.cpp b/contrib/enco/frontend/caffe/src/ShapeQuery.cpp
new file mode 100644 (file)
index 0000000..bef0f7b
--- /dev/null
@@ -0,0 +1,24 @@
+#include "ShapeQuery.h"
+
+#include <cassert>
+
+//
+// AxisSpecifier
+//
+AxisSpecifier axis_specifier(int32_t value) { return AxisSpecifier{value}; }
+
+//
+// ShapeQuery
+//
+uint32_t ShapeQuery::axis(const AxisSpecifier &specifier) const
+{
+  if (specifier.value() > 0)
+  {
+    return static_cast<uint32_t>(specifier.value());
+  }
+
+  assert(_shape->rank() >= static_cast<uint32_t>(-specifier.value()));
+  return static_cast<uint32_t>(_shape->rank() + specifier.value());
+}
+
+ShapeQuery query_on(const nncc::core::ADT::tensor::Shape &shape) { return ShapeQuery{&shape}; }
diff --git a/contrib/enco/frontend/caffe/src/ShapeQuery.h b/contrib/enco/frontend/caffe/src/ShapeQuery.h
new file mode 100644 (file)
index 0000000..2184863
--- /dev/null
@@ -0,0 +1,59 @@
+#ifndef __SHAPE_QUERY_H__
+#define __SHAPE_QUERY_H__
+
+#include <nncc/core/ADT/tensor/Shape.h>
+
+/**
+ * @brief A wrapper class for an integer number that specifies axis
+ *
+ * Several Caffe layers includes 'axis' parameter (which may be negative) which specifies
+ * some axis required for operation.
+ *
+ * Here are several examples:
+ * - Convolution layer uses 'axis' parameter to specify "channel" axis
+ *   (http://caffe.berkeleyvision.org/tutorial/layers/convolution.html)
+ * - Concat layer uses 'axis' parameter to specify axis to be concatenated
+ *   (http://caffe.berkeleyvision.org/tutorial/layers/concat.html)
+ *
+ * AxisSpecifier class is introduced to distinguish this 'axis' parameter from other integers
+ * (to prevent possible mistake).
+ */
+class AxisSpecifier
+{
+public:
+  explicit AxisSpecifier(int32_t value) : _value{value}
+  {
+    // DO NOTHING
+  }
+
+public:
+  int32_t value(void) const { return _value; }
+
+private:
+  int32_t _value = 1;
+};
+
+AxisSpecifier axis_specifier(int32_t value);
+
+/**
+ * @brief A wrapper class that allows additional queries over tensor shape.
+ */
+class ShapeQuery
+{
+public:
+  explicit ShapeQuery(const nncc::core::ADT::tensor::Shape *shape) : _shape{shape}
+  {
+    // DO NOTHING
+  }
+
+public:
+  // @brief Return the dimension number (axis) specified by a given axis specifier
+  uint32_t axis(const AxisSpecifier &) const;
+
+private:
+  const nncc::core::ADT::tensor::Shape *_shape;
+};
+
+ShapeQuery query_on(const nncc::core::ADT::tensor::Shape &);
+
+#endif // __SHAPE_QUERY_H__