};
/**
+ * @brief Permutation-based Filter-to-Tensor converter
+ */
+template <> class PermutingDecoder<Domain::Filter> final : public FilterDecoder
+{
+public:
+ PermutingDecoder() = default;
+
+public:
+ bool valid(void) const;
+
+public:
+ TensorShape shape(const FilterShape &tensor_shape) const override;
+ FilterIndex value(const TensorIndex &index) const override;
+
+public:
+ const Permutation<Domain::Filter> *perm(void) const { return &_perm; }
+ Permutation<Domain::Filter> *perm(void) { return &_perm; }
+ void perm(const Permutation<Domain::Filter> &p) { _perm = p; }
+
+private:
+ Permutation<Domain::Filter> _perm;
+};
+
+/**
* @brief Mapping between DepthwiseFilter/Tensor Axis
*/
template <> class Permutation<Domain::DepthwiseFilter>
bool PermutingEncoder<Domain::Filter>::valid(void) const { return ::valid(_perm); }
+//
+// Permuting Decoder
+//
+TensorShape PermutingDecoder<Domain::Filter>::shape(const FilterShape &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ TensorShape out;
+
+ out.rank(4);
+ out.dim(_perm[FilterAxis::Count]) = in.count();
+ out.dim(_perm[FilterAxis::Depth]) = in.depth();
+ out.dim(_perm[FilterAxis::Height]) = in.height();
+ out.dim(_perm[FilterAxis::Width]) = in.width();
+
+ return out;
+}
+
+FilterIndex PermutingDecoder<Domain::Filter>::value(const TensorIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ FilterIndex out;
+
+ out.nth() = in.at(_perm[FilterAxis::Count]);
+ out.channel() = in.at(_perm[FilterAxis::Depth]);
+ out.row() = in.at(_perm[FilterAxis::Height]);
+ out.column() = in.at(_perm[FilterAxis::Width]);
+
+ return out;
+}
+
+bool PermutingDecoder<Domain::Filter>::valid(void) const { return ::valid(_perm); }
+
} // namespace loco
/**
EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1);
EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2);
}
+
+TEST(PermutingDecoderTest, filter)
+{
+ PermutingDecoder<Domain::Filter> dec;
+
+ // Decoder is invalid at the beginning
+ ASSERT_FALSE(dec.valid());
+
+ // Set "invalid" mapping
+ dec.perm()->axis(FilterAxis::Count) = 0;
+ dec.perm()->axis(FilterAxis::Depth) = 6;
+ dec.perm()->axis(FilterAxis::Height) = 1;
+ dec.perm()->axis(FilterAxis::Width) = 2;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set another "invalid" mapping
+ dec.perm()->axis(FilterAxis::Depth) = 1;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set "valid" mapping
+ dec.perm()->axis(FilterAxis::Depth) = 3;
+
+ // Decoder is now valid
+ ASSERT_TRUE(dec.valid());
+
+ // Let's test with a small filter
+ FilterShape filter_shape;
+
+ filter_shape.count() = 10;
+ filter_shape.depth() = 3;
+ filter_shape.height() = 6;
+ filter_shape.width() = 8;
+
+ // Get the tensor shape corresponding to a given image
+ auto tensor_shape = dec.shape(filter_shape);
+
+ ASSERT_EQ(tensor_shape.rank(), 4);
+ ASSERT_EQ(tensor_shape.dim(0), 10); // COUNT
+ ASSERT_EQ(tensor_shape.dim(1), 6); // HEIGHT
+ ASSERT_EQ(tensor_shape.dim(2), 8); // WIDTH
+ ASSERT_EQ(tensor_shape.dim(3), 3); // DEPTH
+
+ // Let's find a source filter index!
+ TensorIndex tensor_index;
+
+ tensor_index.resize(4);
+
+ tensor_index.at(0) = 0; // BATCH(COUNT)
+ tensor_index.at(3) = 1; // CHANNEL(DEPTH)
+ tensor_index.at(1) = 2; // ROW(HEIGHT)
+ tensor_index.at(2) = 3; // COLUMN(WIDTH)
+
+ auto filter_index = dec.value(tensor_index);
+
+ ASSERT_EQ(filter_index.nth(), 0);
+ ASSERT_EQ(filter_index.channel(), 1);
+ ASSERT_EQ(filter_index.row(), 2);
+ ASSERT_EQ(filter_index.column(), 3);
+}