Unknown,
Tensor,
Feature,
+ Filter,
/* ... */
};
#include "loco/IR/FeatureAxis.h"
#include "loco/IR/FeatureCodec.h"
+#include "loco/IR/FilterAxis.h"
#include <map>
Permutation<Domain::Feature> _perm;
};
+/**
+ * @brief Mapping between Filter/Tensor Axis
+ */
+template <> class Permutation<Domain::Filter>
+{
+public:
+ Permutation() = default;
+
+public:
+ /**
+ * @brief Return whether a given filter axis has a corresponding tensor axis
+ *
+ * This method does not validate the corresponding value.
+ */
+ bool mapped(const FilterAxis &axis_f) const;
+
+ /**
+ * @brief Get the tensor axis corresponding to a given filter axis
+ *
+ * This method works correctly only for mapped filter axes.
+ */
+ const TensorAxis &axis(const FilterAxis &axis_f) const;
+
+ /**
+ * @brief Set the tensor axis corresponding to a given filter axis
+ */
+ TensorAxis &axis(const FilterAxis &axis_f);
+
+ TensorAxis operator[](const FilterAxis &axis_f) const { return axis(axis_f); }
+ TensorAxis &operator[](const FilterAxis &axis_f) { return axis(axis_f); }
+
+private:
+ std::map<FilterAxis, TensorAxis> _map;
+};
+
} // namespace loco
#endif // __LOCO_IR_PERMUTING_CODEC_H__
bool PermutingDecoder<Domain::Feature>::valid(void) const { return ::valid(_perm); }
} // namespace loco
+
+/**
+ * Filter Domain
+ */
+namespace
+{
+
+using loco::FilterAxis;
+
+inline bool valid(const FilterAxis &axis)
+{
+ switch (axis)
+ {
+ case FilterAxis::Count:
+ return true;
+ case FilterAxis::Depth:
+ return true;
+ case FilterAxis::Height:
+ return true;
+ case FilterAxis::Width:
+ return true;
+ default:
+ break;
+ }
+
+ return false;
+}
+
+inline bool valid(const loco::Permutation<loco::Domain::Filter> &perm)
+{
+ auto check = [&perm](FilterAxis axis_f) {
+ if (!perm.mapped(axis_f))
+ return false;
+ return perm.axis(axis_f) < 4;
+ };
+
+ if (!check(FilterAxis::Count))
+ return false;
+ if (!check(FilterAxis::Depth))
+ return false;
+ if (!check(FilterAxis::Height))
+ return false;
+ if (!check(FilterAxis::Width))
+ return false;
+
+ // Check whether tensor axes are all distinct
+ std::set<loco::TensorAxis> values;
+
+ values.insert(perm[FilterAxis::Count]);
+ values.insert(perm[FilterAxis::Depth]);
+ values.insert(perm[FilterAxis::Height]);
+ values.insert(perm[FilterAxis::Width]);
+
+ return values.size() == 4;
+}
+
+} // namespace
+
+namespace loco
+{
+
+//
+// Permutation
+//
+bool Permutation<Domain::Filter>::mapped(const FilterAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid filter axis");
+ return _map.find(axis_f) != _map.end();
+}
+
+const uint32_t &Permutation<Domain::Filter>::axis(const FilterAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid filter axis");
+ assert(mapped(axis_f) && "unmapped filter axis");
+ return _map.at(axis_f);
+}
+
+uint32_t &Permutation<Domain::Filter>::axis(const FilterAxis &axis_f)
+{
+ assert(valid(axis_f) && "invalid filter axis");
+ return _map[axis_f];
+}
+
+} // namespace loco
ASSERT_EQ(perm[FeatureAxis::Width], 8);
}
+TEST(PemutationTest, filter)
+{
+ Permutation<Domain::Filter> perm;
+
+ // All values are invalid at the beginning
+ ASSERT_FALSE(perm.mapped(FilterAxis::Count));
+ ASSERT_FALSE(perm.mapped(FilterAxis::Depth));
+ ASSERT_FALSE(perm.mapped(FilterAxis::Height));
+ ASSERT_FALSE(perm.mapped(FilterAxis::Width));
+
+ // Update mapping
+ perm[FilterAxis::Count] = 5;
+ perm[FilterAxis::Depth] = 6;
+ perm[FilterAxis::Height] = 7;
+ perm[FilterAxis::Width] = 8;
+
+ // Now perm has a mapping for all the axes
+ ASSERT_TRUE(perm.mapped(FilterAxis::Count));
+ ASSERT_TRUE(perm.mapped(FilterAxis::Depth));
+ ASSERT_TRUE(perm.mapped(FilterAxis::Height));
+ ASSERT_TRUE(perm.mapped(FilterAxis::Width));
+
+ // Check the value
+ ASSERT_EQ(perm[FilterAxis::Count], 5);
+ ASSERT_EQ(perm[FilterAxis::Depth], 6);
+ ASSERT_EQ(perm[FilterAxis::Height], 7);
+ ASSERT_EQ(perm[FilterAxis::Width], 8);
+}
+
TEST(PermutingEncoderTest, feature)
{
PermutingEncoder<Domain::Feature> enc;