[exo] rewriting FuseBiasAddPass to use Mixin<Bias> (#8507)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Mon, 28 Oct 2019 06:50:30 +0000 (15:50 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 28 Oct 2019 06:50:30 +0000 (15:50 +0900)
* [exo] FuseBiasAddPass

This code is modified to use Mixin<Bias> instead of TFLConv2D. (This commit is a part of modification.)

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* Former -> FormerT

* Fix get_former()

compiler/exo/src/Pass/FuseBiasAddPass.cpp

index db4aa1f..1668075 100644 (file)
 
 #include <set>
 
+/*
+  Note: Terms for variables in this implementation is as follows:
+
+      ex) subgraph handled:    TFLConv2D -------- TFLAdd
+                        (or TFLDepthwiseConv2D)  (or TFLSub)
+                                    |                 |
+                                   \|/               \|/
+            variable name :     former            latter
+                Type      :     FormerT           LatterT
+                    (shortened name from Mixin)  (template type)
+*/
 namespace
 {
 
+using FormerT = locoex::TFLNodeMixin<locoex::TFLNodeTrait::Bias>;
+
 locoex::TFLConst *get_const(loco::Node *x, loco::Node *y)
 {
   if (auto const_node = dynamic_cast<locoex::TFLConst *>(x))
@@ -41,6 +54,17 @@ locoex::TFLConst *get_const(loco::Node *x, loco::Node *y)
   return nullptr;
 }
 
+FormerT *get_former(loco::Node *x, loco::Node *y)
+{
+  if (auto node = dynamic_cast<FormerT *>(x))
+    return node;
+  else if (auto node = dynamic_cast<FormerT *>(y))
+    return node;
+
+  return nullptr;
+}
+
+// TODO replace this with get_former
 locoex::TFLConv2D *get_conv2d(loco::Node *x, loco::Node *y)
 {
   if (auto conv2d_node = dynamic_cast<locoex::TFLConv2D *>(x))
@@ -110,6 +134,7 @@ template <typename TFLType> float calc(float, float);
 template <> float calc<locoex::TFLAdd>(float x, float y) { return x + y; }
 template <> float calc<locoex::TFLSub>(float x, float y) { return x - y; }
 
+// TODO rewrite this by using FormerT and LatterT (Remove Conv2D dependency)
 // TFLType is either TFLAdd or TFLSub
 template <typename TFLType> class Fuser
 {
@@ -211,6 +236,7 @@ template <typename TFLType> void Fuser<TFLType>::fuse(void)
   _fusable_node->y(nullptr);
 }
 
+// TODO rewrite this by using FormerT and LatterT (Remove Conv2D dependency)
 struct Collector final : public locoex::TFLNodeMutableVisitor<void>
 {
   void setCandidate(locoex::TFLNode *node, loco::Node *x, loco::Node *y)
@@ -277,15 +303,19 @@ struct Collector final : public locoex::TFLNodeMutableVisitor<void>
 
 struct Performer final : public locoex::TFLNodeMutableVisitor<void>
 {
-  void visit(locoex::TFLAdd *node) final
+  void visit(locoex::TFLAdd *latter) final
   {
-    Fuser<locoex::TFLAdd> fuser(node);
+    assert(get_former(latter->x(), latter->y()));
+
+    Fuser<locoex::TFLAdd> fuser(latter);
     fuser.fuse();
   }
 
-  void visit(locoex::TFLSub *node) final
+  void visit(locoex::TFLSub *latter) final
   {
-    Fuser<locoex::TFLSub> fuser(node);
+    assert(get_former(latter->x(), latter->y()));
+
+    Fuser<locoex::TFLSub> fuser(latter);
     fuser.fuse();
   }