#include <cassert>
+namespace
+{
+
+/**
+ * @breif Infer (positive) channel axis from channel axis specifier (which may be negative)
+ *
+ * NOTE This implementation SHOULD be aligned with CanonicalAxisIndex method in ::caffe::Blob
+ */
+class CanonicalChannelAxis
+{
+public:
+ CanonicalChannelAxis(int32_t axis) : _axis{axis}
+ {
+ // DO NOTHING
+ }
+
+public:
+ uint32_t eval(const nncc::core::ADT::tensor::Shape &ifm) const
+ {
+ if (_axis > 0)
+ {
+ return static_cast<uint32_t>(_axis);
+ }
+
+ assert(ifm.rank() >= static_cast<uint32_t>(-_axis));
+ return static_cast<uint32_t>(ifm.rank() + _axis);
+ }
+
+private:
+ int32_t _axis;
+};
+
+CanonicalChannelAxis canonical_channel_axis(int32_t axis) { return CanonicalChannelAxis{axis}; }
+
+} // namespace
+
ConvolutionSpec::ConvolutionSpec(const ::caffe::ConvolutionParameter ¶m)
: _param(param), _num_output{0}
{
uint32_t ConvolutionSpec::channel_axis(void) const
{
- const int32_t axis = _param.axis();
-
- if (axis > 0)
- {
- return static_cast<uint32_t>(axis);
- }
-
- assert(ifm_rank() >= static_cast<uint32_t>(-axis));
- return ifm_rank() + axis;
+ return canonical_channel_axis(_param.axis()).eval(ifm_shape());
}
uint32_t ConvolutionSpec::stride(uint32_t spatial_axis) const