[neurun] Add PermutationEliminationPass to eliminate permutation of I/O (#3552)
authorДилшоджон Умронхонович Пошшоев/AI Tools Lab /SRR/Engineer/삼성전자 <d.poshshoev@samsung.com>
Fri, 23 Nov 2018 07:30:07 +0000 (10:30 +0300)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 23 Nov 2018 07:30:07 +0000 (16:30 +0900)
Eliminates permutation after model's input and before output
This is just for NHWC_TO_NCHW in input and NCHW_TO_NHWC in output.
Permutation of input/output will be done during source/sink in next commit

Signed-off-by: Poshshoev Dilshodzhon <d.poshshoev@samsung.com>
runtimes/neurun/src/graph/Graph.h
runtimes/neurun/src/graph/pass/PermutationEliminationPass.cc

index 2fc42a9..b6d93ac 100644 (file)
@@ -116,6 +116,7 @@ private:
   // Accessors
 public:
   const operand::IndexSet &getInputs() const { return _model->inputs; }
+  operand::IndexSet &getInputs() { return _model->inputs; }
   const operand::IndexSet &getOutputs() const { return _model->outputs; }
   operand::IndexSet &getOutputs() { return _model->outputs; }
   const operand::Set &operands() const { return _model->operands; }
index 0ed1774..0ee4a15 100644 (file)
@@ -28,9 +28,6 @@ namespace graph
 {
 namespace pass
 {
-
-#define UNUSED(x) (void)x
-
 void PermutationEliminationPass::callback(const operand::Index &inp_index, operand::Object &object)
 {
   if (_graph.getInputs().contains(inp_index))
@@ -46,28 +43,147 @@ void PermutationEliminationPass::callback(const operand::Index &inp_index, opera
 void PermutationEliminationPass::eliminateInput(const operand::Index &inp_index,
                                                 operand::Object &object)
 {
-  UNUSED(inp_index;);
-  UNUSED(object);
-  /* TODO */
+  auto &model_inputs = _graph.getInputs();
+
+  // get uses of the model's given input
+  auto uses = object.getUses();
+
+  // input must be used just by permutation
+  if (uses.size() != 1)
+  {
+    return;
+  }
+
+  for (auto input_use : uses.list())
+  {
+    auto &perm_operation = _graph.operations().at(input_use);
+    auto perm_inputs = perm_operation.getInputs();
+
+    auto perm_outputs = perm_operation.getOutputs();
+
+    if (!isPermuteLayerToEliminate(perm_inputs, perm_outputs, true))
+    {
+      return;
+    }
+
+    assert(perm_inputs.at(0) == inp_index);
+
+    VERBOSE(PermutationEliminationPass::EliminateInput) << "remove NHWC_TO_NCHW permutation\n";
+
+    // set model's new input, which was output of permutation
+    model_inputs.replace(inp_index, perm_outputs.at(0));
+
+    // remove model's input, which is also input of permutation
+    _graph.removeOperand(inp_index);
+
+    // remove permutation operation
+    _graph.operations().remove(input_use);
+
+    VERBOSE(PermutationEliminationPass::EliminateInput)
+        << inp_index.value() << " is model's input and is removed. New input is "
+        << perm_outputs.at(0).value() << "\n"
+        << input_use.value() << " is removed permutation operation\n";
+  }
 }
 
 void PermutationEliminationPass::eliminateOutput(const operand::Index &out_index,
                                                  operand::Object &object)
 {
-  UNUSED(out_index);
-  UNUSED(object);
-  /* TODO */
+  auto &model_outputs = _graph.getOutputs();
+
+  // get defs of the model's given output
+  auto defs = object.getDef();
+
+  // output must use just permutation
+  if (defs.size() != 1)
+  {
+    return;
+  }
+
+  for (auto output_def : defs.list())
+  {
+    auto &perm_operation = _graph.operations().at(output_def);
+    auto perm_outputs = perm_operation.getOutputs();
+
+    auto perm_inputs = perm_operation.getInputs();
+    if (!isPermuteLayerToEliminate(perm_inputs, perm_outputs, false))
+    {
+      return;
+    }
+
+    assert(perm_outputs.at(0) == out_index);
+
+    VERBOSE(PermutationEliminationPass::EliminateOutput) << "remove NCHW_TO_NHWC permutation\n";
+
+    // Update operations' output that is used by permute operand
+    for (auto perm_input_index : perm_inputs)
+    {
+      auto &perm_input_operand = _graph.operands().at(perm_input_index);
+      perm_input_operand.removeUse(output_def);
+    }
+
+    // set model's new output, which was input of permutation
+    model_outputs.replace(out_index, perm_inputs.at(0));
+
+    // remove model's output, which is also output of permutation
+    _graph.removeOperand(out_index);
+
+    // remove permutation operation
+    _graph.operations().remove(output_def);
+
+    VERBOSE(PermutationEliminationPass::EliminateOutput)
+        << out_index.value() << " is model's output and is removed. New output is "
+        << perm_inputs.at(0).value() << "\n"
+        << output_def.value() << " is removed permutation operation\n";
+  }
 }
 
 bool PermutationEliminationPass::isPermuteLayerToEliminate(const operand::IndexSet &inp_indexes,
                                                            const operand::IndexSet &out_indexes,
                                                            bool is_for_model_input)
 {
-  UNUSED(inp_indexes);
-  UNUSED(out_indexes);
-  UNUSED(is_for_model_input);
-  /* TODO */
-  return true;
+  auto input_def_backends = _graph.operands().at(inp_indexes.at(0)).lower_info()->def_backends();
+  auto output_def_backends = _graph.operands().at(out_indexes.at(0)).lower_info()->def_backends();
+
+  auto input_layout = input_def_backends.getOnlyElement()->config()->getOperandLayout();
+  auto output_layout = output_def_backends.getOnlyElement()->config()->getOperandLayout();
+
+  if (input_def_backends.size() != 1 || output_def_backends.size() != 1)
+  {
+    return false;
+  }
+
+  // all operands' backend must be the same
+  for (auto index : inp_indexes)
+  {
+    auto op_backend_set = _graph.operands().at(index).lower_info()->def_backends();
+    if (op_backend_set.size() != 1 ||
+        input_layout != op_backend_set.getOnlyElement()->config()->getOperandLayout())
+    {
+      return false;
+    }
+  }
+  // all operands' backend must be the same
+  for (auto index : out_indexes)
+  {
+    auto op_backend_set = _graph.operands().at(index).lower_info()->def_backends();
+    if (op_backend_set.size() != 1 ||
+        output_layout != op_backend_set.getOnlyElement()->config()->getOperandLayout())
+    {
+      return false;
+    }
+  }
+
+  if (is_for_model_input)
+  {
+    // check if this is NHWC_TO_NCHW permutation: must have single input, which is model's input
+    return (inp_indexes.size() == 1 && input_layout == graph::operand::Layout::NHWC &&
+            output_layout == graph::operand::Layout::NCHW);
+  }
+
+  // check if this is NCHW_TO_NHWC permutation: must have single output, which is model's output
+  return (out_indexes.size() == 1 && input_layout == graph::operand::Layout::NCHW &&
+          output_layout == graph::operand::Layout::NHWC);
 }
 
 } // namespace pass