[exo] adds loop to fuse the nodes for FuseReluPass (#8266)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Thu, 17 Oct 2019 07:23:06 +0000 (16:23 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 17 Oct 2019 07:23:06 +0000 (16:23 +0900)
This completes FuseReluPass.cpp by add loops that performs fusing selected nodes.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo/src/Pass/FuseReluPass.cpp

index b07e962..71f74c9 100644 (file)
@@ -65,6 +65,37 @@ struct Collector final : public locoex::TFLNodeMutableVisitor<void>
   std::set<locoex::TFLNode *> candidates;
 };
 
+void set_activation_fusion(loco::Node *node, locoex::FusedActFunc f)
+{
+  using namespace locoex;
+
+  if (auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node))
+    fusable_node->fusedActivationFunction(f);
+  else
+    assert(false);
+}
+
+struct Performer final : public locoex::TFLNodeMutableVisitor<void>
+{
+  void visit(locoex::TFLRelu *the_relu) final
+  {
+    set_activation_fusion(the_relu->features(), locoex::FusedActFunc::RELU);
+
+    loco::replace(the_relu).with(the_relu->features());
+    the_relu->features(nullptr);
+  }
+
+  void visit(locoex::TFLRelu6 *the_relu6) final
+  {
+    set_activation_fusion(the_relu6->features(), locoex::FusedActFunc::RELU6);
+
+    loco::replace(the_relu6).with(the_relu6->features());
+    the_relu6->features(nullptr);
+  }
+
+  void visit(locoex::TFLNode *) final { assert(false && "should not be called"); }
+};
+
 } // namespace
 
 namespace exo
@@ -83,9 +114,14 @@ bool FuseReluPass::run(loco::Graph *g)
     }
   }
 
-  // TODO write code for fusing
+  Performer performer;
+
+  for (auto node : collector.candidates)
+  {
+    node->accept(&performer);
+  }
 
-  return false;
+  return collector.candidates.size() > 0;
 }
 
 } // namespace exo