From: Sergei Barannikov/Engineer/AI Tools Lab /SRR/Samsung Electronics Date: Fri, 18 Oct 2019 20:33:38 +0000 (+0300) Subject: [mir] Remove Operation::Input class (#8294) X-Git-Tag: submit/tizen/20191205.083104~688 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c54e9a74b1f47467717201a08c647b535daeb3de;p=platform%2Fcore%2Fml%2Fnnfw.git [mir] Remove Operation::Input class (#8294) Remove `Operation::Input` class and add lightweight `Use` class. Signed-off-by: Sergei Barannikov --- diff --git a/compiler/mir/include/mir/Attributes.h b/compiler/mir/include/mir/Attributes.h index c065ad9..4262280 100644 --- a/compiler/mir/include/mir/Attributes.h +++ b/compiler/mir/include/mir/Attributes.h @@ -77,6 +77,6 @@ struct PadOpAttributes std::vector padding_after; float padding_value; }; -} +} // namespace mir -#endif \ No newline at end of file +#endif diff --git a/compiler/mir/include/mir/Operation.h b/compiler/mir/include/mir/Operation.h index 4e55577..851ecde 100644 --- a/compiler/mir/include/mir/Operation.h +++ b/compiler/mir/include/mir/Operation.h @@ -21,8 +21,8 @@ #include #include -#include #include +#include namespace mir { @@ -39,7 +39,23 @@ public: #undef HANDLE_OP }; - class Input; + /// @brief Represents a use of an operation output. + struct Use + { + Use(Operation *node, std::size_t index) : _node(node), _index(index) {} + + Operation *getNode() const { return _node; } + + std::size_t getIndex() const { return _index; } + + // TODO Remove (needed for transition period). + operator Use *() const { return std::addressof(*const_cast(this)); } + bool operator==(const Use &other) { return _node == other._node && _index == other._index; } + + private: + Operation *_node; + std::size_t _index; + }; /// @brief Represents an output of a node. class Output @@ -58,17 +74,23 @@ public: Operation *getNode() { return _node; } const Operation *getNode() const { return _node; } - /// @brief Returns the index of this output among all the ouptputs of the node. + /// @brief Returns the index of this output among all the outputs of the node. std::size_t getIndex() const { return _index; } /// @brief Returns the inputs that consume this output. - const std::unordered_set &getConsumers() const { return _consumers; } + const std::vector &getUses() const { return _uses; } - /// @brief Adds the specified input to the consumers of this output. - void addConsumer(Input *consumer) { _consumers.emplace(consumer); } + // TODO Remove (needed for transition period). + const std::vector &getConsumers() const { return _uses; } - /// @brief Removes the specified input from the consumers of this output. - void removeConsumer(Input *consumer) { _consumers.erase(consumer); } + /// @brief Adds the specified use to the uses of this output. + void addUse(Use use) { _uses.push_back(use); } + + /// @brief Removes the specified use from the uses of this output. + void removeUse(Use use); + + /// @brief Replace the defs of all uses of this output with the specified def. + void replaceAllUsesWith(Output *new_def); /// @brief Gets the type of this output. const TensorType &getType() const { return _type; } @@ -96,52 +118,13 @@ public: private: Operation *_node; std::size_t _index; - std::unordered_set _consumers; + std::vector _uses; TensorType _type; std::string _name; }; - /// @brief Represents an input of a node. - class Input - { - public: - Input(Operation *node, std::size_t index, Output *producer) - : _node(node), _index(index), _producer(producer) - { - _producer->addConsumer(this); - } - - ~Input() = default; - - Input(const Input &) = delete; - Input(Input &&) = delete; - Input &operator=(const Input &) = delete; - Input &operator=(Input &&) = delete; - - /// @brief Returns the node this is the input of. - Operation *getNode() { return _node; } - const Operation *getNode() const { return _node; } - - /// @brief Returns the index of this output among all the inputs of the node. - std::size_t getIndex() const { return _index; } - - /// @brief Returns the output that produces data for this input. - Output *getProducer() const { return _producer; } - operator Output *() const { return _producer; } - - /// @brief Replaces the output that produces data for this input with the specified one. - void replaceProducer(Output *producer) - { - _producer->removeConsumer(this); - producer->addConsumer(this); - _producer = producer; - } - - private: - Operation *_node; - std::size_t _index; - Output *_producer; - }; + // TODO Remove (needed for transition period). + using Input = Use; virtual ~Operation() = default; @@ -153,8 +136,8 @@ public: std::size_t getNumInputs() const { return _inputs.size(); } std::size_t getNumOutputs() const { return _outputs.size(); } - std::deque &getInputs() { return _inputs; } - const std::deque &getInputs() const { return _inputs; } + std::deque &getInputs() { return _inputs; } + const std::deque &getInputs() const { return _inputs; } std::deque &getOutputs() { return _outputs; } const std::deque &getOutputs() const { return _outputs; } @@ -162,27 +145,13 @@ public: Output *getInput(std::size_t index) { assert(index < _inputs.size()); - return _inputs[index].getProducer(); + return _inputs[index]; } const Output *getInput(std::size_t index) const { assert(index < _inputs.size()); - return _inputs[index].getProducer(); - } - - // TODO Remove after replacing uses with `getInput`. - Output *getInputProducer(std::size_t index) - { - assert(index < _inputs.size()); - return _inputs[index].getProducer(); - } - - // TODO Remove after replacing uses with `getInput`. - const Output *getInputProducer(std::size_t index) const - { - assert(index < _inputs.size()); - return _inputs[index].getProducer(); + return _inputs[index]; } Output *getOutput(std::size_t index) @@ -216,7 +185,7 @@ protected: private: Type _type; std::size_t _id = std::numeric_limits::max(); - std::deque _inputs; + std::deque _inputs; std::deque _outputs; }; diff --git a/compiler/mir/src/Graph.cpp b/compiler/mir/src/Graph.cpp index 505d1b5..b7c3179 100644 --- a/compiler/mir/src/Graph.cpp +++ b/compiler/mir/src/Graph.cpp @@ -34,14 +34,8 @@ static void replaceUsages(Operation *op, Operation *with) assert(op->getNumOutputs() == with->getNumOutputs()); for (std::size_t i = 0; i < op->getNumOutputs(); ++i) { - auto *output = op->getOutput(i); - // The copy is intended here. - const auto consumers = - output->getConsumers(); // NOLINT(performance-unnecessary-copy-initialization) - for (auto *consumer : consumers) - { - consumer->replaceProducer(with->getOutput(i)); - } + Operation::Output *output = op->getOutput(i); + output->replaceAllUsesWith(with->getOutput(i)); } } @@ -67,15 +61,15 @@ void Graph::accept(IVisitor *visitor) src_node->accept(visitor); for (auto &src_output : src_node->getOutputs()) { - for (auto *consumer : src_output.getConsumers()) + for (const auto use : src_output.getUses()) { - Operation *dst_node = consumer->getNode(); + Operation *dst_node = use.getNode(); if (known_ops.count(dst_node) == 0) { bool all_inputs_resolved = true; - for (auto &dst_input : dst_node->getInputs()) + for (Operation::Output *dst_input : dst_node->getInputs()) { - if (known_ops.count(dst_input.getProducer()->getNode()) == 0) + if (known_ops.count(dst_input->getNode()) == 0) { all_inputs_resolved = false; } @@ -121,13 +115,13 @@ void Graph::removeNode(Operation *op) #ifndef NDEBUG for (const auto &output : op->getOutputs()) { - assert(output.getConsumers().empty() && "Trying to remove a node that has uses."); + assert(output.getUses().empty() && "Trying to remove a node that has uses."); } #endif - for (auto &input : op->getInputs()) + for (std::size_t i = 0; i < op->getNumInputs(); ++i) { - input.getProducer()->removeConsumer(&input); + op->getInput(i)->removeUse(Operation::Use(op, i)); } if (op->getType() == Operation::Type::input) diff --git a/compiler/mir/src/GraphPatternMatcher.cpp b/compiler/mir/src/GraphPatternMatcher.cpp index 31b8551..78ea1fa 100644 --- a/compiler/mir/src/GraphPatternMatcher.cpp +++ b/compiler/mir/src/GraphPatternMatcher.cpp @@ -33,9 +33,9 @@ GraphPatternMatcher::matchEdge(GraphPatternMatcher::Predicate p1, GraphPatternMa { for (auto &out : start->getOutputs()) { - for (auto *consumer : out.getConsumers()) + for (auto use : out.getUses()) { - Operation *end = consumer->getNode(); + Operation *end = use.getNode(); if (p2(end)) { matches.emplace_back(std::make_pair(start, end)); @@ -57,16 +57,15 @@ GraphPatternMatcher::matchUpBush(mir::GraphPatternMatcher::Predicate p1, { if (p2(root)) { - auto &prev_nodes = root->getInputs(); - if (std::all_of(prev_nodes.begin(), prev_nodes.end(), [p1](const Operation::Input &input) { - return p1(input.getProducer()->getNode()); - })) + const auto &inputs = root->getInputs(); + if (std::all_of(inputs.begin(), inputs.end(), + [p1](const Operation::Output *input) { return p1(input->getNode()); })) { std::vector tops; - tops.reserve(prev_nodes.size()); - for (auto &pr : prev_nodes) + tops.reserve(inputs.size()); + for (Operation::Output *pr : inputs) { - tops.emplace_back(pr.getProducer()->getNode()); + tops.emplace_back(pr->getNode()); } matches.emplace_back(std::make_pair(tops, root)); } diff --git a/compiler/mir/src/IrDotDumper.cpp b/compiler/mir/src/IrDotDumper.cpp index e40255c..0c3f4df 100644 --- a/compiler/mir/src/IrDotDumper.cpp +++ b/compiler/mir/src/IrDotDumper.cpp @@ -29,9 +29,9 @@ void dumpGraph(const Graph *graph, std::ostream &stream) for (const auto *node : graph->getNodes()) { dot_graph.addNode(DotNodeBuilder(*node).getDotNode()); - for (const auto &input : node->getInputs()) + for (const Operation::Output *input : node->getInputs()) { - dot_graph.addEdge({input.getProducer()->getNode()->getId(), node->getId()}); + dot_graph.addEdge({input->getNode()->getId(), node->getId()}); } } diff --git a/compiler/mir/src/Operation.cpp b/compiler/mir/src/Operation.cpp index d926705..6f72acb 100644 --- a/compiler/mir/src/Operation.cpp +++ b/compiler/mir/src/Operation.cpp @@ -18,15 +18,34 @@ #include "mir/Visitor.h" #include "mir/OpDefs.h" +#include + namespace mir { +void Operation::Output::removeUse(Operation::Use use) +{ + auto it = std::remove(_uses.begin(), _uses.end(), use); + _uses.erase(it); +} + +void Operation::Output::replaceAllUsesWith(mir::Operation::Output *new_def) +{ + for (auto use : _uses) + { + use.getNode()->_inputs[use.getIndex()] = new_def; + new_def->addUse(use); + } + _uses.clear(); +} + Operation::Operation(Type type, const std::vector &inputs, std::size_t num_outputs) : _type(type) { for (std::size_t i = 0; i < inputs.size(); ++i) { - _inputs.emplace_back(this, i, inputs[i]); + inputs[i]->addUse(Use(this, i)); + _inputs.push_back(inputs[i]); } for (std::size_t i = 0; i < num_outputs; ++i) {