[loco] Implement FilterAxis Permutation (#3675)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 4 Jun 2019 09:30:51 +0000 (18:30 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 4 Jun 2019 09:30:51 +0000 (18:30 +0900)
* [loco] Implement FilterAxis Permutation

This commit implements Permutation over FilterAxis similarly as that
over FeatureAxis.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Return const reference

* Fix a typo

contrib/loco/include/loco/IR/Domain.h
contrib/loco/include/loco/IR/PermutingCodec.h
contrib/loco/src/IR/PermutingCodec.cpp
contrib/loco/src/IR/PermutingCodec.test.cpp

index c7ee4fe..d4ca518 100644 (file)
@@ -41,6 +41,7 @@ enum class Domain
   Unknown,
   Tensor,
   Feature,
+  Filter,
   /* ... */
 };
 
index c9cf0ae..8250604 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "loco/IR/FeatureAxis.h"
 #include "loco/IR/FeatureCodec.h"
+#include "loco/IR/FilterAxis.h"
 
 #include <map>
 
@@ -110,6 +111,41 @@ private:
   Permutation<Domain::Feature> _perm;
 };
 
+/**
+ * @brief Mapping between Filter/Tensor Axis
+ */
+template <> class Permutation<Domain::Filter>
+{
+public:
+  Permutation() = default;
+
+public:
+  /**
+   * @brief Return whether a given filter axis has a corresponding tensor axis
+   *
+   * This method does not validate the corresponding value.
+   */
+  bool mapped(const FilterAxis &axis_f) const;
+
+  /**
+   * @brief Get the tensor axis corresponding to a given filter axis
+   *
+   * This method works correctly only for mapped filter axes.
+   */
+  const TensorAxis &axis(const FilterAxis &axis_f) const;
+
+  /**
+   * @brief Set the tensor axis corresponding to a given filter axis
+   */
+  TensorAxis &axis(const FilterAxis &axis_f);
+
+  TensorAxis operator[](const FilterAxis &axis_f) const { return axis(axis_f); }
+  TensorAxis &operator[](const FilterAxis &axis_f) { return axis(axis_f); }
+
+private:
+  std::map<FilterAxis, TensorAxis> _map;
+};
+
 } // namespace loco
 
 #endif // __LOCO_IR_PERMUTING_CODEC_H__
index 9afa199..fab91ea 100644 (file)
@@ -173,3 +173,87 @@ FeatureIndex PermutingDecoder<Domain::Feature>::value(const TensorIndex &in) con
 bool PermutingDecoder<Domain::Feature>::valid(void) const { return ::valid(_perm); }
 
 } // namespace loco
+
+/**
+ * Filter Domain
+ */
+namespace
+{
+
+using loco::FilterAxis;
+
+inline bool valid(const FilterAxis &axis)
+{
+  switch (axis)
+  {
+  case FilterAxis::Count:
+    return true;
+  case FilterAxis::Depth:
+    return true;
+  case FilterAxis::Height:
+    return true;
+  case FilterAxis::Width:
+    return true;
+  default:
+    break;
+  }
+
+  return false;
+}
+
+inline bool valid(const loco::Permutation<loco::Domain::Filter> &perm)
+{
+  auto check = [&perm](FilterAxis axis_f) {
+    if (!perm.mapped(axis_f))
+      return false;
+    return perm.axis(axis_f) < 4;
+  };
+
+  if (!check(FilterAxis::Count))
+    return false;
+  if (!check(FilterAxis::Depth))
+    return false;
+  if (!check(FilterAxis::Height))
+    return false;
+  if (!check(FilterAxis::Width))
+    return false;
+
+  // Check whether tensor axes are all distinct
+  std::set<loco::TensorAxis> values;
+
+  values.insert(perm[FilterAxis::Count]);
+  values.insert(perm[FilterAxis::Depth]);
+  values.insert(perm[FilterAxis::Height]);
+  values.insert(perm[FilterAxis::Width]);
+
+  return values.size() == 4;
+}
+
+} // namespace
+
+namespace loco
+{
+
+//
+// Permutation
+//
+bool Permutation<Domain::Filter>::mapped(const FilterAxis &axis_f) const
+{
+  assert(valid(axis_f) && "invalid filter axis");
+  return _map.find(axis_f) != _map.end();
+}
+
+const uint32_t &Permutation<Domain::Filter>::axis(const FilterAxis &axis_f) const
+{
+  assert(valid(axis_f) && "invalid filter axis");
+  assert(mapped(axis_f) && "unmapped filter axis");
+  return _map.at(axis_f);
+}
+
+uint32_t &Permutation<Domain::Filter>::axis(const FilterAxis &axis_f)
+{
+  assert(valid(axis_f) && "invalid filter axis");
+  return _map[axis_f];
+}
+
+} // namespace loco
index 13a8295..3dffdd7 100644 (file)
@@ -49,6 +49,35 @@ TEST(PemutationTest, feature)
   ASSERT_EQ(perm[FeatureAxis::Width], 8);
 }
 
+TEST(PemutationTest, filter)
+{
+  Permutation<Domain::Filter> perm;
+
+  // All values are invalid at the beginning
+  ASSERT_FALSE(perm.mapped(FilterAxis::Count));
+  ASSERT_FALSE(perm.mapped(FilterAxis::Depth));
+  ASSERT_FALSE(perm.mapped(FilterAxis::Height));
+  ASSERT_FALSE(perm.mapped(FilterAxis::Width));
+
+  // Update mapping
+  perm[FilterAxis::Count] = 5;
+  perm[FilterAxis::Depth] = 6;
+  perm[FilterAxis::Height] = 7;
+  perm[FilterAxis::Width] = 8;
+
+  // Now perm has a mapping for all the axes
+  ASSERT_TRUE(perm.mapped(FilterAxis::Count));
+  ASSERT_TRUE(perm.mapped(FilterAxis::Depth));
+  ASSERT_TRUE(perm.mapped(FilterAxis::Height));
+  ASSERT_TRUE(perm.mapped(FilterAxis::Width));
+
+  // Check the value
+  ASSERT_EQ(perm[FilterAxis::Count], 5);
+  ASSERT_EQ(perm[FilterAxis::Depth], 6);
+  ASSERT_EQ(perm[FilterAxis::Height], 7);
+  ASSERT_EQ(perm[FilterAxis::Width], 8);
+}
+
 TEST(PermutingEncoderTest, feature)
 {
   PermutingEncoder<Domain::Feature> enc;