Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / pass / PermutationEliminationPass.h
index 1c84300..614e44c 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,9 +17,8 @@
 #ifndef __ONERT_GRAPH_PASS_PERMUTATION_ELIMINATION_PASS_H__
 #define __ONERT_GRAPH_PASS_PERMUTATION_ELIMINATION_PASS_H__
 
-#include "LoweredOperandPass.h"
-#include "ir/Operand.h"
-#include "ir/OperandIndexSequence.h"
+#include "ir/OperationVisitor.h"
+#include "LoweredOperationPass.h"
 
 namespace onert
 {
@@ -28,55 +27,35 @@ namespace ir
 namespace pass
 {
 
-class PermutationEliminationPass : public LoweredOperandPass
+/**
+ * @brief An optimization pass that removes Permute operations if possible
+ *
+ * There may be some Permute operations that are inserted by PermutationInsertionPass or other
+ * passes. This pass checks all Permute operations and eliminates them if Permute in/out tensors
+ * are compatible and layouts match.
+ *
+ * Permute input tensor is kept and the output is removed for all the cases, except model outputs.
+ * As all output tensors have to be controlflow backend, so the output is kept.
+ *
+ * @note This is an optimization pass which means that everything should work fine even if this pass
+ *       was skipped.
+ */
+class PermutationEliminationPass : public LoweredOperationPass, public OperationVisitor
 {
 public:
-  using LoweredOperandPass::LoweredOperandPass;
+  using LoweredOperationPass::LoweredOperationPass;
 
 public:
-  std::string id() override { return "PermutationEliminationPass"; }
+  std::string id() final { return "PermutationEliminationPass"; }
 
-  void callback(const OperandIndex &index, Operand &object) override;
+public:
+  void callback(const OperationIndex &i, Operation &n) final;
 
 private:
-  /**
-   * @brief Remove Permute operation that permutates input
-   *
-   * Note: This function aslo removes model's input and
-   * sets output of permutation as model's new input
-   *
-   * @param inp_index is the target operand index for the elimination
-   * @param object is the target operand object for the elimination
-   *
-   * @return
-   */
-  void eliminateInput(const OperandIndex &inp_index, Operand &object);
-
-  /**
-   * @brief Remove Permute operation that permutates output of a model
-   *
-   * Note: This function aslo removes model's output and
-   * sets input of permutation as model's new output
-   *
-   * @param out_index is the target operand index for the elimination
-   * @param object is the target operand object for the elimination
-   *
-   * @return
-   */
-  void eliminateOutput(const OperandIndex &out_index, Operand &object);
+  void visit(const operation::Permute &) final;
 
-  /**
-   * @brief Determine if passed operands are permute layer's input and output, that must be
-   * eliminated
-   *
-   * @param inp_index indexes of the input operand to operation
-   * @param out_index indexes of the output operand to operation
-   * @param is_for_model_input checking for model's input or output
-   *
-   * @return if it is permutation layer
-   */
-  bool isPermuteLayerToEliminate(const OperandIndexSequence &inp_indexes,
-                                 const OperandIndexSequence &out_indexes, bool is_for_model_input);
+private:
+  ir::OperationIndex _op_ind;
 };
 
 } // namespace pass