[loco] Eq operato for NodeShape (#6661)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 19 Aug 2019 04:17:57 +0000 (13:17 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 19 Aug 2019 04:17:57 +0000 (13:17 +0900)
This will introduce Eq(==) operator for NodeShape

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/loco/include/loco/IR/NodeShape.h
compiler/loco/src/IR/NodeShape.cpp

index fdc1321..b9589e3 100644 (file)
@@ -60,6 +60,8 @@ private:
   std::vector<Dimension> _dims;
 };
 
+bool operator==(const NodeShape &lhs, const NodeShape &rhs);
+
 } // namespace loco
 
 #endif // __LOCO_IR_NODE_SHAPE_H__
index d253b09..b4c84a9 100644 (file)
@@ -17,6 +17,7 @@
 #include "loco/IR/NodeShape.h"
 
 #include <cassert>
+#include <stdexcept>
 
 //
 // BiasShape Support
@@ -177,3 +178,70 @@ template <> TensorShape NodeShape::as(void) const
 }
 
 } // namespace loco
+
+namespace loco
+{
+
+bool operator==(const NodeShape &lhs, const NodeShape &rhs)
+{
+  if (lhs.domain() != rhs.domain())
+    return false;
+
+  switch (lhs.domain())
+  {
+    case loco::Domain::Tensor:
+    {
+      auto lhs_t = lhs.as<TensorShape>();
+      auto rhs_t = rhs.as<TensorShape>();
+      if (lhs_t.rank() != rhs_t.rank())
+        return false;
+      for (uint32_t axis = 0; axis < lhs_t.rank(); ++axis)
+      {
+        if (!(lhs_t.dim(axis) == rhs_t.dim(axis)))
+          return false;
+      }
+      return true;
+    }
+
+    case loco::Domain::Feature:
+    {
+      auto lhs_f = lhs.as<FeatureShape>();
+      auto rhs_f = rhs.as<FeatureShape>();
+
+      return (lhs_f.count() == rhs_f.count() && lhs_f.depth() == rhs_f.depth() &&
+              lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width());
+    }
+
+    case loco::Domain::Filter:
+    {
+      auto lhs_f = lhs.as<FilterShape>();
+      auto rhs_f = rhs.as<FilterShape>();
+
+      return (lhs_f.count() == rhs_f.count() && lhs_f.depth() == rhs_f.depth() &&
+              lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width());
+    }
+
+    case loco::Domain::DepthwiseFilter:
+    {
+      auto lhs_f = lhs.as<DepthwiseFilterShape>();
+      auto rhs_f = rhs.as<DepthwiseFilterShape>();
+
+      return (lhs_f.multiplier() == rhs_f.multiplier() && lhs_f.depth() == rhs_f.depth() &&
+              lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width());
+    }
+
+    case loco::Domain::Bias:
+    {
+      auto lhs_f = lhs.as<BiasShape>();
+      auto rhs_f = rhs.as<BiasShape>();
+
+      return (lhs_f.length() == rhs_f.length());
+    }
+
+    default:
+      throw std::runtime_error("Not supported domain for NodeShape equality");
+  }
+  return false;
+}
+
+} // namespace loco