Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleOptimizer.cpp
index 5e1613a..b011581 100644 (file)
@@ -39,6 +39,7 @@
 #include "luci/Pass/FuseMeanWithMeanPass.h"
 #include "luci/Pass/FusePreActivationBatchNormPass.h"
 #include "luci/Pass/FusePReluPass.h"
+#include "luci/Pass/FuseGeluPass.h"
 #include "luci/Pass/FuseTransposeWithMeanPass.h"
 #include "luci/Pass/MakeBatchNormGammaPositivePass.h"
 #include "luci/Pass/RemoveDuplicateConstPass.h"
@@ -70,6 +71,7 @@
 #include "luci/Pass/SubstituteTransposeToReshapePass.h"
 #include "luci/Pass/TransformMinMaxToRelu6Pass.h"
 #include "luci/Pass/TransformMinReluToRelu6Pass.h"
+#include "luci/Pass/DecomposeHardSwishPass.h"
 #include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h"
 // TODO add more passes
 
@@ -137,7 +139,8 @@ bool OptimizeOptionsImpl::query(Algorithm algo)
 }
 
 // TODO Make a struct for args
-void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc)
+void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc,
+                          bool fuse_gelu)
 {
   logo::Phase phase;
 
@@ -160,6 +163,12 @@ void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_out
   if (fuse_fc)
     phase.emplace_back(std::make_unique<luci::FuseAddWithFullyConnectedPass>());
 
+  // Fuse decomposed ops to Gelu Op
+  // Why here? ConverNCHWToNHWCPass inserts additional Ops, so it is better to fuse
+  // Gelu in advance.
+  if (fuse_gelu)
+    phase.emplace_back(std::make_unique<luci::FuseGeluPass>());
+
   phase.emplace_back(
     std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
 
@@ -216,8 +225,9 @@ void CircleOptimizer::optimize(loco::Graph *g) const
       _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true";
 
     bool fuse_fc = _options->query(Options::Algorithm::FuseAddWithFullyConnected);
+    bool fuse_gelu = _options->query(Options::Algorithm::FuseGelu);
 
-    convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc);
+    convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc, fuse_gelu);
   }
 
   /* TRANSFORM DECLARATION BEGIN */
@@ -283,6 +293,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
   {
     phase.emplace_back(std::make_unique<FusePReluPass>());
   }
+  if (_options->query(Options::Algorithm::FuseGelu))
+  {
+    phase.emplace_back(std::make_unique<FuseGeluPass>());
+  }
   if (_options->query(Options::Algorithm::FuseTransposeWithMean))
   {
     phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>());
@@ -319,14 +333,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const
   {
     phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
   }
-  if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
-  {
-    phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
-  }
-  if (_options->query(Options::Algorithm::ForwardTransposeOp))
-  {
-    phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
-  }
   if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
   {
     phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
@@ -428,10 +434,26 @@ void CircleOptimizer::optimize(loco::Graph *g) const
   {
     phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>());
   }
+  if (_options->query(Options::Algorithm::DecomposeHardSwishPass))
+  {
+    phase.emplace_back(std::make_unique<luci::DecomposeHardSwishPass>());
+  }
   if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM))
   {
     phase.emplace_back(std::make_unique<luci::UnrollUnidirectionalSequenceLSTMPass>());
   }
+  // Forward Reshape/Transpose is done after
+  // 1. SubstituteXXXToReshape
+  // 2. RemoveRedundantReshape/Transpose
+  // See https://github.com/Samsung/ONE/pull/10596 for more details
+  if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
+  {
+    phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
+  }
+  if (_options->query(Options::Algorithm::ForwardTransposeOp))
+  {
+    phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
+  }
 
   /* TRANSFORM DECLARATION END */