#include <deque>
#include <string>
-#include <unordered_set>
#include <limits>
+#include <vector>
namespace mir
{
#undef HANDLE_OP
};
- class Input;
+ /// @brief Represents a use of an operation output.
+ struct Use
+ {
+ Use(Operation *node, std::size_t index) : _node(node), _index(index) {}
+
+ Operation *getNode() const { return _node; }
+
+ std::size_t getIndex() const { return _index; }
+
+ // TODO Remove (needed for transition period).
+ operator Use *() const { return std::addressof(*const_cast<Use *>(this)); }
+ bool operator==(const Use &other) { return _node == other._node && _index == other._index; }
+
+ private:
+ Operation *_node;
+ std::size_t _index;
+ };
/// @brief Represents an output of a node.
class Output
Operation *getNode() { return _node; }
const Operation *getNode() const { return _node; }
- /// @brief Returns the index of this output among all the ouptputs of the node.
+ /// @brief Returns the index of this output among all the outputs of the node.
std::size_t getIndex() const { return _index; }
/// @brief Returns the inputs that consume this output.
- const std::unordered_set<Input *> &getConsumers() const { return _consumers; }
+ const std::vector<Use> &getUses() const { return _uses; }
- /// @brief Adds the specified input to the consumers of this output.
- void addConsumer(Input *consumer) { _consumers.emplace(consumer); }
+ // TODO Remove (needed for transition period).
+ const std::vector<Use> &getConsumers() const { return _uses; }
- /// @brief Removes the specified input from the consumers of this output.
- void removeConsumer(Input *consumer) { _consumers.erase(consumer); }
+ /// @brief Adds the specified use to the uses of this output.
+ void addUse(Use use) { _uses.push_back(use); }
+
+ /// @brief Removes the specified use from the uses of this output.
+ void removeUse(Use use);
+
+ /// @brief Replace the defs of all uses of this output with the specified def.
+ void replaceAllUsesWith(Output *new_def);
/// @brief Gets the type of this output.
const TensorType &getType() const { return _type; }
private:
Operation *_node;
std::size_t _index;
- std::unordered_set<Input *> _consumers;
+ std::vector<Use> _uses;
TensorType _type;
std::string _name;
};
- /// @brief Represents an input of a node.
- class Input
- {
- public:
- Input(Operation *node, std::size_t index, Output *producer)
- : _node(node), _index(index), _producer(producer)
- {
- _producer->addConsumer(this);
- }
-
- ~Input() = default;
-
- Input(const Input &) = delete;
- Input(Input &&) = delete;
- Input &operator=(const Input &) = delete;
- Input &operator=(Input &&) = delete;
-
- /// @brief Returns the node this is the input of.
- Operation *getNode() { return _node; }
- const Operation *getNode() const { return _node; }
-
- /// @brief Returns the index of this output among all the inputs of the node.
- std::size_t getIndex() const { return _index; }
-
- /// @brief Returns the output that produces data for this input.
- Output *getProducer() const { return _producer; }
- operator Output *() const { return _producer; }
-
- /// @brief Replaces the output that produces data for this input with the specified one.
- void replaceProducer(Output *producer)
- {
- _producer->removeConsumer(this);
- producer->addConsumer(this);
- _producer = producer;
- }
-
- private:
- Operation *_node;
- std::size_t _index;
- Output *_producer;
- };
+ // TODO Remove (needed for transition period).
+ using Input = Use;
virtual ~Operation() = default;
std::size_t getNumInputs() const { return _inputs.size(); }
std::size_t getNumOutputs() const { return _outputs.size(); }
- std::deque<Input> &getInputs() { return _inputs; }
- const std::deque<Input> &getInputs() const { return _inputs; }
+ std::deque<Output *> &getInputs() { return _inputs; }
+ const std::deque<Output *> &getInputs() const { return _inputs; }
std::deque<Output> &getOutputs() { return _outputs; }
const std::deque<Output> &getOutputs() const { return _outputs; }
Output *getInput(std::size_t index)
{
assert(index < _inputs.size());
- return _inputs[index].getProducer();
+ return _inputs[index];
}
const Output *getInput(std::size_t index) const
{
assert(index < _inputs.size());
- return _inputs[index].getProducer();
- }
-
- // TODO Remove after replacing uses with `getInput`.
- Output *getInputProducer(std::size_t index)
- {
- assert(index < _inputs.size());
- return _inputs[index].getProducer();
- }
-
- // TODO Remove after replacing uses with `getInput`.
- const Output *getInputProducer(std::size_t index) const
- {
- assert(index < _inputs.size());
- return _inputs[index].getProducer();
+ return _inputs[index];
}
Output *getOutput(std::size_t index)
private:
Type _type;
std::size_t _id = std::numeric_limits<std::size_t>::max();
- std::deque<Input> _inputs;
+ std::deque<Output *> _inputs;
std::deque<Output> _outputs;
};
assert(op->getNumOutputs() == with->getNumOutputs());
for (std::size_t i = 0; i < op->getNumOutputs(); ++i)
{
- auto *output = op->getOutput(i);
- // The copy is intended here.
- const auto consumers =
- output->getConsumers(); // NOLINT(performance-unnecessary-copy-initialization)
- for (auto *consumer : consumers)
- {
- consumer->replaceProducer(with->getOutput(i));
- }
+ Operation::Output *output = op->getOutput(i);
+ output->replaceAllUsesWith(with->getOutput(i));
}
}
src_node->accept(visitor);
for (auto &src_output : src_node->getOutputs())
{
- for (auto *consumer : src_output.getConsumers())
+ for (const auto use : src_output.getUses())
{
- Operation *dst_node = consumer->getNode();
+ Operation *dst_node = use.getNode();
if (known_ops.count(dst_node) == 0)
{
bool all_inputs_resolved = true;
- for (auto &dst_input : dst_node->getInputs())
+ for (Operation::Output *dst_input : dst_node->getInputs())
{
- if (known_ops.count(dst_input.getProducer()->getNode()) == 0)
+ if (known_ops.count(dst_input->getNode()) == 0)
{
all_inputs_resolved = false;
}
#ifndef NDEBUG
for (const auto &output : op->getOutputs())
{
- assert(output.getConsumers().empty() && "Trying to remove a node that has uses.");
+ assert(output.getUses().empty() && "Trying to remove a node that has uses.");
}
#endif
- for (auto &input : op->getInputs())
+ for (std::size_t i = 0; i < op->getNumInputs(); ++i)
{
- input.getProducer()->removeConsumer(&input);
+ op->getInput(i)->removeUse(Operation::Use(op, i));
}
if (op->getType() == Operation::Type::input)
{
for (auto &out : start->getOutputs())
{
- for (auto *consumer : out.getConsumers())
+ for (auto use : out.getUses())
{
- Operation *end = consumer->getNode();
+ Operation *end = use.getNode();
if (p2(end))
{
matches.emplace_back(std::make_pair(start, end));
{
if (p2(root))
{
- auto &prev_nodes = root->getInputs();
- if (std::all_of(prev_nodes.begin(), prev_nodes.end(), [p1](const Operation::Input &input) {
- return p1(input.getProducer()->getNode());
- }))
+ const auto &inputs = root->getInputs();
+ if (std::all_of(inputs.begin(), inputs.end(),
+ [p1](const Operation::Output *input) { return p1(input->getNode()); }))
{
std::vector<Operation *> tops;
- tops.reserve(prev_nodes.size());
- for (auto &pr : prev_nodes)
+ tops.reserve(inputs.size());
+ for (Operation::Output *pr : inputs)
{
- tops.emplace_back(pr.getProducer()->getNode());
+ tops.emplace_back(pr->getNode());
}
matches.emplace_back(std::make_pair(tops, root));
}