[nnc] Initial optimization Pass (#2940)
authorАндрей Шедько/AI Tools Lab /SRR/Engineer/삼성전자 <a.shedko@samsung.com>
Thu, 31 Jan 2019 15:59:19 +0000 (18:59 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 31 Jan 2019 15:59:19 +0000 (18:59 +0300)
Initial support for transpose composition
Added a `Matcher` class for matching graph patterns

Signed-off-by: Andrei Shedko <a.shedko@samsung.com>
18 files changed:
contrib/nnc/core/CMakeLists.txt
contrib/nnc/core/modelIR/Graph.cpp
contrib/nnc/core/modelIR/GraphPatternMatcher.cpp [new file with mode: 0644]
contrib/nnc/core/modelIR/Operation.cpp
contrib/nnc/driver/Driver.cpp
contrib/nnc/driver/Driver.h
contrib/nnc/driver/Options.cpp
contrib/nnc/include/core/modelIR/Graph.h
contrib/nnc/include/core/modelIR/GraphPatternMatcher.h [new file with mode: 0644]
contrib/nnc/include/core/modelIR/Operation.h
contrib/nnc/include/option/Options.h
contrib/nnc/include/passes/optimizations/CombineTransposes.h [new file with mode: 0644]
contrib/nnc/passes/CMakeLists.txt
contrib/nnc/passes/dot_dumper/DumperPass.cpp
contrib/nnc/passes/interpreter/interpreter_pass.cpp
contrib/nnc/passes/optimizations/CMakeLists.txt [new file with mode: 0644]
contrib/nnc/passes/optimizations/CombineTransposes.cpp [new file with mode: 0644]
contrib/nnc/passes/soft_backend/SBSerializer.cpp

index 3251204..efef88a 100644 (file)
@@ -15,6 +15,7 @@ set(SOURCES "modelIR/operations/ConcatOp.cpp"
             "modelIR/Index.cpp"
             "modelIR/ir_dot_builder.cpp"
             "modelIR/IrDotDumper.cpp"
+            "modelIR/GraphPatternMatcher.cpp"
             "modelIR/ir_dot_node_info.cpp"
             "modelIR/Operation.cpp"
             "modelIR/Shape.cpp"
index c82ff64..6a0cc6c 100644 (file)
  * limitations under the License.
  */
 
+#include "core/modelIR/Graph.h"
+
 #include <deque>
 #include <set>
 #include <algorithm>
 
-#include "core/modelIR/Graph.h"
-
 namespace nnc {
 namespace mir {
 
@@ -93,7 +93,7 @@ Graph::~Graph() {
 }
 
 void Graph::registerOp(Operation* op) {
-  _ops.push_back(op);
+  _ops.emplace(op);
 
   if (auto* input_op = dynamic_cast<ops::InputOp*>(op))
     _inputs.emplace_back(input_op);
@@ -102,7 +102,7 @@ void Graph::registerOp(Operation* op) {
     _outputs.emplace_back(output_op);
 }
 
-void Graph::replaceNode(const Operation* op, Operation* with) {
+void Graph::replaceNode(Operation* op, Operation* with) {
   replaceUsages(op, with);
 
   _inputs.erase(std::remove_if(_inputs.begin(), _inputs.end(), [op](ops::InputOp* n) {
@@ -113,12 +113,11 @@ void Graph::replaceNode(const Operation* op, Operation* with) {
     return n == op;
   }), _outputs.end());
 
-  _ops.erase(std::remove_if(_ops.begin(), _ops.end(), [op](Operation* n) {
-    return n == op;
-  }), _ops.end());
+  _ops.erase(op);
+
 }
 
-ops::InputOp* Graph::replaceWithInputNode(const Operation* op) {
+ops::InputOp* Graph::replaceWithInputNode(Operation* op) {
   assert(op->getNumOutputs() <= 1
          && "Only operations with single output value can be replaced with input node");
   assert(op->getNextNodes().size() <= 1
@@ -154,5 +153,12 @@ void Graph::replaceInputNodes(const std::vector<std::string>& new_inputs) {
   }
 }
 
+void Graph::removeNode(Operation* op) {
+  op->removeFromPrev();
+  op->removeFromNext();
+  _ops.erase(op);
+  delete op;
+}
+
 } // namespace mir
 } // namespace nnc
diff --git a/contrib/nnc/core/modelIR/GraphPatternMatcher.cpp b/contrib/nnc/core/modelIR/GraphPatternMatcher.cpp
new file mode 100644 (file)
index 0000000..94ef068
--- /dev/null
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/GraphPatternMatcher.h"
+
+#include <algorithm>
+#include <vector>
+
+namespace nnc {
+namespace mir {
+
+std::vector<std::pair<Operation*, Operation*>> GraphPatternMatcher::matchEdge(
+  GraphPatternMatcher::Predicate p1,
+  GraphPatternMatcher::Predicate p2) {
+
+  std::vector<std::pair<Operation*, Operation*>> matches;
+  for (auto* start: _g->getNodes()) {
+    if (p1(start)) {
+      const auto& next_nodes = start->getNextNodes();
+      for (auto* end: next_nodes) {
+        if (p2(end)) {
+          matches.emplace_back(std::make_pair(start, end));
+          break;
+        }
+      }
+    }
+  }
+  return matches;
+}
+
+} // namespace nnc
+} // namespace mir
index 20a1ee6..47956a1 100644 (file)
@@ -79,6 +79,11 @@ void Operation::setOutputShape(std::size_t index, const Shape& shape) {
   _outputShapes[index] = shape;
 }
 
+void Operation::setInput(const IODescriptor& descr, size_t i) {
+  descr.op->_outputs.emplace_back(this);
+  _inputs[i] = descr;
+}
+
 void Operation::accept(IVisitor* v) {
   switch (getType()) {
 #define HANDLE_OP(OpType, OpClass) \
@@ -92,5 +97,24 @@ void Operation::accept(IVisitor* v) {
   }
 }
 
+void Operation::removeFromPrev() {
+  for (const auto& prev : _inputs) {
+    auto& mutable_next = prev.op->_outputs;
+    mutable_next.erase(std::find(std::begin(mutable_next), std::end(mutable_next), this));
+  }
+  _inputs.clear();
+}
+
+void Operation::removeFromNext() {
+  for (auto* next : _outputs) {
+    auto& mutable_prev = next->_inputs;
+    mutable_prev.erase(
+      std::remove_if(mutable_prev.begin(), mutable_prev.end(), [this](IODescriptor n) {
+        return n.op == this;
+      }), mutable_prev.end());
+  }
+  _outputs.clear();
+}
+
 } // namespace mir
 } // namespace nnc
index 2d641cf..3bf6ea4 100644 (file)
 #include "passes/soft_backend/CPPGenerator.h"
 #include "passes/dot_dumper/DumperPass.h"
 #include "passes/acl_soft_backend/AclCppGenerator.h"
+#include "passes/optimizations/CombineTransposes.h"
 #include "support/CommandLine.h"
 #include "Definitions.h"
 #include "option/Options.h"
 #include "Driver.h"
 
-
 namespace nnc {
 
 /**
@@ -110,10 +110,22 @@ void Driver::registerBackendPass() {
 
 } // registerBackendPass
 
+static void registerDumper(PassManager& pass_manager) {
+  if (cli::dumpGraph)
+    pass_manager.registerPass(std::unique_ptr<Pass>(new DumperPass()));
+}
+
+void Driver::registerOptimizationPass() {
+  if (cli::doOptimizationPass) {
+    _passManager.registerPass(std::unique_ptr<Pass>(new CombineTransposes()));
+  }
+} // registerOptimizationPass
+
 void Driver::runDriver() {
 
   // register passes
   registerFrontendPass();
+  registerOptimizationPass();
   registerBackendPass();
 
   // run registered passes
index 06059a6..7a4459d 100644 (file)
 
 #include "pass/PassManager.h"
 
-namespace nnc
-{
+namespace nnc {
 
 /**
  * @brief exceptions description class for compiler driver
  */
-class DriverException : public std::exception
-{
+class DriverException : public std::exception {
 public:
   DriverException() = default;
   explicit DriverException(const std::string& reason) : _msg(reason) {}
@@ -44,8 +42,7 @@ private:
 /**
  * @brief Compiler Driver manages the whole pipeline compilation process
  */
-class Driver
-{
+class Driver {
 public:
   /**
    * @brief main method to run compiler driver
@@ -57,6 +54,7 @@ public:
 private:
   void registerFrontendPass();
   void registerBackendPass();
+  void registerOptimizationPass();
   void runPasses();
 
   PassManager _passManager;
index c073d31..0a1bb7c 100644 (file)
@@ -142,6 +142,23 @@ Option<std::string> inputFile(optname("--nnmodel, -m"),
                               checkInFile);
 
 /**
+ * Options for *optimizer*
+ */
+Option<bool> doOptimizationPass(optname("-O"),
+                                overview("whether to optimize model or not"),
+                                false,
+                                optional(true), optvalues(""), nullptr,
+                                separators(""),
+                                showopt(true));
+
+Option<bool> dumpGraph(optname("--dump, -D"),
+                       overview("dump graph to dot files after optimization passes"),
+                       false,
+                       optional(true), optvalues(""), nullptr,
+                       separators(""),
+                       showopt(true));
+
+/**
  * Options for *backend*
  */
 // options for soft backend
index 2f82fb2..87e5f77 100644 (file)
@@ -20,6 +20,7 @@
 #include <string>
 #include <vector>
 #include <type_traits>
+#include <unordered_set>
 #include <unordered_map>
 #include <set>
 
@@ -55,26 +56,32 @@ class Graph {
    * @brief Returns all graph nodes
    * @return vector containing all graph nodes
    */
-  const std::vector<Operation*>& getNodes() const { return _ops; }
+  std::unordered_set<Operation*> getNodes() const { return _ops; }
 
   /**
    * @brief Returns all graph input nodes
    * @returns vector containing all graph input nodes
    */
-  const std::vector<ops::InputOp*>& getInputs() const { return _inputs; }
+  std::vector<ops::InputOp*> getInputs() const { return _inputs; }
 
   /**
    * @brief Returns all graph output nodes
    * @returns vector containing all graph output nodes
    */
-  const std::vector<ops::OutputOp*>& getOutputs() const { return _outputs; }
+  std::vector<ops::OutputOp*> getOutputs() const { return _outputs; }
+
+  /**
+   * @brief remove node from graph, along with its links in other nodes
+   * @param op node to be removed
+   */
+  void removeNode(Operation* op);
 
   /**
    * @brief Subsitude node in graph with another keeping all edges
    * @param op Node to subsitude
    * @param with Node to place instead
    */
-  void replaceNode(const Operation* op, Operation* with);
+  void replaceNode(Operation* op, Operation* with);
 
   /**
    * @brief Replaces referenced node with input(VariableOp) node
@@ -82,7 +89,7 @@ class Graph {
    * @return Input node which is placed in graph instead of passed node
    * @warning deletes passed node
    */
-  ops::InputOp* replaceWithInputNode(const Operation* op);
+  ops::InputOp* replaceWithInputNode(Operation* op);
 
   /**
    * @brief Change graph inputs to nodes with names in newInputs
@@ -94,7 +101,7 @@ class Graph {
 private:
   void registerOp(Operation* op);
 
-  std::vector<Operation*> _ops;
+  std::unordered_set<Operation*> _ops;
   size_t _lastNodeId = 0;
   std::vector<ops::InputOp*> _inputs;
   std::vector<ops::OutputOp*> _outputs;
diff --git a/contrib/nnc/include/core/modelIR/GraphPatternMatcher.h b/contrib/nnc/include/core/modelIR/GraphPatternMatcher.h
new file mode 100644 (file)
index 0000000..bdf939b
--- /dev/null
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_GRAPH_PATTERN_MATCHER_H
+#define NNCC_GRAPH_PATTERN_MATCHER_H
+
+#include "Graph.h"
+
+namespace nnc {
+namespace mir {
+
+class Operation;
+
+class GraphPatternMatcher {
+public:
+  using Predicate = bool(const Operation*);
+  explicit GraphPatternMatcher(Graph* g) : _g(g) {};
+
+  /**
+   * @brief Match an edge with 2 predicates for ends of the edge
+   * @param pattern
+   * @return Vector of topmost ops of all matches; empty if no mathces are found
+   */
+  std::vector<std::pair<Operation*, Operation*>> matchEdge(Predicate p1, Predicate p2);
+
+private:
+  Graph* _g;
+};
+
+#endif //NNCC_GRAPH_PATTERN_MATCHER_H
+
+} // namespace nnc
+} // namespace mir
index 9794ccc..6a96b79 100644 (file)
@@ -72,6 +72,19 @@ public:
   const nnc::mir::Shape& getInputShape(std::size_t index) const;
   const nnc::mir::Shape& getOutputShape(std::size_t index) const;
 
+  /// @brief Removes links to this node from it's parents
+  void removeFromPrev();
+
+  /// @brief Removes links to this node from it's children
+  void removeFromNext();
+
+  /**
+   * @brief Set `descr` as `i`-th input of this node
+   * @param descr the tensor to be set as input
+   * @param i input index
+   */
+  void setInput(const IODescriptor& descr, size_t i);
+
   void accept(IVisitor* v);
 
 protected:
index 803fdcc..d49b17e 100644 (file)
@@ -36,6 +36,9 @@ extern Option<bool> caffeFrontend;  // frontend for CAFFE AI framework
 extern Option<bool> tflFrontend;    // frontend for TensorFlow Lite AI framework
 extern Option<bool> onnxFrontend;   // frontend for ONNX AI framework
 
+extern Option<bool> doOptimizationPass; // enable optimization pass
+extern Option<bool> dumpGraph; // enable Dumping graph to .dot files
+
 // valid values for target option
 #define NNC_TARGET_ARM_CPP      "arm-c++"
 #define NNC_TARGET_X86_CPP      "x86-c++"
diff --git a/contrib/nnc/include/passes/optimizations/CombineTransposes.h b/contrib/nnc/include/passes/optimizations/CombineTransposes.h
new file mode 100644 (file)
index 0000000..311624e
--- /dev/null
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_COMBINE_TRANSPOSES_H
+#define NNCC_COMBINE_TRANSPOSES_H
+
+#include "pass/Pass.h"
+#include "pass/PassData.h"
+
+namespace nnc {
+
+/**
+ * @brief This pass combines sequential transposes and removes identity transposes if
+ * the combination results in an identity permutation.
+ */
+class CombineTransposes : public Pass {
+public:
+  PassData run(PassData data) override;
+
+private:
+};
+
+} // namespace nnc
+
+
+#endif //NNCC_COMBINE_TRANSPOSES_H
index 1fa2195..5a413bc 100644 (file)
@@ -23,6 +23,11 @@ if(NNC_FRONTEND_CAFFE2_ENABLED)
 endif()
 
 #
+# MIDDLE PASSES
+#
+add_subdirectory(optimizations)
+
+#
 # BACKENDs
 #
 add_subdirectory(interpreter)
index 9ee1328..9ed90ed 100644 (file)
@@ -19,7 +19,6 @@
 #include "pass/PassException.h"
 
 #include <fstream>
-#include <fstream>
 
 namespace nnc {
 
index aab24e9..6205d25 100644 (file)
@@ -140,7 +140,6 @@ PassData InterpreterPass::run(PassData data) {
 #else
     std::cout << "Result <" << out_node->getName()
               << "> wasn't saved, due to lack of HDF5" << std::endl;
-
 #endif  // NNC_HDF5_SUPPORTED
   }
 
diff --git a/contrib/nnc/passes/optimizations/CMakeLists.txt b/contrib/nnc/passes/optimizations/CMakeLists.txt
new file mode 100644 (file)
index 0000000..1a1e13a
--- /dev/null
@@ -0,0 +1,6 @@
+set(OPTIMIZATIONS_SRC "CombineTransposes.cpp")
+nnc_add_library(nnc_optimizations SHARED ${OPTIMIZATIONS_SRC})
+target_link_libraries(nnc_optimizations PRIVATE nnc_core nnc_support)
+
+# install optimizations library
+nnc_install_library(nnc_optimizations)
diff --git a/contrib/nnc/passes/optimizations/CombineTransposes.cpp b/contrib/nnc/passes/optimizations/CombineTransposes.cpp
new file mode 100644 (file)
index 0000000..0a6f2b9
--- /dev/null
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/CombineTransposes.h"
+#include "core/modelIR/operations/TransposeOp.h"
+#include "core/modelIR/Graph.h"
+#include "core/modelIR/GraphPatternMatcher.h"
+
+namespace nnc {
+
+using namespace mir;
+
+PassData CombineTransposes::run(PassData data) {
+  auto g = static_cast<Graph*>(data);
+  assert(g);
+  return g;
+}
+
+} //namespace nnc
index 202eca6..008536f 100644 (file)
@@ -51,7 +51,6 @@
 
 #include "pass/PassException.h"
 #include <algorithm>
-#include <core/modelIR/TensorUtil.h>
 
 #define UNUSED(x) ((void)(x))