From b1b7754ad93afbd32dc07ff5a5c85be816e9109c Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 19 Aug 2019 13:17:57 +0900 Subject: [PATCH] [loco] Eq operato for NodeShape (#6661) This will introduce Eq(==) operator for NodeShape Signed-off-by: SaeHie Park --- compiler/loco/include/loco/IR/NodeShape.h | 2 + compiler/loco/src/IR/NodeShape.cpp | 68 +++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/compiler/loco/include/loco/IR/NodeShape.h b/compiler/loco/include/loco/IR/NodeShape.h index fdc1321..b9589e3 100644 --- a/compiler/loco/include/loco/IR/NodeShape.h +++ b/compiler/loco/include/loco/IR/NodeShape.h @@ -60,6 +60,8 @@ private: std::vector _dims; }; +bool operator==(const NodeShape &lhs, const NodeShape &rhs); + } // namespace loco #endif // __LOCO_IR_NODE_SHAPE_H__ diff --git a/compiler/loco/src/IR/NodeShape.cpp b/compiler/loco/src/IR/NodeShape.cpp index d253b09..b4c84a9 100644 --- a/compiler/loco/src/IR/NodeShape.cpp +++ b/compiler/loco/src/IR/NodeShape.cpp @@ -17,6 +17,7 @@ #include "loco/IR/NodeShape.h" #include +#include // // BiasShape Support @@ -177,3 +178,70 @@ template <> TensorShape NodeShape::as(void) const } } // namespace loco + +namespace loco +{ + +bool operator==(const NodeShape &lhs, const NodeShape &rhs) +{ + if (lhs.domain() != rhs.domain()) + return false; + + switch (lhs.domain()) + { + case loco::Domain::Tensor: + { + auto lhs_t = lhs.as(); + auto rhs_t = rhs.as(); + if (lhs_t.rank() != rhs_t.rank()) + return false; + for (uint32_t axis = 0; axis < lhs_t.rank(); ++axis) + { + if (!(lhs_t.dim(axis) == rhs_t.dim(axis))) + return false; + } + return true; + } + + case loco::Domain::Feature: + { + auto lhs_f = lhs.as(); + auto rhs_f = rhs.as(); + + return (lhs_f.count() == rhs_f.count() && lhs_f.depth() == rhs_f.depth() && + lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width()); + } + + case loco::Domain::Filter: + { + auto lhs_f = lhs.as(); + auto rhs_f = rhs.as(); + + return (lhs_f.count() == rhs_f.count() && lhs_f.depth() == rhs_f.depth() && + lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width()); + } + + case loco::Domain::DepthwiseFilter: + { + auto lhs_f = lhs.as(); + auto rhs_f = rhs.as(); + + return (lhs_f.multiplier() == rhs_f.multiplier() && lhs_f.depth() == rhs_f.depth() && + lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width()); + } + + case loco::Domain::Bias: + { + auto lhs_f = lhs.as(); + auto rhs_f = rhs.as(); + + return (lhs_f.length() == rhs_f.length()); + } + + default: + throw std::runtime_error("Not supported domain for NodeShape equality"); + } + return false; +} + +} // namespace loco -- 2.7.4