Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / pass / PermutationEliminationPass.cc
index 9e0291e..2deccd4 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  */
 
 #include "PermutationEliminationPass.h"
+#include "backend/controlflow/Config.h"
 
-#include "ir/Operand.h"
-#include "ir/operand/LowerInfo.h"
-#include "ir/Graph.h"
-#include "backend/IConfig.h"
 #include "util/logging.h"
 
 namespace onert
@@ -28,166 +25,129 @@ namespace ir
 {
 namespace pass
 {
-void PermutationEliminationPass::callback(const OperandIndex &inp_index, Operand &object)
-{
-  if (_graph.getInputs().contains(inp_index))
-  {
-    eliminateInput(inp_index, object);
-  }
-  else if (_graph.getOutputs().contains(inp_index))
-  {
-    eliminateOutput(inp_index, object);
-  }
-}
 
-void PermutationEliminationPass::eliminateInput(const OperandIndex &inp_index, Operand &object)
+void PermutationEliminationPass::callback(const OperationIndex &ind, Operation &node)
 {
-  auto &model_inputs = _graph.getInputs();
-
-  // get uses of the model's given input
-  auto uses = object.getUses();
+  _op_ind = ind;
+  node.accept(*this);
+};
 
-  // input must be used just by permutation
-  if (uses.size() != 1)
-  {
-    return;
-  }
+void PermutationEliminationPass::visit(const operation::Permute &node)
+{
+  auto in_operand = node.getInputs().at(0);
+  auto out_operand = node.getOutputs().at(0);
 
-  for (auto input_use : uses)
+  // Check if two tensors are both portable
+  // TODO Make this general, this is just a workaround to check two tensors are portable
   {
-    auto &perm_operation = _graph.operations().at(input_use);
-    auto perm_inputs = perm_operation.getInputs();
+    auto in_def_factor = _lowered_graph.getLowerInfo(in_operand)->def_factors().getOnlyElement();
+    auto out_def_factor = _lowered_graph.getLowerInfo(out_operand)->def_factors().getOnlyElement();
 
-    auto perm_outputs = perm_operation.getOutputs();
+    auto in_backend_id = in_def_factor.backend()->config()->id();
+    auto out_backend_id = out_def_factor.backend()->config()->id();
 
-    if (!isPermuteLayerToEliminate(perm_inputs, perm_outputs, true))
-    {
+    // TODO Fix this workaround that removes only Permute between cpu and controlflow backend.
+    //      This should be general.
+    if (!((in_backend_id == backend::controlflow::Config::ID && out_backend_id == "cpu") ||
+          (in_backend_id == "cpu" && out_backend_id == backend::controlflow::Config::ID)))
       return;
-    }
-
-    assert(perm_inputs.at(0) == inp_index);
-
-    VERBOSE(PermutationEliminationPass::EliminateInput) << "remove NHWC_TO_NCHW permutation\n";
-
-    // set model's new input, which was output of permutation
-    model_inputs.replace(inp_index, perm_outputs.at(0));
-
-    // remove model's input, which is also input of permutation
-    _graph.removeOperand(inp_index);
-
-    // remove permutation operation
-    assert(_lowered_graph.op_seqs().containsOperation(input_use));
-    auto op_seq_idx = _lowered_graph.op_seqs().getOperation(input_use);
-    _lowered_graph.op_seqs().remove(op_seq_idx);
-    _graph.operations().remove(input_use);
-
-    VERBOSE(PermutationEliminationPass::EliminateInput)
-        << inp_index.value() << " is model's input and is removed. New input is "
-        << perm_outputs.at(0).value() << "\n"
-        << input_use.value() << " is removed permutation operation\n";
-  }
-}
-
-void PermutationEliminationPass::eliminateOutput(const OperandIndex &out_index, Operand &object)
-{
-  auto &model_outputs = _graph.getOutputs();
-
-  // get defs of the model's given output
-  auto defs = object.getDef();
-
-  // output must use just permutation
-  if (defs.size() != 1)
-  {
-    return;
   }
 
-  for (auto output_def : defs)
+  if (_graph.getOutputs().contains(out_operand))
   {
-    auto &perm_operation = _graph.operations().at(output_def);
-    auto perm_outputs = perm_operation.getOutputs();
-
-    auto perm_inputs = perm_operation.getInputs();
-    if (!isPermuteLayerToEliminate(perm_inputs, perm_outputs, false))
+    // Exceptional case : When the output operand is a model output
+    // In this case we keep the output and remove the input
+
+    auto &out_operand_obj = _graph.operands().at(out_operand);
+    assert(out_operand_obj.getDef() == _op_ind);
+    out_operand_obj.unsetDef();
+    _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
+      if (!op_seq.getOutputs().contains(in_operand))
+        return;
+
+      // Update OpSequence/Operation edges and Operand edges
+      op_seq.replaceOutputs(in_operand, out_operand);
+      for (auto op : op_seq.operations())
+      {
+        auto &operation_obj = _graph.operations().at(op);
+        if (operation_obj.getOutputs().contains(in_operand))
+        {
+          operation_obj.replaceOutputs(in_operand, out_operand);
+          out_operand_obj.setDef(op);
+        }
+      }
+    });
+
+    // Remove Permute operation, enclosing OpSequence and the operand
     {
-      return;
-    }
-
-    assert(perm_outputs.at(0) == out_index);
+      _graph.removeOperand(in_operand);
 
-    VERBOSE(PermutationEliminationPass::EliminateOutput) << "remove NCHW_TO_NHWC permutation\n";
-
-    // Update operations' output that is used by permute operand
-    for (auto perm_input_index : perm_inputs)
-    {
-      auto &perm_input_operand = _graph.operands().at(perm_input_index);
-      perm_input_operand.removeUse(output_def);
+      auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
+      // Assumes enclosing OpSequence contatins just this Permute operation
+      assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
+      _lowered_graph.op_seqs().remove(op_seq_ind);
+      _graph.operations().remove(_op_ind);
     }
 
-    // set model's new output, which was input of permutation
-    model_outputs.replace(out_index, perm_inputs.at(0));
-
-    // remove model's output, which is also output of permutation
-    _graph.removeOperand(out_index);
-
-    // remove permutation operation
-    assert(_lowered_graph.op_seqs().containsOperation(output_def));
-    auto op_seq_idx = _lowered_graph.op_seqs().getOperation(output_def);
-    _lowered_graph.op_seqs().remove(op_seq_idx);
-    _graph.operations().remove(output_def);
-
-    VERBOSE(PermutationEliminationPass::EliminateOutput)
-        << out_index.value() << " is model's output and is removed. New output is "
-        << perm_inputs.at(0).value() << "\n"
-        << output_def.value() << " is removed permutation operation\n";
+    _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
+      if (!op_seq.getInputs().contains(in_operand))
+        return;
+
+      op_seq.replaceInputs(in_operand, out_operand);
+      for (auto op : op_seq.operations())
+      {
+        auto &operation_obj = _graph.operations().at(op);
+        if (operation_obj.getInputs().contains(in_operand))
+        {
+          operation_obj.replaceInputs(in_operand, out_operand);
+          out_operand_obj.insertUse(op);
+        }
+      }
+    });
+
+    VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
+    VERBOSE(removePermute) << "  - Input (removed) Operand : " << in_operand << std::endl;
+    VERBOSE(removePermute) << "  - Output(kept)    Operand : " << out_operand << std::endl;
   }
-}
-
-bool PermutationEliminationPass::isPermuteLayerToEliminate(const OperandIndexSequence &inp_indexes,
-                                                           const OperandIndexSequence &out_indexes,
-                                                           bool is_for_model_input)
-{
-  auto input_def_factors = _lowered_graph.getLowerInfo(inp_indexes.at(0))->def_factors();
-  auto output_def_factors = _lowered_graph.getLowerInfo(out_indexes.at(0))->def_factors();
-
-  auto input_layout = input_def_factors.getOnlyElement().layout();
-  auto output_layout = output_def_factors.getOnlyElement().layout();
-
-  if (input_def_factors.size() != 1 || output_def_factors.size() != 1)
-  {
-    return false;
-  }
-
-  // all operands' factor must be the same
-  for (auto index : inp_indexes)
-  {
-    auto op_factor_set = _lowered_graph.getLowerInfo(index)->def_factors();
-    if (op_factor_set.size() != 1 ||
-        input_layout != _lowered_graph.getLowerInfo(index)->def_factors().getOnlyElement().layout())
-    {
-      return false;
-    }
-  }
-  // all operands' factor must be the same
-  for (auto index : out_indexes)
+  else
   {
-    auto op_factor_set = _lowered_graph.getLowerInfo(index)->def_factors();
-    if (op_factor_set.size() != 1 ||
-        output_layout !=
-            _lowered_graph.getLowerInfo(index)->def_factors().getOnlyElement().layout())
+    // Otherwise keep the input and remove the output
+
+    auto &in_operand_obj = _graph.operands().at(in_operand);
+    in_operand_obj.removeUse(_op_ind);
+
+    // Make OpSequences(that use the output) use the input
+    _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
+      if (!op_seq.getInputs().contains(out_operand))
+        return;
+
+      op_seq.replaceInputs(out_operand, in_operand);
+      for (auto op : op_seq.operations())
+      {
+        auto &operation_obj = _graph.operations().at(op);
+        if (operation_obj.getInputs().contains(out_operand))
+        {
+          operation_obj.replaceInputs(out_operand, in_operand);
+          in_operand_obj.insertUse(op);
+        }
+      }
+    });
+
+    // Remove Permute operation, enclosing OpSequence and the operand
     {
-      return false;
+      _graph.removeOperand(out_operand);
+
+      auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
+      // Assumes enclosing OpSequence contatins just this Permute operation
+      assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
+      _lowered_graph.op_seqs().remove(op_seq_ind);
+      _graph.operations().remove(_op_ind);
     }
-  }
 
-  if (is_for_model_input)
-  {
-    // check if this is NHWC_TO_NCHW permutation: must have single input, which is model's input
-    return (inp_indexes.size() == 1 && input_layout == Layout::NHWC &&
-            output_layout == Layout::NCHW);
+    VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
+    VERBOSE(removePermute) << "  - Input (kept)    Operand : " << in_operand << std::endl;
+    VERBOSE(removePermute) << "  - Output(removed) Operand : " << out_operand << std::endl;
   }
-
-  // check if this is NCHW_TO_NHWC permutation: must have single output, which is model's output
-  return (out_indexes.size() == 1 && input_layout == Layout::NCHW && output_layout == Layout::NHWC);
 }
 
 } // namespace pass