#include "ConvolutionSpec.h"
+#include "ShapeQuery.h"
#include <cassert>
public:
uint32_t eval(const nncc::core::ADT::tensor::Shape &ifm) const
{
- if (_axis > 0)
- {
- return static_cast<uint32_t>(_axis);
- }
-
- assert(ifm.rank() >= static_cast<uint32_t>(-_axis));
- return static_cast<uint32_t>(ifm.rank() + _axis);
+ return query_on(ifm).axis(axis_specifier(_axis));
}
private:
--- /dev/null
+#include "ShapeQuery.h"
+
+#include <cassert>
+
+//
+// AxisSpecifier
+//
+AxisSpecifier axis_specifier(int32_t value) { return AxisSpecifier{value}; }
+
+//
+// ShapeQuery
+//
+uint32_t ShapeQuery::axis(const AxisSpecifier &specifier) const
+{
+ if (specifier.value() > 0)
+ {
+ return static_cast<uint32_t>(specifier.value());
+ }
+
+ assert(_shape->rank() >= static_cast<uint32_t>(-specifier.value()));
+ return static_cast<uint32_t>(_shape->rank() + specifier.value());
+}
+
+ShapeQuery query_on(const nncc::core::ADT::tensor::Shape &shape) { return ShapeQuery{&shape}; }
--- /dev/null
+#ifndef __SHAPE_QUERY_H__
+#define __SHAPE_QUERY_H__
+
+#include <nncc/core/ADT/tensor/Shape.h>
+
+/**
+ * @brief A wrapper class for an integer number that specifies axis
+ *
+ * Several Caffe layers includes 'axis' parameter (which may be negative) which specifies
+ * some axis required for operation.
+ *
+ * Here are several examples:
+ * - Convolution layer uses 'axis' parameter to specify "channel" axis
+ * (http://caffe.berkeleyvision.org/tutorial/layers/convolution.html)
+ * - Concat layer uses 'axis' parameter to specify axis to be concatenated
+ * (http://caffe.berkeleyvision.org/tutorial/layers/concat.html)
+ *
+ * AxisSpecifier class is introduced to distinguish this 'axis' parameter from other integers
+ * (to prevent possible mistake).
+ */
+class AxisSpecifier
+{
+public:
+ explicit AxisSpecifier(int32_t value) : _value{value}
+ {
+ // DO NOTHING
+ }
+
+public:
+ int32_t value(void) const { return _value; }
+
+private:
+ int32_t _value = 1;
+};
+
+AxisSpecifier axis_specifier(int32_t value);
+
+/**
+ * @brief A wrapper class that allows additional queries over tensor shape.
+ */
+class ShapeQuery
+{
+public:
+ explicit ShapeQuery(const nncc::core::ADT::tensor::Shape *shape) : _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ // @brief Return the dimension number (axis) specified by a given axis specifier
+ uint32_t axis(const AxisSpecifier &) const;
+
+private:
+ const nncc::core::ADT::tensor::Shape *_shape;
+};
+
+ShapeQuery query_on(const nncc::core::ADT::tensor::Shape &);
+
+#endif // __SHAPE_QUERY_H__