[mir] Remove Operation::Input class (#8294)
authorSergei Barannikov/Engineer/AI Tools Lab /SRR/Samsung Electronics <s.barannikov@samsung.com>
Fri, 18 Oct 2019 20:33:38 +0000 (23:33 +0300)
committerAlexander Efimov/./AI Tools Lab/Samsung Electronics <a.efimov@samsung.com>
Fri, 18 Oct 2019 20:33:38 +0000 (23:33 +0300)
Remove `Operation::Input` class and add lightweight `Use` class.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir/include/mir/Attributes.h
compiler/mir/include/mir/Operation.h
compiler/mir/src/Graph.cpp
compiler/mir/src/GraphPatternMatcher.cpp
compiler/mir/src/IrDotDumper.cpp
compiler/mir/src/Operation.cpp

index c065ad9..4262280 100644 (file)
@@ -77,6 +77,6 @@ struct PadOpAttributes
   std::vector<std::int32_t> padding_after;
   float padding_value;
 };
-}
+} // namespace mir
 
-#endif
\ No newline at end of file
+#endif
index 4e55577..851ecde 100644 (file)
@@ -21,8 +21,8 @@
 
 #include <deque>
 #include <string>
-#include <unordered_set>
 #include <limits>
+#include <vector>
 
 namespace mir
 {
@@ -39,7 +39,23 @@ public:
 #undef HANDLE_OP
   };
 
-  class Input;
+  /// @brief Represents a use of an operation output.
+  struct Use
+  {
+    Use(Operation *node, std::size_t index) : _node(node), _index(index) {}
+
+    Operation *getNode() const { return _node; }
+
+    std::size_t getIndex() const { return _index; }
+
+    // TODO Remove (needed for transition period).
+    operator Use *() const { return std::addressof(*const_cast<Use *>(this)); }
+    bool operator==(const Use &other) { return _node == other._node && _index == other._index; }
+
+  private:
+    Operation *_node;
+    std::size_t _index;
+  };
 
   /// @brief Represents an output of a node.
   class Output
@@ -58,17 +74,23 @@ public:
     Operation *getNode() { return _node; }
     const Operation *getNode() const { return _node; }
 
-    /// @brief Returns the index of this output among all the ouptputs of the node.
+    /// @brief Returns the index of this output among all the outputs of the node.
     std::size_t getIndex() const { return _index; }
 
     /// @brief Returns the inputs that consume this output.
-    const std::unordered_set<Input *> &getConsumers() const { return _consumers; }
+    const std::vector<Use> &getUses() const { return _uses; }
 
-    /// @brief Adds the specified input to the consumers of this output.
-    void addConsumer(Input *consumer) { _consumers.emplace(consumer); }
+    // TODO Remove (needed for transition period).
+    const std::vector<Use> &getConsumers() const { return _uses; }
 
-    /// @brief Removes the specified input from the consumers of this output.
-    void removeConsumer(Input *consumer) { _consumers.erase(consumer); }
+    /// @brief Adds the specified use to the uses of this output.
+    void addUse(Use use) { _uses.push_back(use); }
+
+    /// @brief Removes the specified use from the uses of this output.
+    void removeUse(Use use);
+
+    /// @brief Replace the defs of all uses of this output with the specified def.
+    void replaceAllUsesWith(Output *new_def);
 
     /// @brief Gets the type of this output.
     const TensorType &getType() const { return _type; }
@@ -96,52 +118,13 @@ public:
   private:
     Operation *_node;
     std::size_t _index;
-    std::unordered_set<Input *> _consumers;
+    std::vector<Use> _uses;
     TensorType _type;
     std::string _name;
   };
 
-  /// @brief Represents an input of a node.
-  class Input
-  {
-  public:
-    Input(Operation *node, std::size_t index, Output *producer)
-        : _node(node), _index(index), _producer(producer)
-    {
-      _producer->addConsumer(this);
-    }
-
-    ~Input() = default;
-
-    Input(const Input &) = delete;
-    Input(Input &&) = delete;
-    Input &operator=(const Input &) = delete;
-    Input &operator=(Input &&) = delete;
-
-    /// @brief Returns the node this is the input of.
-    Operation *getNode() { return _node; }
-    const Operation *getNode() const { return _node; }
-
-    /// @brief Returns the index of this output among all the inputs of the node.
-    std::size_t getIndex() const { return _index; }
-
-    /// @brief Returns the output that produces data for this input.
-    Output *getProducer() const { return _producer; }
-    operator Output *() const { return _producer; }
-
-    /// @brief Replaces the output that produces data for this input with the specified one.
-    void replaceProducer(Output *producer)
-    {
-      _producer->removeConsumer(this);
-      producer->addConsumer(this);
-      _producer = producer;
-    }
-
-  private:
-    Operation *_node;
-    std::size_t _index;
-    Output *_producer;
-  };
+  // TODO Remove (needed for transition period).
+  using Input = Use;
 
   virtual ~Operation() = default;
 
@@ -153,8 +136,8 @@ public:
   std::size_t getNumInputs() const { return _inputs.size(); }
   std::size_t getNumOutputs() const { return _outputs.size(); }
 
-  std::deque<Input> &getInputs() { return _inputs; }
-  const std::deque<Input> &getInputs() const { return _inputs; }
+  std::deque<Output *> &getInputs() { return _inputs; }
+  const std::deque<Output *> &getInputs() const { return _inputs; }
 
   std::deque<Output> &getOutputs() { return _outputs; }
   const std::deque<Output> &getOutputs() const { return _outputs; }
@@ -162,27 +145,13 @@ public:
   Output *getInput(std::size_t index)
   {
     assert(index < _inputs.size());
-    return _inputs[index].getProducer();
+    return _inputs[index];
   }
 
   const Output *getInput(std::size_t index) const
   {
     assert(index < _inputs.size());
-    return _inputs[index].getProducer();
-  }
-
-  // TODO Remove after replacing uses with `getInput`.
-  Output *getInputProducer(std::size_t index)
-  {
-    assert(index < _inputs.size());
-    return _inputs[index].getProducer();
-  }
-
-  // TODO Remove after replacing uses with `getInput`.
-  const Output *getInputProducer(std::size_t index) const
-  {
-    assert(index < _inputs.size());
-    return _inputs[index].getProducer();
+    return _inputs[index];
   }
 
   Output *getOutput(std::size_t index)
@@ -216,7 +185,7 @@ protected:
 private:
   Type _type;
   std::size_t _id = std::numeric_limits<std::size_t>::max();
-  std::deque<Input> _inputs;
+  std::deque<Output *> _inputs;
   std::deque<Output> _outputs;
 };
 
index 505d1b5..b7c3179 100644 (file)
@@ -34,14 +34,8 @@ static void replaceUsages(Operation *op, Operation *with)
   assert(op->getNumOutputs() == with->getNumOutputs());
   for (std::size_t i = 0; i < op->getNumOutputs(); ++i)
   {
-    auto *output = op->getOutput(i);
-    // The copy is intended here.
-    const auto consumers =
-        output->getConsumers(); // NOLINT(performance-unnecessary-copy-initialization)
-    for (auto *consumer : consumers)
-    {
-      consumer->replaceProducer(with->getOutput(i));
-    }
+    Operation::Output *output = op->getOutput(i);
+    output->replaceAllUsesWith(with->getOutput(i));
   }
 }
 
@@ -67,15 +61,15 @@ void Graph::accept(IVisitor *visitor)
     src_node->accept(visitor);
     for (auto &src_output : src_node->getOutputs())
     {
-      for (auto *consumer : src_output.getConsumers())
+      for (const auto use : src_output.getUses())
       {
-        Operation *dst_node = consumer->getNode();
+        Operation *dst_node = use.getNode();
         if (known_ops.count(dst_node) == 0)
         {
           bool all_inputs_resolved = true;
-          for (auto &dst_input : dst_node->getInputs())
+          for (Operation::Output *dst_input : dst_node->getInputs())
           {
-            if (known_ops.count(dst_input.getProducer()->getNode()) == 0)
+            if (known_ops.count(dst_input->getNode()) == 0)
             {
               all_inputs_resolved = false;
             }
@@ -121,13 +115,13 @@ void Graph::removeNode(Operation *op)
 #ifndef NDEBUG
   for (const auto &output : op->getOutputs())
   {
-    assert(output.getConsumers().empty() && "Trying to remove a node that has uses.");
+    assert(output.getUses().empty() && "Trying to remove a node that has uses.");
   }
 #endif
 
-  for (auto &input : op->getInputs())
+  for (std::size_t i = 0; i < op->getNumInputs(); ++i)
   {
-    input.getProducer()->removeConsumer(&input);
+    op->getInput(i)->removeUse(Operation::Use(op, i));
   }
 
   if (op->getType() == Operation::Type::input)
index 31b8551..78ea1fa 100644 (file)
@@ -33,9 +33,9 @@ GraphPatternMatcher::matchEdge(GraphPatternMatcher::Predicate p1, GraphPatternMa
     {
       for (auto &out : start->getOutputs())
       {
-        for (auto *consumer : out.getConsumers())
+        for (auto use : out.getUses())
         {
-          Operation *end = consumer->getNode();
+          Operation *end = use.getNode();
           if (p2(end))
           {
             matches.emplace_back(std::make_pair(start, end));
@@ -57,16 +57,15 @@ GraphPatternMatcher::matchUpBush(mir::GraphPatternMatcher::Predicate p1,
   {
     if (p2(root))
     {
-      auto &prev_nodes = root->getInputs();
-      if (std::all_of(prev_nodes.begin(), prev_nodes.end(), [p1](const Operation::Input &input) {
-            return p1(input.getProducer()->getNode());
-          }))
+      const auto &inputs = root->getInputs();
+      if (std::all_of(inputs.begin(), inputs.end(),
+                      [p1](const Operation::Output *input) { return p1(input->getNode()); }))
       {
         std::vector<Operation *> tops;
-        tops.reserve(prev_nodes.size());
-        for (auto &pr : prev_nodes)
+        tops.reserve(inputs.size());
+        for (Operation::Output *pr : inputs)
         {
-          tops.emplace_back(pr.getProducer()->getNode());
+          tops.emplace_back(pr->getNode());
         }
         matches.emplace_back(std::make_pair(tops, root));
       }
index e40255c..0c3f4df 100644 (file)
@@ -29,9 +29,9 @@ void dumpGraph(const Graph *graph, std::ostream &stream)
   for (const auto *node : graph->getNodes())
   {
     dot_graph.addNode(DotNodeBuilder(*node).getDotNode());
-    for (const auto &input : node->getInputs())
+    for (const Operation::Output *input : node->getInputs())
     {
-      dot_graph.addEdge({input.getProducer()->getNode()->getId(), node->getId()});
+      dot_graph.addEdge({input->getNode()->getId(), node->getId()});
     }
   }
 
index d926705..6f72acb 100644 (file)
 #include "mir/Visitor.h"
 #include "mir/OpDefs.h"
 
+#include <algorithm>
+
 namespace mir
 {
 
+void Operation::Output::removeUse(Operation::Use use)
+{
+  auto it = std::remove(_uses.begin(), _uses.end(), use);
+  _uses.erase(it);
+}
+
+void Operation::Output::replaceAllUsesWith(mir::Operation::Output *new_def)
+{
+  for (auto use : _uses)
+  {
+    use.getNode()->_inputs[use.getIndex()] = new_def;
+    new_def->addUse(use);
+  }
+  _uses.clear();
+}
+
 Operation::Operation(Type type, const std::vector<Output *> &inputs, std::size_t num_outputs)
     : _type(type)
 {
   for (std::size_t i = 0; i < inputs.size(); ++i)
   {
-    _inputs.emplace_back(this, i, inputs[i]);
+    inputs[i]->addUse(Use(this, i));
+    _inputs.push_back(inputs[i]);
   }
   for (std::size_t i = 0; i < num_outputs; ++i)
   {