};
/**
+ * @brief Permutation-based DepthwiseFilter-to-Tensor converter
+ */
+template <> class PermutingDecoder<Domain::DepthwiseFilter> final : public DepthwiseFilterDecoder
+{
+public:
+ PermutingDecoder() = default;
+
+public:
+ PermutingDecoder(const Permutation<Domain::DepthwiseFilter> &perm) : _perm{perm}
+ {
+ // DO NOTHING
+ }
+
+public:
+ bool valid(void) const;
+
+public:
+ TensorShape shape(const DepthwiseFilterShape &shape) const override;
+ DepthwiseFilterIndex value(const TensorIndex &index) const override;
+
+public:
+ const Permutation<Domain::DepthwiseFilter> *perm(void) const { return &_perm; }
+ Permutation<Domain::DepthwiseFilter> *perm(void) { return &_perm; }
+ void perm(const Permutation<Domain::DepthwiseFilter> &p) { _perm = p; }
+
+private:
+ Permutation<Domain::DepthwiseFilter> _perm;
+};
+
+/**
* @brief Mapping between Matrix/Tensor Axis
*/
template <> class Permutation<Domain::Matrix>
bool PermutingEncoder<Domain::DepthwiseFilter>::valid(void) const { return ::valid(_perm); }
+//
+// Permuting Decoder
+//
+TensorShape PermutingDecoder<Domain::DepthwiseFilter>::shape(const DepthwiseFilterShape &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ TensorShape out;
+ out.rank(4);
+
+ out.dim(_perm[DepthwiseFilterAxis::Depth]) = in.depth();
+ out.dim(_perm[DepthwiseFilterAxis::Multiplier]) = in.multiplier();
+ out.dim(_perm[DepthwiseFilterAxis::Height]) = in.height();
+ out.dim(_perm[DepthwiseFilterAxis::Width]) = in.width();
+
+ return out;
+}
+
+DepthwiseFilterIndex PermutingDecoder<Domain::DepthwiseFilter>::value(const TensorIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+ assert(in.rank() == 4);
+
+ DepthwiseFilterIndex out;
+
+ out.channel() = in.at(_perm[DepthwiseFilterAxis::Depth]);
+ out.nth() = in.at(_perm[DepthwiseFilterAxis::Multiplier]);
+ out.row() = in.at(_perm[DepthwiseFilterAxis::Height]);
+ out.column() = in.at(_perm[DepthwiseFilterAxis::Width]);
+
+ return out;
+}
+
+bool PermutingDecoder<Domain::DepthwiseFilter>::valid(void) const { return ::valid(_perm); }
+
} // namespace loco
/**
ASSERT_EQ(filter_index.row(), 2);
ASSERT_EQ(filter_index.column(), 3);
}
+
+TEST(PermutingDecoderTest, depthwise_filter)
+{
+ PermutingDecoder<Domain::DepthwiseFilter> dec;
+
+ // Decoder is invalid at the beginning
+ ASSERT_FALSE(dec.valid());
+
+ // Set "invalid" mapping
+ dec.perm()->axis(DepthwiseFilterAxis::Depth) = 0;
+ dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6;
+ dec.perm()->axis(DepthwiseFilterAxis::Height) = 1;
+ dec.perm()->axis(DepthwiseFilterAxis::Width) = 2;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set another "invalid" mapping
+ dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set "valid" mapping
+ dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3;
+
+ // Decoder is now valid
+ ASSERT_TRUE(dec.valid());
+
+ DepthwiseFilterShape dw_filter_shape;
+
+ dw_filter_shape.depth() = 8;
+ dw_filter_shape.multiplier() = 1;
+ dw_filter_shape.height() = 7;
+ dw_filter_shape.width() = 4;
+
+ // Get the corresponding depthwise filter shape
+ auto tensor_shape = dec.shape(dw_filter_shape);
+
+ ASSERT_EQ(tensor_shape.dim(0).value(), 8);
+ ASSERT_EQ(tensor_shape.dim(1).value(), 7);
+ ASSERT_EQ(tensor_shape.dim(2).value(), 4);
+ ASSERT_EQ(tensor_shape.dim(3).value(), 1);
+
+ // Let's find a source tensor index!
+ TensorIndex tensor_index;
+ tensor_index.resize(4);
+
+ tensor_index.at(0) = 4;
+ tensor_index.at(1) = 2;
+ tensor_index.at(2) = 1;
+ tensor_index.at(3) = 0;
+
+ auto dw_filter_index = dec.value(tensor_index);
+
+ ASSERT_EQ(dw_filter_index.channel(), 4);
+ ASSERT_EQ(dw_filter_index.nth(), 0);
+ ASSERT_EQ(dw_filter_index.row(), 2);
+ ASSERT_EQ(dw_filter_index.column(), 1);
+}