From 2bf40b2f4bc2b69a354af7e3b941cc8357eecc49 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 25 Nov 2019 11:00:41 +0900 Subject: [PATCH] [exo] Alternative helper for commutative args (#9140) * [exo] Alternative helper for commutative args This commit introduces alternative helper to extract and set commutative argument of the node. Current: `bool ok = commutative_args_of(node).is(arg1, arg2);` Suggested alternative: `bool ok = set(&arg1, &arg2).with_commutative_args_of(node);` Signed-off-by: Cheongyo Bahk * Fill! --- compiler/exo/src/Pass/FuseInstanceNormPass.cpp | 60 ++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/compiler/exo/src/Pass/FuseInstanceNormPass.cpp b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp index 0d3428c..7b2fa4c 100644 --- a/compiler/exo/src/Pass/FuseInstanceNormPass.cpp +++ b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp @@ -118,6 +118,66 @@ bool CommutativeArgsGetter::is(ARG_TYPE_1 *&arg_1, ARG_TYPE_2 *&arg_2 return false; } +/** + * Alternative approach + * + * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node); + */ + +template class NodeFiller final +{ +public: + NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2) + { + // DO NOTHING + } + + template bool with_commutative_args_of(const COMM_NODE *node); + +private: + ARG_TYPE_1 **_arg_1; + ARG_TYPE_2 **_arg_2; +}; + +template +inline NodeFiller fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) +{ + return NodeFiller{arg_1, arg_2}; +} + +template +template +bool NodeFiller::with_commutative_args_of(const COMM_NODE *node) +{ + // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2 + { + auto x = dynamic_cast(node->x()); + auto y = dynamic_cast(node->y()); + + if (x && y) + { + *_arg_1 = x; + *_arg_2 = y; + return true; + } + } + + // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1 + { + auto x = dynamic_cast(node->x()); + auto y = dynamic_cast(node->y()); + + if (x && y) + { + *_arg_1 = y; + *_arg_2 = x; + return true; + } + } + + return false; +} + } // namespace // Helper to check detail -- 2.7.4