Introduce ConstantInsertionPass into neurun (#9312)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Mon, 2 Dec 2019 09:11:05 +0000 (18:11 +0900)
committer이한종/On-Device Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Mon, 2 Dec 2019 09:11:05 +0000 (18:11 +0900)
This commit introduces ConstantInsertionPass into neurun.
  - Creates new operands as a constant operand to be used on other backend
  - Makes to enable using a constant operand on other backend

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtime/neurun/core/include/graph/Graph.h
runtime/neurun/core/src/graph/Graph.cc
runtime/neurun/core/src/graph/pass/ConstantInsertionPass.cc [new file with mode: 0644]
runtime/neurun/core/src/graph/pass/ConstantInsertionPass.h [new file with mode: 0644]

index 36f67dc..fcfb8e1 100644 (file)
@@ -182,6 +182,7 @@ public:
   operand::LowerInfo *getLowerInfo(const model::OperandIndex &index);
   void setLowerInfo(const model::OperandIndex &index,
                     std::unique_ptr<operand::LowerInfo> &&lower_info);
+  void removeLowerInfo(const model::OperandIndex &index);
   model::Subgraphs &subgraphs()
   {
     assert(_subgraphs);
index 89ff907..cace60f 100644 (file)
@@ -29,6 +29,7 @@
 #include "operand/Shape4DConvert.h"
 #include "compiler/BackendResolver.h"
 #include "backend/IConfig.h"
+#include "pass/ConstantInsertionPass.h"
 #include "pass/PermutationInsertionPass.h"
 #include "pass/PermutationEliminationPass.h"
 #include "pass/PermutationOperationPass.h"
@@ -118,6 +119,9 @@ void Graph::lower(void)
 
     _subgraphs->dump("merged and sorted operations without permutation");
 
+    pass::ConstantInsertionPass ci_pass(*this);
+    ci_pass.run();
+
     // Set LowerInfo for each operand from the operand::LowerInfo holder
     manipulateLowerInfo(operands_lower_info);
 
@@ -222,6 +226,11 @@ void Graph::setLowerInfo(const model::OperandIndex &index,
   _lower_info_map->operand.insert(std::make_pair(index, std::move(lower_info)));
 }
 
+void Graph::removeLowerInfo(const model::OperandIndex &index)
+{
+  _lower_info_map->operand.erase(index);
+}
+
 void Graph::makeSubgraphs(
     model::OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info)
 {
@@ -397,24 +406,6 @@ void Graph::manipulateLowerInfo(
     }
   }
 
-  // Add DefFactor constants same as UseFactor
-  // NOTE This assumes a constant operand is used by only one operation
-  _model->operations.iterate([&](const model::OperationIndex &, model::Operation &node) {
-    // LowerInfo for input operands
-    for (auto operand : node.getInputs())
-    {
-      auto &&lower_info = operands_lower_info.at(operand);
-      if (lower_info->def_factors().empty())
-      {
-        // If it is a constant
-        if (!_model->inputs.contains(operand))
-        {
-          lower_info->addDefPermuteFactor(lower_info->use_factors().getOnlyElement());
-        }
-      }
-    }
-  });
-
   // Set LowerInfo for each operand from the operand::LowerInfo holder
   _model->operands.iterate([&](const model::OperandIndex &index, model::Operand &) {
     setLowerInfo(index, std::move(operands_lower_info[index]));
diff --git a/runtime/neurun/core/src/graph/pass/ConstantInsertionPass.cc b/runtime/neurun/core/src/graph/pass/ConstantInsertionPass.cc
new file mode 100644 (file)
index 0000000..03f6384
--- /dev/null
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConstantInsertionPass.h"
+
+#include "backend/Backend.h"
+#include <graph/Graph.h>
+#include "graph/operand/Shape4DConvert.h"
+#include <util/Utils.h>
+
+namespace neurun
+{
+namespace graph
+{
+namespace pass
+{
+
+void ConstantInsertionPass::callback(const model::OperationIndex &node_index,
+                                     model::Operation &node)
+{
+  const auto &subgraph_index = _graph.subgraphs().getOperation(node_index);
+  const auto subg_lower_info = _graph.getLowerInfo(subgraph_index);
+  const auto backend = subg_lower_info->backend();
+  const auto layout = subg_lower_info->layout();
+  const auto factor = graph::operand::PermuteFactor{backend, layout};
+
+  for (const auto input : node.getInputs())
+  {
+    auto &object = _graph.operands().at(input);
+
+    if (object.isConstant())
+    {
+      const auto key = ReplaceKey{input, factor};
+      if (_replace_operands_map.count(key) == 0)
+      {
+        auto new_object = object;
+        // TODO Remove const_case
+        const_cast<std::list<model::OperationIndex> &>(new_object.getDef().list()).clear();
+        const_cast<std::list<model::OperationIndex> &>(new_object.getUses().list()).clear();
+        const auto new_index = _graph.operands().emplace(new_object);
+        _replace_operands_map[key] = new_index;
+
+        _graph.setLowerInfo(new_index, nnfw::cpp14::make_unique<graph::operand::LowerInfo>(
+                                           graph::operand::asShape4D(new_object.shape())));
+        _graph.getLowerInfo(new_index)->addDefPermuteFactor(factor);
+      }
+
+      const auto replaced_input = _replace_operands_map[key];
+      // Update subgraph
+      if (_graph.subgraphs().at(subgraph_index).getInputs().contains(input))
+      {
+        _graph.subgraphs().at(subgraph_index).replaceInput(input, replaced_input);
+      }
+
+      // Update node
+      node.replaceInput(input, replaced_input);
+
+      // Update operand
+      auto &replaced_object = _graph.operands().at(replaced_input);
+      replaced_object.appendUse(node_index);
+
+      // Update lower_info
+      auto replaced_lower_info = _graph.getLowerInfo(replaced_input);
+      replaced_lower_info->addUsePermuteFactor(factor);
+
+      // Remove this node from def and uses of origin operand
+      if (object.getDef().contains(node_index))
+      {
+        object.removeDef(node_index);
+      }
+      object.removeUse(node_index);
+
+      // Remove origin operand
+      if (object.getDef().size() == 0 && object.getUses().size() == 0)
+      {
+        _graph.removeOperand(input);
+        _graph.removeLowerInfo(input);
+      }
+    }
+  }
+
+  // Now this runtime does not support the node making output as constant
+  for (const auto &output : node.getOutputs())
+  {
+    UNUSED_RELEASE(output);
+    assert(!_graph.operands().at(output).isConstant());
+  }
+}
+
+} // namespace pass
+} // namespace graph
+} // namespace neurun
diff --git a/runtime/neurun/core/src/graph/pass/ConstantInsertionPass.h b/runtime/neurun/core/src/graph/pass/ConstantInsertionPass.h
new file mode 100644 (file)
index 0000000..6a3d415
--- /dev/null
@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NEURUN_GRAPH_PASS_CONSTANT_INSERTION_PASS_H__
+#define __NEURUN_GRAPH_PASS_CONSTANT_INSERTION_PASS_H__
+
+#include <graph/operand/PermuteFactor.h>
+#include <model/Index.h>
+#include "OperationPass.h"
+#include <unordered_map>
+#include <utility>
+
+namespace neurun
+{
+namespace graph
+{
+namespace pass
+{
+
+class ConstantInsertionPass : public OperationPass
+{
+public:
+  using OperationPass::OperationPass;
+
+public:
+  std::string id() final { return "ConstantInsertionPass"; }
+
+public:
+  void callback(const model::OperationIndex &index, model::Operation &node) final;
+
+private:
+  struct ReplaceKey
+  {
+    model::OperandIndex index;
+    graph::operand::PermuteFactor factor;
+
+    bool operator==(const ReplaceKey &other) const
+    {
+      return index == other.index && factor == other.factor;
+    }
+  };
+
+  /**
+   * @brief Structure that provides hash function of ReplaceKey
+   */
+  struct KeyHasher
+  {
+    std::size_t operator()(const ReplaceKey &key) const noexcept
+    {
+      using std::hash;
+      return hash<model::OperandIndex>()(key.index) ^
+             (hash<graph::operand::PermuteFactor>()(key.factor) << 1);
+    }
+  };
+
+  std::unordered_map<ReplaceKey, model::OperandIndex, KeyHasher> _replace_operands_map;
+};
+
+} // namespace pass
+} // namespace graph
+} // namespace neurun
+
+#endif // __NEURUN_GRAPH_PASS_CONSTANT_INSERTION_PASS_H__