[exo] Alternative helper for commutative args (#9140)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Mon, 25 Nov 2019 02:00:41 +0000 (11:00 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 25 Nov 2019 02:00:41 +0000 (11:00 +0900)
* [exo] Alternative helper for commutative args

This commit introduces alternative helper to extract and set commutative
argument of the node.

Current:
`bool ok = commutative_args_of(node).is(arg1, arg2);`

Suggested alternative:
`bool ok = set(&arg1, &arg2).with_commutative_args_of(node);`

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
* Fill!

compiler/exo/src/Pass/FuseInstanceNormPass.cpp

index 0d3428c..7b2fa4c 100644 (file)
@@ -118,6 +118,66 @@ bool CommutativeArgsGetter<COMM_NODE>::is(ARG_TYPE_1 *&arg_1, ARG_TYPE_2 *&arg_2
   return false;
 }
 
+/**
+ * Alternative approach
+ *
+ *         bool ok = fill(&arg1, &arg2).with_commutative_args_of(node);
+ */
+
+template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final
+{
+public:
+  NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2)
+  {
+    // DO NOTHING
+  }
+
+  template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node);
+
+private:
+  ARG_TYPE_1 **_arg_1;
+  ARG_TYPE_2 **_arg_2;
+};
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
+{
+  return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2};
+}
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+template <class COMM_NODE>
+bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node)
+{
+  // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2
+  {
+    auto x = dynamic_cast<ARG_TYPE_1 *>(node->x());
+    auto y = dynamic_cast<ARG_TYPE_2 *>(node->y());
+
+    if (x && y)
+    {
+      *_arg_1 = x;
+      *_arg_2 = y;
+      return true;
+    }
+  }
+
+  // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1
+  {
+    auto x = dynamic_cast<ARG_TYPE_2 *>(node->x());
+    auto y = dynamic_cast<ARG_TYPE_1 *>(node->y());
+
+    if (x && y)
+    {
+      *_arg_1 = y;
+      *_arg_2 = x;
+      return true;
+    }
+  }
+
+  return false;
+}
+
 } // namespace
 
 // Helper to check detail