[nnc] Redesign IR. Part 1. (#2998)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Fri, 8 Feb 2019 10:01:38 +0000 (13:01 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 8 Feb 2019 10:01:38 +0000 (13:01 +0300)
Introduce Input and Output classes that serve as inputs and outputs of a graph node.
Refactor other parts of the compiler to account for this change.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
26 files changed:
contrib/nnc/core/modelIR/Graph.cpp
contrib/nnc/core/modelIR/GraphPatternMatcher.cpp
contrib/nnc/core/modelIR/IrDotDumper.cpp
contrib/nnc/core/modelIR/Operation.cpp
contrib/nnc/core/modelIR/ir_dot_builder.cpp
contrib/nnc/include/core/modelIR/Graph.h
contrib/nnc/include/core/modelIR/Operation.h
contrib/nnc/include/passes/interpreter/Interpreter.h
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp
contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp
contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp
contrib/nnc/passes/caffe_frontend/caffe_importer.cpp
contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/interpreter_pass.cpp
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.h
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp
contrib/nnc/passes/soft_backend/CPPGenerator.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/unittests/core/NodeReplacer.cpp
contrib/nnc/unittests/core/operation.cpp
contrib/nnc/unittests/soft_backend/CPPOperations.cpp
contrib/nnc/unittests/soft_backend/ModelAnalyzer.cpp

index 6a0cc6c..726ef75 100644 (file)
@@ -29,27 +29,16 @@ namespace mir {
  * @param op the operation to replace
  * @param with the operation to use as a replacement
  */
-static void replaceUsages(const Operation* op, Operation* with) {
-
-  //For each output replace prev references to `node` by `with`
-  for (auto out : op->getNextNodes()) {
-    for (auto& prev : out->getMutablePrevNodes()) {
-      if (prev.op == op)
-        prev.op = with;
-    }
-  }
-
-  with->getMutableNextNodes() = op->getNextNodes();
-
-  //For each input replace next references to `node` by `with`
-  for (auto& in : op->getPrevNodes()) {
-    for (auto& next : in.op->getMutableNextNodes()) {
-      if (next == op)
-        next = with;
+static void replaceUsages(Operation* op, Operation* with) {
+  assert(op->getNumOutputs() == with->getNumOutputs());
+  for (std::size_t i = 0; i < op->getNumOutputs(); ++i) {
+    auto* output = op->getOutput(i);
+    // The copy is intended here.
+    const auto consumers = output->getConsumers();
+    for (auto* consumer : consumers) {
+      consumer->replaceProducer(with->getOutput(i));
     }
   }
-
-  with->getMutablePrevNodes() = op->getPrevNodes();
 }
 
 void Graph::accept(IVisitor* visitor) {
@@ -65,21 +54,23 @@ void Graph::accept(IVisitor* visitor) {
 
   //BFS
   while (!q.empty()) {
-    auto n = q.front();
+    Operation* src_node = q.front();
     q.pop_front();
-    n->accept(visitor);
-    for (auto out : n->getNextNodes()) {
-      if (known_ops.count(out) == 0) {
-
-        bool allInputsResolved = true;
-        for (auto in : out->getPrevNodes()) {
-          if (known_ops.count(in.op) == 0) {
-            allInputsResolved = false;
+    src_node->accept(visitor);
+    for (const auto& src_output : src_node->getOutputs()) {
+      for (const auto* consumer : src_output.getConsumers()) {
+        Operation* dst_node = consumer->getNode();
+        if (known_ops.count(dst_node) == 0) {
+          bool allInputsResolved = true;
+          for (const auto& dst_input : dst_node->getInputs()) {
+            if (known_ops.count(dst_input.getProducer()->getNode()) == 0) {
+              allInputsResolved = false;
+            }
+          }
+          if (allInputsResolved) {
+            known_ops.insert(dst_node);
+            q.push_back(dst_node);
           }
-        }
-        if (allInputsResolved) {
-          known_ops.insert(out);
-          q.push_back(out);
         }
       }
     }
@@ -104,34 +95,16 @@ void Graph::registerOp(Operation* op) {
 
 void Graph::replaceNode(Operation* op, Operation* with) {
   replaceUsages(op, with);
-
-  _inputs.erase(std::remove_if(_inputs.begin(), _inputs.end(), [op](ops::InputOp* n) {
-    return n == op;
-  }), _inputs.end());
-
-  _outputs.erase(std::remove_if(_outputs.begin(), _outputs.end(), [op](ops::OutputOp* n) {
-    return n == op;
-  }), _outputs.end());
-
-  _ops.erase(op);
-
+  removeNode(op);
 }
 
 ops::InputOp* Graph::replaceWithInputNode(Operation* op) {
-  assert(op->getNumOutputs() <= 1
+  assert(op->getNumOutputs() == 1
          && "Only operations with single output value can be replaced with input node");
-  assert(op->getNextNodes().size() <= 1
-         && "Node with multiple outputs cannot be changed into input");
 
   auto in = create<ops::InputOp>(op->getName(), op->getOutputShape(0));
   replaceNode(op, in);
 
-  //replaceNode adds all connections of original node,
-  //but for input node we don't need input connections
-  in->getMutablePrevNodes().clear();
-
-  delete op;
-
   return dynamic_cast<ops::InputOp*>(in);
 }
 
@@ -146,16 +119,28 @@ void Graph::replaceInputNodes(const std::vector<std::string>& new_inputs) {
     }
   }
 
-  _inputs.clear();
-
   for (auto& op : ops_to_replace) {
     replaceWithInputNode(op);
   }
 }
 
 void Graph::removeNode(Operation* op) {
-  op->removeFromPrev();
-  op->removeFromNext();
+#ifndef NDEBUG
+  for (const auto& output : op->getOutputs()) {
+    assert(output.getConsumers().empty() && "Trying to remove a node that has uses.");
+  }
+#endif
+
+  for (auto& input : op->getInputs()) {
+    input.getProducer()->removeConsumer(&input);
+  }
+
+  if (op->getType() == Operation::Type::input)
+    _inputs.erase(std::remove(_inputs.begin(), _inputs.end(), op));
+
+  if (op->getType() == Operation::Type::output)
+    _outputs.erase(std::remove(_outputs.begin(), _outputs.end(), op));
+
   _ops.erase(op);
   delete op;
 }
index 94ef068..61d1c65 100644 (file)
@@ -29,11 +29,13 @@ std::vector<std::pair<Operation*, Operation*>> GraphPatternMatcher::matchEdge(
   std::vector<std::pair<Operation*, Operation*>> matches;
   for (auto* start: _g->getNodes()) {
     if (p1(start)) {
-      const auto& next_nodes = start->getNextNodes();
-      for (auto* end: next_nodes) {
-        if (p2(end)) {
-          matches.emplace_back(std::make_pair(start, end));
-          break;
+      for (const auto& out : start->getOutputs()) {
+        for (const auto* consumer : out.getConsumers()) {
+          Operation* end = consumer->getNode();
+          if (p2(end)) {
+            matches.emplace_back(std::make_pair(start, end));
+            break;
+          }
         }
       }
     }
index 96d704e..826cd0f 100644 (file)
@@ -49,6 +49,7 @@
 #include "core/modelIR/operations/TransposeOp.h"
 
 #include <iostream>
+#include <map>
 
 namespace nnc {
 namespace mir {
index c1da0b6..fcbb420 100644 (file)
 namespace nnc {
 namespace mir {
 
-Operation::Operation(Type type, const std::vector<IODescriptor>& args)
-  : _type(type), _num_inputs(args.size()), _num_outputs(1) {
-  _inputs.resize(_num_inputs);
-  for (std::size_t i = 0; i < _num_inputs; ++i) {
-    args[i].op->_outputs.push_back(this);
-    _inputs[i] = args[i];
+Operation::Operation(Type type, const std::vector<IODescriptor>& inputs, std::size_t num_outputs)
+    : _type(type) {
+  for (std::size_t i = 0; i < inputs.size(); ++i) {
+    _inputs.emplace_back(this, i, inputs[i]);
+  }
+  for (std::size_t i = 0; i < num_outputs; ++i) {
+    _outputs.emplace_back(this, i);
   }
-}
-
-const IODescriptor Operation::getOutput(std::size_t index) {
-  return IODescriptor{.op = this, .index = index};
-}
-
-const Shape& Operation::getInputShape(std::size_t index) const {
-  return getInput(index).getShape();
-}
-
-const Shape& Operation::getOutputShape(std::size_t index) const {
-  assert(index < getNumOutputs());
-  return _outputShapes.at(index);
-}
-
-void Operation::setOutputShape(std::size_t index, const Shape& shape) {
-  assert(index < getNumOutputs());
-  _outputShapes[index] = shape;
-}
-
-void Operation::setInput(const IODescriptor& descr, size_t i) {
-  descr.op->_outputs.emplace_back(this);
-  _inputs[i] = descr;
 }
 
 void Operation::accept(IVisitor* v) {
@@ -95,24 +73,5 @@ void Operation::accept(IVisitor* v) {
   }
 }
 
-void Operation::removeFromPrev() {
-  for (const auto& prev : _inputs) {
-    auto& mutable_next = prev.op->_outputs;
-    mutable_next.erase(std::find(std::begin(mutable_next), std::end(mutable_next), this));
-  }
-  _inputs.clear();
-}
-
-void Operation::removeFromNext() {
-  for (auto* next : _outputs) {
-    auto& mutable_prev = next->_inputs;
-    mutable_prev.erase(
-      std::remove_if(mutable_prev.begin(), mutable_prev.end(), [this](IODescriptor n) {
-        return n.op == this;
-      }), mutable_prev.end());
-  }
-  _outputs.clear();
-}
-
 } // namespace mir
 } // namespace nnc
index 2337527..7d44607 100644 (file)
@@ -24,9 +24,9 @@ namespace mir
 void IrDotBuilder::updateWithOp(Operation* op, const DotIrNodeInfo& irNodeInfo)
 {
   addNode(op, irNodeInfo);
-  for (auto &prev : op->getPrevNodes())
+  for (auto &prev : op->getInputs())
   {
-    addEdge(prev.op, op);
+    addEdge(prev.getProducer()->getNode(), op);
   }
 }
 
index 87e5f77..6f9ed18 100644 (file)
@@ -103,6 +103,7 @@ private:
 
   std::unordered_set<Operation*> _ops;
   size_t _lastNodeId = 0;
+  // TODO Change these to unordered_sets.
   std::vector<ops::InputOp*> _inputs;
   std::vector<ops::OutputOp*> _outputs;
 };
index 2c653ad..d0f5ec8 100644 (file)
 #ifndef _NNC_CORE_IR_MODEL_OPERATION_H_
 #define _NNC_CORE_IR_MODEL_OPERATION_H_
 
-#include <string>
-#include <map>
-#include "TensorVariant.h"
+#include "core/modelIR/Shape.h"
 #include "core/modelIR/Visitor.h"
-
-#include "Shape.h"
+#include <deque>
+#include <string>
+#include <unordered_set>
 
 namespace nnc {
 namespace mir {
 
-class Operation;
-
-struct IODescriptor {
-  Operation* op;
-  std::size_t index;
-  const Shape& getShape() const;
-};
-
 class Operation {
 public:
   enum class Type {
@@ -43,6 +34,83 @@ public:
 #undef HANDLE_OP
   };
 
+  class Input;
+
+  /// @brief Represents an output of a node.
+  class Output {
+  public:
+    Output(Operation* node, std::size_t index) : _node(node), _index(index) {}
+
+    ~Output() = default;
+
+    Output(const Output&) = delete;
+    Output(Output&&) = delete;
+    Output& operator=(const Output&) = delete;
+    Output& operator=(Output&&) = delete;
+
+    /// @brief Returns the node this is an output of.
+    Operation* getNode() const { return _node; }
+
+    /// @brief Returns the index of this output among all the ouptputs of the node.
+    std::size_t getIndex() const { return _index; }
+
+    /// @brief Returns the inputs that consume this output.
+    const std::unordered_set<Input*>& getConsumers() const { return _consumers; }
+
+    /// @brief Adds the specified input to the consumers of this output.
+    void addConsumer(Input* consumer) { _consumers.emplace(consumer); }
+
+    /// @brief Removes the specified input from the consumers of this output.
+    void removeConsumer(Input* consumer) { _consumers.erase(consumer); }
+
+    const Shape& getShape() const { return _shape; }
+    void setShape(const Shape& shape) { _shape = shape; }
+
+  private:
+    Operation* _node;
+    std::size_t _index;
+    std::unordered_set<Input*> _consumers;
+    Shape _shape;
+  };
+
+/// @brief Represents an input of a node.
+  class Input {
+  public:
+    Input(Operation* node, std::size_t index, Output* producer)
+        : _node(node), _index(index), _producer(producer) {
+      _producer->addConsumer(this);
+    }
+
+    ~Input() = default;
+
+    Input(const Input&) = delete;
+    Input(Input&&) = delete;
+    Input& operator=(const Input&) = delete;
+    Input& operator=(Input&&) = delete;
+
+    /// @brief Returns the node this is the input of.
+    Operation* getNode() const { return _node; }
+
+    /// @brief Returns the index of this output among all the inputs of the node.
+    std::size_t getIndex() const { return _index; }
+
+    /// @brief Returns the output that produces data for this input.
+    Output* getProducer() const { return _producer; }
+
+    /// @brief Replaces the output that produces data for this input with the specified one.
+    void replaceProducer(Output* producer) {
+      _producer->removeConsumer(this);
+      producer->addConsumer(this);
+      _producer = producer;
+    }
+
+  private:
+    Operation* _node;
+    std::size_t _index;
+    Output* _producer;
+  };
+
+
   virtual ~Operation() = default;
 
   Type getType() const { return _type; }
@@ -53,58 +121,62 @@ public:
   const std::string& getName() const { return _name; }
   void setName(const std::string& name) { _name = name; }
 
-  std::size_t getNumInputs() const { return _num_inputs; }
-  std::size_t getNumOutputs() const { return _num_outputs; }
+  std::size_t getNumInputs() const { return _inputs.size(); }
+  std::size_t getNumOutputs() const { return _outputs.size(); }
 
-  IODescriptor getInput(std::size_t index) const {
-    assert(index < _inputs.size());
-    return _inputs[index];
-  }
+  std::deque<Input>& getInputs() { return _inputs; }
+  const std::deque<Input>& getInputs() const { return _inputs; }
 
-  const IODescriptor getOutput(std::size_t index);
+  std::deque<Output>& getOutputs() { return _outputs; }
+  const std::deque<Output>& getOutputs() const { return _outputs; }
 
-  const std::vector<IODescriptor>& getPrevNodes() const { return _inputs; }
-  const std::vector<Operation*>& getNextNodes() const { return _outputs; }
+  Input* getInput(std::size_t index) {
+    assert(index < _inputs.size());
+    return &_inputs[index];
+  }
 
-  std::vector<IODescriptor>& getMutablePrevNodes() { return _inputs; }
-  std::vector<Operation*>& getMutableNextNodes() { return _outputs; }
+  const Input* getInput(std::size_t index) const {
+    assert(index < _inputs.size());
+    return &_inputs[index];
+  }
 
-  const nnc::mir::Shape& getInputShape(std::size_t index) const;
-  const nnc::mir::Shape& getOutputShape(std::size_t index) const;
+  Output* getOutput(std::size_t index) {
+    assert(index < _outputs.size());
+    return &_outputs[index];
+  }
 
-  /// @brief Removes links to this node from it's parents
-  void removeFromPrev();
+  const Output* getOutput(std::size_t index) const {
+    assert(index < _outputs.size());
+    return &_outputs[index];
+  }
 
-  /// @brief Removes links to this node from it's children
-  void removeFromNext();
+  const Shape& getInputShape(std::size_t index) const {
+    return getInput(index)->getProducer()->getShape();
+  }
 
-  /**
-   * @brief Set `descr` as `i`-th input of this node
-   * @param descr the tensor to be set as input
-   * @param i input index
-   */
-  void setInput(const IODescriptor& descr, size_t i);
+  const Shape& getOutputShape(std::size_t index) const {
+    return getOutput(index)->getShape();
+  }
 
   void accept(IVisitor* v);
 
 protected:
-  Operation(Type type, const std::vector<IODescriptor>& args);
-  void setOutputShape(std::size_t index, const nnc::mir::Shape& shape);
+  Operation(Type type, const std::vector<Output*>& inputs, std::size_t num_outputs = 1);
+
+  void setOutputShape(std::size_t index, const Shape& shape) {
+    getOutput(index)->setShape(shape);
+  }
 
 private:
   Type _type;
   std::size_t _id;
   std::string _name;
-  std::size_t _num_inputs;
-  std::size_t _num_outputs;
-  std::vector<IODescriptor> _inputs;
-  std::vector<Operation*> _outputs;
-  std::map<size_t, nnc::mir::Shape> _outputShapes;
+  std::deque<Input> _inputs;
+  std::deque<Output> _outputs;
 };
 
-inline const Shape& IODescriptor::getShape() const {
-  return op->getOutputShape(index);
-}
+// Convenient type alias for the duration of the transition process.
+using IODescriptor = Operation::Output*;
 
 } // namespace mir
 } // namespace nnc
index f70d56b..6c93678 100644 (file)
@@ -82,7 +82,7 @@ private:
   std::unordered_map<std::string, TensorVariant> _inputTensors;
 
   /// @brief Mapping of operations to their computed results.
-  std::unordered_map<std::size_t, std::vector<TensorVariant>> _opResults;
+  std::unordered_map<const Operation*, std::vector<TensorVariant>> _opResults;
 };
 
 } // namespace mir
index 7f5ec8c..61de4bb 100644 (file)
@@ -105,7 +105,7 @@ const ArtifactModule& AclCppOpGenerator::generate(mir::Graph* g) {
 }
 
 void AclCppOpGenerator::visit(ops::ConcatOp& op) {
-  const auto& ir_inputs = op.getPrevNodes();
+  const auto& ir_inputs = op.getInputs();
   IODescriptor ir_output = op.getOutput(0);
 
   static const char* axis_names[] = {"arm_compute::DataLayoutDimension::BATCHES",
@@ -123,8 +123,8 @@ void AclCppOpGenerator::visit(ops::ConcatOp& op) {
   auto inputs_var = _constrBlock->var("std::vector<arm_compute::ICLTensor*>", prefix + "_inputs");
   auto inputs = inputs_var->use();
 
-  for (IODescriptor ir_input : ir_inputs)
-    _constrBlock->call("push_back", {AF::ref(AF::id(tensorName(ir_input)))}, inputs);
+  for (const auto& ir_input : ir_inputs)
+    _constrBlock->call("push_back", {AF::ref(AF::id(tensorName(ir_input.getProducer())))}, inputs);
 
   auto layer = genLayer("arm_compute::CLConcatenateLayer", prefix,
                         {inputs, AF::ref(out), AF::lit(axis_name)});
@@ -143,12 +143,12 @@ void AclCppOpGenerator::visit(ops::DepthwiseConv2DOp& op) {
 
 void AclCppOpGenerator::visit(ops::SoftmaxOp& op) {
   assert(op.getNumInputs() == 1);
-  IODescriptor ir_input = op.getInput(0);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
   auto in = AF::id(tensorName(ir_input));
 
-  int rank = ir_output.getShape().rank();
+  int rank = ir_output->getShape().rank();
   // CLPermute does not support all kinds of permutations now.
   // rank can be more than 2 in our models, so we can not use CLTranspose.
   // This means we can support tensors with no more then one axis > 1.
@@ -157,7 +157,7 @@ void AclCppOpGenerator::visit(ops::SoftmaxOp& op) {
   int nof_long_axes = 0;
 
   for (int i = 0; i < rank; ++i) {
-    if (ir_output.getShape().dim(i) > 1)
+    if (ir_output->getShape().dim(i) > 1)
       ++nof_long_axes;
   }
 
@@ -183,7 +183,7 @@ void AclCppOpGenerator::visit(ops::SoftmaxOp& op) {
     // Then we need two tensors for intermediate results. This is because we do a couple of auxiliary
     // reshapes: one to transform the input tensor to a unidimensional tensor and the second to
     // transorm the result of the softmax operation back to the original form.
-    Shape sm_shape(ir_output.getShape());
+    Shape sm_shape(ir_output->getShape());
 
     std::swap(sm_shape.dim(axis), sm_shape.dim(-1));
 
@@ -262,7 +262,7 @@ AclCppOpGenerator::genTransposeACLtoMIR(const string& name,
 
 void AclCppOpGenerator::visit(ops::PoolOp& op) {
   assert(op.getNumInputs() == 1);
-  IODescriptor ir_input = op.getInput(0);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
   const char* pooling_type = nullptr;
@@ -286,7 +286,7 @@ void AclCppOpGenerator::visit(ops::PoolOp& op) {
   // Transpose data from MIR format to format compatible with ACL
   const string transposed_input_name = output_tensor_name + "transposed_input";
   shared_ptr<ArtifactId> transposed_input =
-      genTransposeMIRtoACL(transposed_input_name, ir_input.getShape(), in_id);
+      genTransposeMIRtoACL(transposed_input_name, ir_input->getShape(), in_id);
 
   const string layer_name = output_tensor_name + "_pooling_layer";
 
@@ -310,7 +310,7 @@ void AclCppOpGenerator::visit(ops::PoolOp& op) {
   shared_ptr<ArtifactId> pooling_info = pooling_info_var->use();
 
   // Generate auxiliary tensor to hold transposed output of pool in NCHW format
-  Shape transposed_output_shape = transposeShape<0, 3, 1, 2>(ir_output.getShape());
+  Shape transposed_output_shape = transposeShape<0, 3, 1, 2>(ir_output->getShape());
   shared_ptr<ArtifactId> transposed_output =
       genTensor(layer_name + "_out_transpose", transposed_output_shape);
 
@@ -329,11 +329,11 @@ void AclCppOpGenerator::visit(ops::PoolOp& op) {
 
 void AclCppOpGenerator::visit(ops::FullyConnectedOp& op) {
   assert(op.getNumInputs() == 2);
-  IODescriptor ir_input = op.getInput(0);
-  IODescriptor ir_weights = op.getInput(1);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
+  IODescriptor ir_weights = op.getInput(1)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
-  auto ir_weights_op = dynamic_cast<mir::ops::ConstantOp*>(ir_weights.op);
+  auto ir_weights_op = dynamic_cast<mir::ops::ConstantOp*>(ir_weights->getNode());
   if (ir_weights_op == nullptr)
     throw AclCppException("Unsupported operation type");
 
@@ -344,7 +344,7 @@ void AclCppOpGenerator::visit(ops::FullyConnectedOp& op) {
   auto in = AF::id(tensorName(ir_input));
 
   // Create the output tensor in the DOM.
-  if (ir_output.getShape().rank() != 2)
+  if (ir_output->getShape().rank() != 2)
     throw AclCppException("Unsupported number of dimensions in fc layer");
   auto out = genTensor(ir_output);
   string operation_name = out->name() + "_fully_connected_layer";
@@ -373,11 +373,11 @@ void AclCppOpGenerator::visit(ops::CappedReluOp& op) {
 
 void AclCppOpGenerator::visit(ops::BiasAddOp& op) {
   assert(op.getNumInputs() == 2);
-  IODescriptor ir_input = op.getInput(0);
-  IODescriptor ir_weights = op.getInput(1);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
+  IODescriptor ir_weights = op.getInput(1)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
-  auto ir_weights_op = dynamic_cast<ops::ConstantOp*>(ir_weights.op);
+  auto ir_weights_op = dynamic_cast<ops::ConstantOp*>(ir_weights->getNode());
   if (ir_weights_op == nullptr)
     throw AclCppException("Unsupported operation type");
 
@@ -394,7 +394,7 @@ void AclCppOpGenerator::visit(ops::BiasAddOp& op) {
   shared_ptr<ArtifactId> transposed_output;
 
   // Create the output tensor in the DOM and obtain its identifier.
-  const Shape& out_shape = ir_output.getShape();
+  const Shape& out_shape = ir_output->getShape();
   const string transposed_output_name = output_tensor_name + "_transposed_output";
 
   switch (out_shape.rank()) {
@@ -402,7 +402,7 @@ void AclCppOpGenerator::visit(ops::BiasAddOp& op) {
       // transpose input to NCHW format supported by ACL
       const string transposed_input_name = output_tensor_name + "_transposed_input";
       transposed_output_shape = transposeShape<0, 3, 1, 2>(out_shape);
-      transposed_input = genTransposeMIRtoACL(transposed_input_name, ir_input.getShape(), input);
+      transposed_input = genTransposeMIRtoACL(transposed_input_name, ir_input->getShape(), input);
 
       transposed_output =
           genTensor(transposed_output_name, transposed_output_shape);
@@ -422,7 +422,7 @@ void AclCppOpGenerator::visit(ops::BiasAddOp& op) {
   string layer_name = transposed_output->name() + "_bias_add_layer";
 
   // Reshape the IR biases tensor and generate the corresponding DOM tensor.
-  const auto& ir_input_shape = ir_input.getShape();
+  const auto& ir_input_shape = ir_input->getShape();
   Shape ir_biases_shape(ir_input_shape.rank());
 
   // ACL CLArithmeticAddition supports input tensors broadcasting.
@@ -463,26 +463,22 @@ void AclCppOpGenerator::visit(ops::InputOp& op) {
 static bool shouldSerializeConstant(ops::ConstantOp& op) {
   // Operations from 'self_serializing_ops_to_inputs' serializing tensors with appropriate index themselves,
   // so we don't serialize them here, also we don't serialize tensors from dangling ConstantOp
-  static std::map<Operation::Type, int> self_serializing_ops_to_inputs{
+  static std::map<Operation::Type, std::size_t> self_serializing_ops_to_inputs{
           {Operation::Type::scale, 1},
           {Operation::Type::conv2D, 1},
           {Operation::Type::fullyConnected, 1},
           {Operation::Type::biasAdd, 1}};
 
-  for (auto& next_node : op.getNextNodes()) {
-    auto self_serializing_op_it = self_serializing_ops_to_inputs.find(next_node->getType());
+  for (const auto* consumer : op.getOutput(0)->getConsumers()) {
+    auto self_serializing_op_it = self_serializing_ops_to_inputs.find(consumer->getNode()->getType());
     // Serialize if next_node type not from 'self_serializing_ops_to_inputs'
     if (self_serializing_op_it == self_serializing_ops_to_inputs.end())
       return true;
 
     // If next_node has current ConstantOp as it's previous node, but not with appropriate index -
     // serialize current ConstantOp
-    int serializing_input_index = self_serializing_op_it->second;
-    auto next_node_prev_nodes = static_cast<int>(next_node->getPrevNodes().size());
-    for (int i = 0; i < next_node_prev_nodes; ++i) {
-      if (next_node->getPrevNodes()[i].op == &op && i != serializing_input_index)
-        return true;
-    }
+    if (self_serializing_op_it->second != consumer->getIndex())
+      return true;
   }
 
   return false;
@@ -503,14 +499,14 @@ void AclCppOpGenerator::visit(ops::ReluOp& op) {
 
 void AclCppOpGenerator::visit(ops::ReshapeOp& op) {
   assert(op.getNumInputs() == 1);
-  IODescriptor ir_input = op.getInput(0);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
   // Get the id of the input tensor in the generated artifact.
   auto in = AF::id(tensorName(ir_input));
 
   // Create the output tensor in the DOM and return its id.
-  const Shape& out_shape = ir_output.getShape();
+  const Shape& out_shape = ir_output->getShape();
 
   // This check confirms that we can "safely" reshape data
   // The only safe configuration of output shape is (1...1, N, 1 ... 1)
@@ -536,11 +532,11 @@ void AclCppOpGenerator::visit(ops::ScaleOp& op) {
   // May be not a perfect implementation, using the CLPixelWiseMultiplication ACL function taking
   // two input tensors with the same shapes.
   assert(op.getNumInputs() == 2);
-  IODescriptor ir_input = op.getInput(0);
-  IODescriptor ir_weights = op.getInput(1);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
+  IODescriptor ir_weights = op.getInput(1)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
-  auto ir_weights_op = dynamic_cast<ops::ConstantOp*>(ir_weights.op);
+  auto ir_weights_op = dynamic_cast<ops::ConstantOp*>(ir_weights->getNode());
   if (ir_weights_op == nullptr)
     throw AclCppException("Unsupported operation type");
 
@@ -555,10 +551,10 @@ void AclCppOpGenerator::visit(ops::ScaleOp& op) {
   // transpose input to NCHW format supported by ACL
   const string transposed_input_name = output_tensor_name + "_transposed_input";
   shared_ptr<ArtifactId> transposed_input =
-      genTransposeMIRtoACL(transposed_input_name, ir_input.getShape(), input);
+      genTransposeMIRtoACL(transposed_input_name, ir_input->getShape(), input);
 
   // Create the output tensor in the DOM and obtain its identifier.
-  const Shape& out_shape = ir_output.getShape();
+  const Shape& out_shape = ir_output->getShape();
   Shape transposed_output_shape;
   switch (out_shape.rank()) {
     case 4:
@@ -579,7 +575,7 @@ void AclCppOpGenerator::visit(ops::ScaleOp& op) {
   auto operation_name = transposed_output->name() + "_scale_layer";
 
   // Reshape the IR scales tensor and generate the corresponding DOM tensor.
-  const Shape ir_input_shape = transposeShape<0, 3, 1, 2>(ir_input.getShape());
+  const Shape ir_input_shape = transposeShape<0, 3, 1, 2>(ir_input->getShape());
   Shape ir_scales_shape(ir_input_shape.rank());
 
   // ACL CLArithmeticDivision supports input tensors broadcasting.
@@ -638,7 +634,7 @@ void AclCppOpGenerator::visit(ops::BatchNormOp&) {
 
 void AclCppOpGenerator::visit(ops::DropoutOp& op) {
   assert(op.getNumInputs() == 1);
-  IODescriptor ir_input = op.getInput(0);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
   // Just copy input tensor to the output one.
@@ -658,7 +654,7 @@ void AclCppOpGenerator::visit(ops::TanhOp& op) {
 
 void AclCppOpGenerator::visit(ops::ElementwiseOp& op) {
   assert(op.getNumInputs() >= 2);
-  const auto& ir_inputs = op.getPrevNodes();
+  const auto& ir_inputs = op.getInputs();
   IODescriptor ir_output = op.getOutput(0);
 
   // Create the output tensor in the DOM and obtain its identifier.
@@ -666,10 +662,10 @@ void AclCppOpGenerator::visit(ops::ElementwiseOp& op) {
   addToPersistentTensors(out);
 
   // Get the identifier of the first input tensor in the DOM.
-  auto in1 = AF::id(tensorName(ir_inputs[0]));
+  auto in1 = AF::id(tensorName(ir_inputs[0].getProducer()));
 
   for (size_t i = 1; i < ir_inputs.size(); ++i) {
-    IODescriptor ir_input = ir_inputs[i];
+    IODescriptor ir_input = ir_inputs[i].getProducer();
 
     // Get the identifier of the second input tensor in the DOM.
     auto in2 = AF::id(tensorName(ir_input));
@@ -679,11 +675,11 @@ void AclCppOpGenerator::visit(ops::ElementwiseOp& op) {
     // Different ACL layers used to implement different types of elementwise operations.
     switch (op.getOpType()) {
       case ops::ElementwiseOp::OpType::mul:
-        in1 = genMultiplication(out->name() + "_" + "multiplication", i - 1, ir_input.getShape(),
+        in1 = genMultiplication(out->name() + "_" + "multiplication", i - 1, ir_input->getShape(),
                                 in1, in2, i == ir_inputs.size() - 1 ? out : nullptr);
         break;
       case ops::ElementwiseOp::OpType::add:
-        in1 = genAddition(out->name() + "_" + "addition", i - 1, ir_input.getShape(),
+        in1 = genAddition(out->name() + "_" + "addition", i - 1, ir_input->getShape(),
                           in1, in2, i == ir_inputs.size() - 1 ? out : nullptr);
         break;
       default:
@@ -702,7 +698,7 @@ void AclCppOpGenerator::visit(ops::EluOp&) {
 
 void AclCppOpGenerator::visit(ops::PadOp& op) {
   assert(op.getNumInputs() == 1);
-  IODescriptor ir_input = op.getInput(0);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
   // Get the id of the input tensor.
@@ -733,11 +729,11 @@ void AclCppOpGenerator::visit(ops::PadOp& op) {
 
 template <typename Op>
 void AclCppOpGenerator::genConvolution(Op& op, const string& acl_func_name, const string& suffix) {
-  IODescriptor ir_input = op.getInput(0);
-  IODescriptor ir_weights = op.getInput(1);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
+  IODescriptor ir_weights = op.getInput(1)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
-  auto ir_weights_op = dynamic_cast<ops::ConstantOp*>(ir_weights.op);
+  auto ir_weights_op = dynamic_cast<ops::ConstantOp*>(ir_weights->getNode());
   if (ir_weights_op == nullptr)
     throw AclCppException("Unsupported operation type");
 
@@ -759,11 +755,11 @@ void AclCppOpGenerator::genConvolution(Op& op, const string& acl_func_name, cons
 
   // Generate auxiliary tensor to hold transposed input of convolution in NCHW format
   shared_ptr<ArtifactId> transposed_input =
-      genTransposeMIRtoACL(output_tensor_name + "_transposed_input", ir_input.getShape(), input);
+      genTransposeMIRtoACL(output_tensor_name + "_transposed_input", ir_input->getShape(), input);
 
   // Create the transposed output tensor in the DOM.
   const string transposed_output_name = output_tensor_name + "_transposed_output";
-  Shape transposed_output_shape = transposeShape<0, 3, 1, 2>(ir_output.getShape());
+  Shape transposed_output_shape = transposeShape<0, 3, 1, 2>(ir_output->getShape());
   shared_ptr<ArtifactId> transposed_output =
       genTensor(transposed_output_name, transposed_output_shape);
 
@@ -810,7 +806,7 @@ void AclCppOpGenerator::genConvolution(Op& op, const string& acl_func_name, cons
 void AclCppOpGenerator::genActivation(mir::Operation& op, const std::string& activation_name,
                                       float a, float b) {
   assert(op.getNumInputs() == 1);
-  IODescriptor ir_input = op.getInput(0);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
   // Get the id of the input tensor.
@@ -927,7 +923,7 @@ string AclCppOpGenerator::tensorName(IODescriptor ir_tensor) const {
   string tensor_name;
 
   // TODO Use the tensor name instead of the operation name.
-  const auto& op_name = ir_tensor.op->getName();
+  const auto& op_name = ir_tensor->getNode()->getName();
 
   if (!op_name.empty()) {
     tensor_name = "_" + op_name;
@@ -935,7 +931,7 @@ string AclCppOpGenerator::tensorName(IODescriptor ir_tensor) const {
                tensor_name.end(),
                [](char c) { return std::isalnum(c) == 0; }, '_');
   } else {
-    tensor_name = "tensor_" + to_string(ir_tensor.op->getId());
+    tensor_name = "tensor_" + to_string(ir_tensor->getNode()->getId());
   }
 
   return tensor_name;
@@ -985,23 +981,26 @@ shared_ptr<ArtifactId> AclCppOpGenerator::genTensor(const string& name,
 }
 
 shared_ptr<ArtifactId> AclCppOpGenerator::genTensor(IODescriptor ir_tensor) {
-  return genTensor(tensorName(ir_tensor), ir_tensor.getShape(), !ir_tensor.op->getName().empty());
+  return genTensor(tensorName(ir_tensor), ir_tensor->getShape(),
+                   !ir_tensor->getNode()->getName().empty());
 }
 
 void AclCppOpGenerator::genNamed(Graph* graph) {
   const auto& inputs = graph->getInputs();
   if (inputs.size() == 1) {
+    auto* input_op = inputs[0];
     auto f = _artifactClass->func(true, "arm_compute::CLTensor&", "getInput");
     auto b = f->getBlock();
-    auto id = AF::id(tensorName(inputs[0]->getOutput(0)));
+    auto id = AF::id(tensorName(input_op->getOutput(0)));
     b->ret(id);
   }
 
   const auto& outputs = graph->getOutputs();
   if (outputs.size() == 1) {
+    auto* output_op = outputs[0];
     auto f = _artifactClass->func(true, "arm_compute::CLTensor&", "getOutput");
     auto b = f->getBlock();
-    auto id = AF::id(tensorName(outputs[0]->getInput(0)));
+    auto id = AF::id(tensorName(output_op->getInput(0)->getProducer()));
     b->ret(id);
   }
 }
@@ -1133,7 +1132,7 @@ void AclCppOpGenerator::genTranspose(const std::shared_ptr<nnc::ArtifactId>& inp
 
 void AclCppOpGenerator::visit(mir::ops::TransposeOp& op) {
   assert(op.getNumInputs() == 1);
-  IODescriptor ir_input = op.getInput(0);
+  IODescriptor ir_input = op.getInput(0)->getProducer();
   IODescriptor ir_output = op.getOutput(0);
 
   // Get the input node tensor id in the DOM.
@@ -1141,7 +1140,7 @@ void AclCppOpGenerator::visit(mir::ops::TransposeOp& op) {
   const vector<size_t>& mir_axis_order = op.getAxisOrder();
 
   // Create the output tensor in the DOM.
-  if (ir_output.getShape().rank() != 4)
+  if (ir_output->getShape().rank() != 4)
     throw AclCppException("Unsupported number of dimensions in transpose operation");
   // TODO replace transpose shape
   shared_ptr<ArtifactId> output = genTensor(ir_output);
index 46a64e0..a063967 100644 (file)
@@ -227,7 +227,7 @@ void Caffe2Importer::createMIRNodesFromOp(const OperatorDef& op) {
     _blobNameToIODescriptor[op.output(i)] = outputs.at(i);
   }
 
-  _lastMIROp = outputs.at(0).op;
+  _lastMIROp = outputs.at(0)->getNode();
 }
 
 mir::TensorVariant Caffe2Importer::createTensor(const OperatorDef& op) {
index d2f249b..d119a93 100644 (file)
@@ -132,7 +132,7 @@ static Shape getWindowShape(const ::caffe2::OperatorDef& op,
 
   int kernel_h(0), kernel_w(0);
   if (is_global_pooling) {
-    const auto& input_shape = inputs[0].getShape();
+    const auto& input_shape = inputs[0]->getShape();
     assert(input_shape.rank() == 4 && "getWindowShape() inputs must be of rank 4");
     kernel_h = input_shape.dim(2);
     kernel_w = input_shape.dim(3);
@@ -237,7 +237,7 @@ Caffe2OpCreator::convertAdd(const std::vector<mir::IODescriptor>& inputs,
 
   std::vector<mir::IODescriptor> add_input;
   for (const auto& i : inputs)
-    add_input.push_back(convertCaffeToMIR(i.op->getOutput(0)));
+    add_input.push_back(convertCaffeToMIR(i));
 
   // check mir tensors contain operand
   if (mir_tensors.find(op.input(1)) != mir_tensors.end()) {
@@ -333,7 +333,7 @@ Caffe2OpCreator::convertFullyConnected(const std::vector<IODescriptor>& inputs,
                                        const MIRTensors& mir_tensors) {
   auto weights_tensor = transposeTensor<1, 0>(mir_tensors.at(op.input(1)));
 
-  const auto& input_shape = inputs[0].getShape();
+  const auto& input_shape = inputs[0]->getShape();
   // Transform input into 2-D tensor by flattening axes
   Shape shape{input_shape.dim(0), input_shape.numElements() / input_shape.dim(0)};
 
@@ -371,7 +371,7 @@ Caffe2OpCreator::convertMul(const std::vector<mir::IODescriptor>& inputs,
 
   std::vector<IODescriptor> input_descriptors;
   for (const auto& i: inputs)
-    input_descriptors.push_back(convertCaffeToMIR(i.op->getOutput(0)));
+    input_descriptors.push_back(convertCaffeToMIR(i));
 
   // TODO: replace ConstantOp on inputs
   if (mir_tensors.find(op.input(1)) != mir_tensors.end()) {
@@ -395,7 +395,7 @@ Caffe2OpCreator::convertResizeNearest(const std::vector<IODescriptor>& inputs,
                                       const ::caffe2::OperatorDef& op) {
   // assume NCHW and convert to MIR (NHWC)
   std::vector<float> scales(4);
-  assert(inputs[0].getShape().rank() == 4 && "only 4d tensors is supported");
+  assert(inputs[0]->getShape().rank() == 4 && "only 4d tensors is supported");
   scales[0] = 1;
   // default to noop
   scales[1] = getSingleArgument(op, "height_scale", 1.0f);
@@ -454,9 +454,9 @@ Caffe2OpCreator::convertSpatialBN(const std::vector<mir::IODescriptor>& inputs,
 }
 
 std::vector<IODescriptor> Caffe2OpCreator::convertSum(const std::vector<IODescriptor>& inputs) {
-  const auto& input_shape = inputs[0].getShape();
+  const auto& input_shape = inputs[0]->getShape();
   for (auto& in : inputs)
-    assert(input_shape == in.getShape() && "All Sum inputs must have same shape");
+    assert(input_shape == in->getShape() && "All Sum inputs must have same shape");
 
   auto op = createOp<ops::ElementwiseOp>("Elementwise_Add", inputs, ops::ElementwiseOp::OpType::add);
   return {op->getOutput(0)};
index cb0f4f9..a715139 100644 (file)
@@ -222,8 +222,8 @@ void CaffeImporter::setGraphOutputs() {
   //   - there is exactly one output;
   //   - the output is from the last layer.
   auto output = _blobNameToIODescriptor[last_layer.top(0)];
-  _graph->create<mir::ops::OutputOp>(output.op->getName(), output);
-  output.op->setName("");
+  _graph->create<mir::ops::OutputOp>(output->getNode()->getName(), output);
+  output->getNode()->setName("");
 }
 
 void CaffeImporter::cleanup() {
index 32bd548..3c77fe4 100644 (file)
@@ -85,7 +85,7 @@ mir::IODescriptor CaffeOpCreator::createMul(mir::IODescriptor arg1, mir::IODescr
 /// @brief Split arg into @p num_parts equal parts along @p axis axis.
 std::vector<mir::IODescriptor>
 CaffeOpCreator::createSplit(mir::IODescriptor arg, int32_t num_parts, int32_t axis) {
-  const auto& arg_shape = arg.getShape();
+  const auto& arg_shape = arg->getShape();
 
   assert(axis >= 0 && axis < arg_shape.rank());
   int32_t part_size = arg_shape.dim(axis) / num_parts;
@@ -109,8 +109,8 @@ IODescriptor
 CaffeOpCreator::createFullyConnected(const mir::IODescriptor& input,
                                      const mir::IODescriptor& weights,
                                      int32_t axis) {
-  const auto& input_shape = input.getShape();
-  const auto& weights_shape = weights.getShape();
+  const auto& input_shape = input->getShape();
+  const auto& weights_shape = weights->getShape();
 
   assert(axis >= 0 && axis < input_shape.rank());
   assert(weights_shape.rank() == 2);
@@ -405,7 +405,7 @@ CaffeOpCreator::convertPooling(const caffe::LayerParameter& layer,
   Shape strides;
   std::vector<int32_t> padding_before, padding_after;
 
-  const auto& input_shape = inputs[0].getShape();
+  const auto& input_shape = inputs[0]->getShape();
   convertPoolingParam(opts, input_shape, window_shape, strides, padding_before, padding_after);
 
   ops::PoolOp::PoolingType pool_type = getPoolingType(opts);
@@ -434,7 +434,7 @@ CaffeOpCreator::convertSoftmax(const caffe::LayerParameter& layer,
 
   // CPP and ACL backends are able to perform Softmax only along the last axis.
   // FIXME Do it in backends.
-  if (inputs[0].getShape().rank() == 4) {
+  if (inputs[0]->getShape().rank() == 4) {
     // For now, we only account for the most common case.
     if (params.axis() != 1)
       throw PassException("Softmax: unsupported axis");
@@ -729,7 +729,7 @@ CaffeOpCreator::convertLSTM(const caffe::LayerParameter& layer,
   auto cont = inputs[1];
   assert(inputs.size() == 2);
 
-  const auto& x_shape = x.getShape();
+  const auto& x_shape = x->getShape();
   const int32_t seq_length = x_shape.dim(0);
   const int32_t batch_size = x_shape.dim(1);
   const int32_t hidden_size = params.num_output();
index 16d857a..6fb78ff 100644 (file)
@@ -81,51 +81,15 @@ using namespace nnc::mir;
 std::vector<std::reference_wrapper<const TensorVariant>>
 NNInterpreter::getInputTensors(const Operation& op) {
   std::vector<std::reference_wrapper<const TensorVariant>> tensors;
-  for (IODescriptor ir_tensor : op.getPrevNodes())
-    tensors.emplace_back(_opResults.at(ir_tensor.op->getId()).at(ir_tensor.index));
+  for (const auto& input : op.getInputs()) {
+    const auto* producer = input.getProducer();
+    tensors.emplace_back(_opResults.at(producer->getNode()).at(producer->getIndex()));
+  }
   return tensors;
 }
 
 void NNInterpreter::setOutputTensors(const Operation& op, std::vector<TensorVariant>&& outputs) {
-  _opResults.emplace(op.getId(), std::move(outputs));
-}
-
-static void dumpIndex(Index ndx) {
-  for (int i = 0; i < ndx.rank(); i++) {
-    std::cout << (i ? "," : "(") << ndx.at(i);
-  }
-  std::cout << ")\t";
-}
-
-#if(0)
-#define DUMP(x, y) dump(x, (y))
-#else
-#define DUMP(x, y)
-#endif
-
-void NNInterpreter::dump(Operation& op, bool all) {
-  // TODO: in theory there could be several outputs from the given 'op'.
-  TensorVariant tensor = _opResults.at(op.getId()).at(0);
-  auto shape = tensor.getShape();
-  std::cout << "Tensor '" <<
-            (op.getNextNodes().size() ? op.getNextNodes()[0]->getName() : "output") <<
-            "' DType = " << (int)tensor.getDataType() << ", ElementSize = " <<
-            tensor.getElementSize() << ", Shape" << shape;
-  std::cout << " ElementsNumber " << shape.numElements() << "\n";
-  static bool do_it = false;
-  if (do_it || all) {
-    auto last_idx = shape.rank() - 1;
-    for (auto idx : ShapeRange(shape)) {
-      if (!(idx.at(last_idx) % 15))
-        std::cout << "\n";
-      dumpIndex(idx);
-      if (tensor.getDataType() == DTYPE::FLOAT32)
-        std::cout << *(float_t*)tensor.at(idx) << "\t";
-      else
-        std::cout << *(int32_t*)tensor.at(idx) << "\t";
-    }
-    std::cout << "\n";
-  }
+  _opResults.emplace(&op, std::move(outputs));
 }
 
 void NNInterpreter::setInput(const std::string &name, const TensorVariant& t) {
@@ -133,7 +97,7 @@ void NNInterpreter::setInput(const std::string &name, const TensorVariant& t) {
 }
 
 TensorVariant NNInterpreter::getResult(IODescriptor tensor) {
-  return _opResults.at(tensor.op->getId()).at(tensor.index);
+  return _opResults.at(tensor->getNode()).at(tensor->getIndex());
 }
 
 void NNInterpreter::visit(ops::InputOp& op) {
@@ -153,21 +117,18 @@ void NNInterpreter::visit(ops::ConcatOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = Concat<float>(inputs, op.getOutputShape(0), op.getAxis())();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::Conv2DOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = Conv2D(inputs[0], inputs[1], op)();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, true);
 }
 
 void NNInterpreter::visit(ops::ReshapeOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = Reshape<float>(inputs[0], op.getOutputShape(0))();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::ReluOp& op) {
@@ -176,7 +137,6 @@ void NNInterpreter::visit(ops::ReluOp& op) {
   auto outputs = Fill<float>(op.getOutputShape(0),
                              [&input](const Index& id) { return std::max(input.at(id), 0.0f); })();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::SigmoidOp& op) {
@@ -192,14 +152,12 @@ void NNInterpreter::visit(ops::SoftmaxOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = Softmax(op.getInputShape(0), inputs[0], op.getAxis())();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::PoolOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = Pool(inputs[0], op)();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::FullyConnectedOp& op) {
@@ -227,28 +185,24 @@ void NNInterpreter::visit(ops::DepthwiseConv2DOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = DepthwiseConv2D(inputs[0], inputs[1], op)();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, true);
 }
 
 void NNInterpreter::visit(ops::BiasAddOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = BiasAdd(inputs[0], inputs[1])();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::BatchNormOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = BatchNorm<float>(inputs[0], op)();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::ScaleOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = Scale(inputs[0], inputs[1])();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::SliceOp& op) {
@@ -267,7 +221,6 @@ void NNInterpreter::visit(ops::DropoutOp& op) {
   // TODO implement this
   auto outputs = Dropout<float>(input, op)();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::TanhOp& op) {
@@ -322,14 +275,12 @@ void NNInterpreter::visit(ops::ElementwiseOp& op) {
     return acc;
   })();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::DeConv2DOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = DeConv2D(inputs[0], inputs[1], op)();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::EluOp& op) {
@@ -342,7 +293,6 @@ void NNInterpreter::visit(ops::EluOp& op) {
       return op.getAlpha() * (expf(input.at(id)) - 1);
   })();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::SqueezeOp& op) {
@@ -350,14 +300,12 @@ void NNInterpreter::visit(ops::SqueezeOp& op) {
   // Squeeze is just a special case of reshape.
   auto outputs = Reshape<float>(inputs[0], op.getOutputShape(0))();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::PadOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = Pad(inputs[0], op)();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::SqrtOp& op) {
@@ -383,8 +331,6 @@ void NNInterpreter::visit(ops::ResizeOp& op) {
     return input.at(in_idx);
   })();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
-
 }
 
 void NNInterpreter::visit(ops::ReduceFOp& op) {
@@ -403,14 +349,12 @@ void NNInterpreter::visit(ops::ReduceFOp& op) {
     return out_t.at(id) / reduction_area;
   })();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::TransposeOp& op) {
   auto inputs = getInputTensors(op);
   auto outputs = Transpose(inputs[0], op)();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::GatherOp& op) {
@@ -428,7 +372,6 @@ void NNInterpreter::visit(ops::LeakyReluOp& op) {
     return val > 0.0f ? val : val * alpha;
   })();
   setOutputTensors(op, std::move(outputs));
-  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::OutputOp&) {
index 6205d25..5d770cc 100644 (file)
@@ -133,7 +133,7 @@ PassData InterpreterPass::run(PassData data) {
   g->accept(&interpreter);
 
   for (auto out_node : g->getOutputs()) {
-    const auto& tensor = interpreter.getResult(out_node->getInput(0));
+    const auto& tensor = interpreter.getResult(out_node->getInput(0)->getProducer());
 
 #ifdef NNC_HDF5_SUPPORTED
     writeTensorToHDF5File(tensor, out_node->getName(), cli::artifactDir);
index 6929292..cd44765 100644 (file)
@@ -166,57 +166,6 @@ void ONNXImporterImpl::createGraphInputs() {
   }
 }
 
-void ONNXImporterImpl::dump(const std::vector<mir::IODescriptor>& input_descrs,
-                            const std::vector<mir::IODescriptor>& out_descrs,
-                            const onnx::NodeProto& onnx_node) {
-  for (auto out_descr : out_descrs) {
-    auto op = out_descr.op;
-    std::cout << onnx_node.op_type() << " '" << op->getName() << "'";
-    if (input_descrs[0].op->getNumInputs() > 0) {
-      std::cout << "Input Shape: " << input_descrs[0].getShape();
-    }
-    std::cout << " Output Shape: " << op->getOutputShape(0);
-    auto* onnx_op_type = ONNXPerfectHash::getONNXOpType(onnx_node.op_type().c_str(), onnx_node.op_type().size());
-    switch (onnx_op_type->opCode) {
-      case ONNXOpCode::opConv: {
-        assert(dynamic_cast<mir::ops::TransposeOp*>(op) != nullptr);
-        if (auto* conv = dynamic_cast<mir::ops::Conv2DOp*>(op->getPrevNodes()[0].op)) {
-          std::cout << " (Conv2D)Weights" << conv->getInputShape(1) << " Strides" <<
-                    conv->getStrides() << " Padding(" << conv->getPaddingBefore()[0] <<
-                    " " << conv->getPaddingBefore()[1] << ")" << ":(" <<
-                    conv->getPaddingAfter()[0] << " " << conv->getPaddingAfter()[1] << ")";
-        } else {
-          auto* dept = dynamic_cast<mir::ops::DepthwiseConv2DOp*>(op->getPrevNodes()[0].op);
-          assert(dept);
-          std::cout << " (DepthwiseConv2D)Weights" << dept->getInputShape(1) << " Strides" <<
-                    dept->getStrides() << " Padding(" << dept->getPaddingBefore()[0] <<
-                    " " << dept->getPaddingBefore()[1] << ")" << ":(" <<
-                    dept->getPaddingAfter()[0] << " " << dept->getPaddingAfter()[1] << ")";
-        }
-        break;
-      }
-      case ONNXOpCode::opGlobalAveragePool:
-      case ONNXOpCode::opAveragePool:
-      case ONNXOpCode::opMaxPool: {
-        auto *pool = dynamic_cast<mir::ops::PoolOp *>(op);
-        if (pool == nullptr) {
-          assert(dynamic_cast<mir::ops::TransposeOp *>(op) != nullptr);
-          pool = dynamic_cast<mir::ops::PoolOp *>(op->getPrevNodes()[0].op);
-        }
-        assert(pool);
-        std::cout << " Kernel " << pool->getWindowShape() << " Strides  " << pool->getStrides();
-        std::cout << " Padding before:  " << pool->getPaddingBefore()[0] << " " <<
-                  pool->getPaddingBefore()[1];
-        std::cout << " After:  " << pool->getPaddingAfter()[0] << " " << pool->getPaddingAfter()[1];
-        break;
-      }
-      default:
-        break;
-    }
-    std::cout << "\n";
-  }
-}
-
 mir::Graph *ONNXImporterImpl::createIR() {
   GOOGLE_PROTOBUF_VERIFY_VERSION;
 
@@ -323,23 +272,20 @@ mir::Graph *ONNXImporterImpl::createIR() {
     }
     // Set outputs' names
     for (int i = 0; i < outputs.size(); i++) {
-      outputs[i].op->setName(onnx_node.output(i));
-      auto result = _tensorNameToIODescriptor.emplace(outputs[i].op->getName(), outputs[i]);
+      outputs[i]->getNode()->setName(onnx_node.output(i));
+      auto result = _tensorNameToIODescriptor.emplace(outputs[i]->getNode()->getName(), outputs[i]);
       if(!result.second)
-        throw PassException("Name duplication: " + outputs[i].op->getName());
+        throw PassException("Name duplication: " + outputs[i]->getNode()->getName());
     }
     assert (outputs.size());
     // FIXME: it should be done properly via the given graph outputs
     _graphOutputs.assign(outputs.begin(), outputs.end());
-#if 0
-    dump(inputs, outputs, onnx_node);
-#endif
   }
   // set graph outputs
   // TODO: it should be done with onnx graph outputs
   for (auto output : _graphOutputs) {
-    _graph->create<mir::ops::OutputOp>(output.op->getName(), output);
-    output.op->setName("");
+    _graph->create<mir::ops::OutputOp>(output->getNode()->getName(), output);
+    output->getNode()->setName("");
   }
 
   return _graph;
index 31f4919..0028ee5 100644 (file)
@@ -39,9 +39,7 @@ public:
 
   void import() {};
   mir::Graph *createIR() override;
-  void dump(const std::vector<mir::IODescriptor>& input_descrs,
-            const std::vector<mir::IODescriptor>& out_descrs,
-            const onnx::NodeProto& onnx_node);
+
   static mir::TensorVariant createTensor(const onnx::TensorProto* tensor);
 
   private:
index f99b01d..83735f9 100644 (file)
@@ -137,7 +137,7 @@ ONNXOpCreator::convertConv2D(const std::vector<mir::IODescriptor>& inputs,
   KernelStridesPadding cdata;
   getKernelStridesPadding(onnx_node, cdata);
   // FIXME: It can be non-constant value.
-  auto* in_weights = dynamic_cast<mir::ops::ConstantOp*>(inputs[1].op);
+  auto* in_weights = dynamic_cast<mir::ops::ConstantOp*>(inputs[1]->getNode());
   assert(in_weights && "Weights could be a constant tensor only");
   const auto& in_weights_tensor = in_weights->getValue();
   // We should transpose ONNX MC(IO)HW to HWOI
@@ -152,7 +152,7 @@ ONNXOpCreator::convertConv2D(const std::vector<mir::IODescriptor>& inputs,
   bool is_depthwise = (num_groups != 1) && (in_group_size == 1) && (out_channels == num_groups);
 
   mir::Operation* result;
-  auto transposed_input = convertONNXToMIR(inputs[0].op->getOutput(0));
+  auto transposed_input = convertONNXToMIR(inputs[0]);
   if (is_depthwise) {
     // TODO handle properly kernel with layer multiplier
     auto transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(kernel_tensor);
@@ -222,7 +222,7 @@ ONNXOpCreator::convertPad(const std::vector<mir::IODescriptor>& inputs,
     vec[i] = pair;
   }
   auto result =
-    createOp<ops::PadOp>(inputs[0], inputs[0].getShape().rank(), vec, scalar);
+    createOp<ops::PadOp>(inputs[0], inputs[0]->getShape().rank(), vec, scalar);
   return {result->getOutput(0)};
 }
 
@@ -243,7 +243,7 @@ ONNXOpCreator::convertPool(const std::vector<mir::IODescriptor>& inputs,
       pool_type = ops::PoolOp::PoolingType::AVG;
       // GlobalAveragePool is equivalent to AveragePool with kernel size equal
       // to the spatial dimension of input tensor
-      cdata.kernel_shape = {t_input.getShape().dim(1), t_input.getShape().dim(2)};
+      cdata.kernel_shape = {t_input->getShape().dim(1), t_input->getShape().dim(2)};
       cdata.strides_shape = {1, 1};
       break;
     }
@@ -279,11 +279,11 @@ ONNXOpCreator::convertSoftmax(const std::vector<mir::IODescriptor>& inputs,
 std::vector<IODescriptor>
 ONNXOpCreator::convertReshape(const std::vector<mir::IODescriptor>& inputs) {
   // The original shape
-  auto in_shape = inputs[0].getShape();
+  const auto& in_shape = inputs[0]->getShape();
 
   // Input tensor describing the new shape
   // TODO: could it be not a constant?
-  auto* op = dynamic_cast<mir::ops::ConstantOp*>(inputs[1].op);
+  auto* op = dynamic_cast<mir::ops::ConstantOp*>(inputs[1]->getNode());
   assert(op && "We support constants only");
   auto shape_tensor = op->getValue();
   Shape shape_tensor_shape = (shape_tensor).getShape();
@@ -315,7 +315,7 @@ ONNXOpCreator::convertUnsqueeze(const std::vector<mir::IODescriptor>& inputs,
                                 const onnx::NodeProto& onnx_node) {
   auto* axes = findAttribute(onnx_node, "axes");
   assert(axes && axes->ints_size());
-  const Shape& input_shape = inputs[0].getShape();
+  const Shape& input_shape = inputs[0]->getShape();
   const int out_rank = input_shape.rank() + axes->ints_size();
   Shape out_shape(out_rank);
   auto ints_iterator = axes->ints().begin();
@@ -365,10 +365,10 @@ ONNXOpCreator::convertUpsample(const std::vector<mir::IODescriptor>& inputs,
 
   // relies on attributes being lifted to constants (ONNX optimization pass)
   assert(inputs.size() > 1);
-  auto* scales = dynamic_cast<mir::ops::ConstantOp*>(inputs[1].op);
+  auto* scales = dynamic_cast<mir::ops::ConstantOp*>(inputs[1]->getNode());
   assert(scales && "Weights could be a constant tensor only");
   auto scales_tensor = Tensor<float>(scales->getValue());
-  int rank = inputs[0].getShape().rank();
+  int rank = inputs[0]->getShape().rank();
   assert(scales_tensor.getShape().numElements() == rank &&
          "The number of elements of 'scales' should be the same as the rank of input 'X'"
   );
@@ -394,10 +394,10 @@ ONNXOpCreator::convertBatchNorm(const std::vector<mir::IODescriptor>& inputs,
   float epsilon = found ? value : 1e-05f;
 
   // TODO: it's better to do it via inputs
-  const auto& scale_tensor = input_tensors.at(inputs[1].op->getName());
-  const auto& bias_tensor = input_tensors.at(inputs[2].op->getName());
-  const auto& mean_tensor = input_tensors.at(inputs[3].op->getName());
-  const auto& var_tensor = input_tensors.at(inputs[4].op->getName());
+  const auto& scale_tensor = input_tensors.at(inputs[1]->getNode()->getName());
+  const auto& bias_tensor = input_tensors.at(inputs[2]->getNode()->getName());
+  const auto& mean_tensor = input_tensors.at(inputs[3]->getNode()->getName());
+  const auto& var_tensor = input_tensors.at(inputs[4]->getNode()->getName());
 
   // res1 = X - mean
   Tensor<float> bias_data(mean_tensor);
@@ -441,7 +441,7 @@ ONNXOpCreator::convertScale(const std::vector<mir::IODescriptor>& inputs,
   float value;
   std::tie(found, value) = getFloatAttribute(onnx_node, "scale");
   float scale_val = found ? value : 1.0;
-  const auto& shape = inputs[0].getShape();
+  const auto& shape = inputs[0]->getShape();
   auto scale_tensor = createTensor(scale_val, shape);
   auto scale = createOp<ops::ConstantOp>(scale_tensor)->getOutput(0);
   auto result = createOp<ops::ScaleOp>(inputs[0], scale);
@@ -450,7 +450,7 @@ ONNXOpCreator::convertScale(const std::vector<mir::IODescriptor>& inputs,
 
 std::vector<IODescriptor>
 ONNXOpCreator::convertShape(const std::vector<mir::IODescriptor>& inputs) {
-  const auto& input_shape = inputs[0].getShape();
+  const auto& input_shape = inputs[0]->getShape();
   int size = input_shape.rank();
   Shape output_shape{size};
   std::vector<float> data(static_cast<std::size_t>(size));
@@ -518,13 +518,13 @@ ONNXOpCreator::convertGemm(const std::vector<mir::IODescriptor>& inputs,
 
   // 1. Prepare input matrix A
   // Flatten the shape by dim(0)
-  const auto& in_shape = inputs[0].getShape();
+  const auto& in_shape = inputs[0]->getShape();
   mir::Shape shape0{in_shape.dim(0), in_shape.numElements() / in_shape.dim(0)};
   auto input_a = createOp<ops::ReshapeOp>(inputs[0], shape0)->getOutput(0);
   if (trans_a)
     input_a = createOp<ops::TransposeOp>(input_a, std::vector<std::size_t>{1, 0})->getOutput(0);
   if (alpha_val != 1.0) {
-    auto alpha_tensor = createTensor(alpha_val, input_a.getShape());
+    auto alpha_tensor = createTensor(alpha_val, input_a->getShape());
     auto alpha = createOp<ops::ConstantOp>(alpha_tensor)->getOutput(0);
     input_a = createOp<ops::ScaleOp>(input_a, alpha)->getOutput(0);
   }
@@ -535,21 +535,21 @@ ONNXOpCreator::convertGemm(const std::vector<mir::IODescriptor>& inputs,
   if (trans_b)
     input_b = createOp<ops::TransposeOp>(input_b, std::vector<std::size_t>{1, 0})->getOutput(0);
   // Number of cols in tensor A must be equal to number of rows in tensor B
-  assert(input_a.getShape().dim(1) == input_b.getShape().dim(0));
-  Shape mult_a_b{input_a.getShape().dim(0), input_b.getShape().dim(1)};
+  assert(input_a->getShape().dim(1) == input_b->getShape().dim(0));
+  Shape mult_a_b{input_a->getShape().dim(0), input_b->getShape().dim(1)};
 
   // 3. Prepare input matrix C
   //
   auto input_c = inputs[2];
-  auto beta_tensor = createTensor(beta_val, input_c.getShape());
-  if ((mult_a_b.rank() == 2) && (input_c.getShape().rank() == 1)) {
+  auto beta_tensor = createTensor(beta_val, input_c->getShape());
+  if ((mult_a_b.rank() == 2) && (input_c->getShape().rank() == 1)) {
     beta_tensor = TensorVariant(beta_tensor, mult_a_b);
   }
   auto beta = createOp<ops::ConstantOp>(beta_tensor)->getOutput(0);
   std::vector<IODescriptor> descriptors = {beta, input_c};
   auto c_mult = createOp<ops::ElementwiseOp>(descriptors,
                                              ops::ElementwiseOp::OpType::mul)->getOutput(0);
-  assert(c_mult.getShape() == mult_a_b);
+  assert(c_mult->getShape() == mult_a_b);
   auto result = createOp<ops::GemmOp>(input_a, input_b, c_mult);
   return {result->getOutput(0)};
 }
index ed3194c..8693de1 100644 (file)
@@ -206,7 +206,7 @@ void CPPCodeGenerator::materializeCall(ostream& out, const ModelAnalyzer& ma,
     return;
   // materialize call
   out << "  " << call->funcName << "(";
-  const auto& prev_nodes = call->mirOp->getPrevNodes();
+  const auto& prev_nodes = call->mirOp->getInputs();
   const auto& out_tensors = call->outputs;
   vector<string> args;
   args.reserve(prev_nodes.size() + out_tensors.size() + 1);
index 9cc1a6f..5d46bad 100644 (file)
@@ -94,9 +94,9 @@ void ModelAnalyzer::appendOperationToInference(
 
   // process operation inputs
   vector<size_t> node_input_tensors;
-  for (const IODescriptor& d: op->getPrevNodes()) {
-    size_t idx = d.index;
-    Operation* prev_op = d.op;
+  for (const auto& input: op->getInputs()) {
+    size_t idx = input.getProducer()->getIndex();
+    const Operation* prev_op = input.getProducer()->getNode();
     assert(_opToDescr.find(prev_op) != _opToDescr.end());
     const CallFunction* call = dynamic_cast<const CallFunction*>(_opToDescr[prev_op]);
     assert(call);
@@ -274,7 +274,13 @@ void ModelAnalyzer::analyze(const mir::Graph* g) {
       auto& top = s.top();
       Operation* node = top.first;
       auto edge = top.second++;
-      auto next_nodes = node->getNextNodes();
+      // FIXME Refactor me.
+      std::vector<Operation*> next_nodes;
+      for (const auto& out : node->getOutputs()) {
+        const auto& consumers = out.getConsumers();
+        std::transform(consumers.begin(), consumers.end(), std::back_inserter(next_nodes),
+                       [](const Operation::Input* input) { return input->getNode(); });
+      }
       if (edge == next_nodes.size()) {
         // this node is fully analyzed, push it into RPO and pop from stack
         post_order.push_back(node);
@@ -361,7 +367,7 @@ void ModelAnalyzer::visit(ops::ConstantOp& op) {
 
   // FIXME This is to work around deserializeTensors not being able to deserialize tensors of type
   // other than float32.
-  if (op.getNextNodes().empty())
+  if (op.getOutput(0)->getConsumers().empty())
     return;
 
   appendOperationToInference(&op, "constant");
index a83dbed..b2fa2cc 100644 (file)
@@ -333,8 +333,8 @@ mir::TensorVariant TfliteImporter::createTensor(const Tensor* t, const Buffer* b
 void TfliteImporter::setGraphOutputs() {
   for (auto output_idx : _graphOutputs) {
     auto output = _tensorMap[output_idx];
-    _graph->create<mir::ops::OutputOp>(output.op->getName(), output);
-    output.op->setName("");
+    _graph->create<mir::ops::OutputOp>(output->getNode()->getName(), output);
+    output->getNode()->setName("");
   }
 }
 
@@ -344,7 +344,7 @@ void TfliteImporter::setIrNodeNames() {
   // turns into IR Conv2D->BiasAdd->ReLU), so not all of the nodes will have names.
   for (auto iter : _tensorMap) {
     const Tensor* tensor = (*_tensors)[iter.first];
-    iter.second.op->setName(tensor->name()->c_str());
+    iter.second->getNode()->setName(tensor->name()->c_str());
   }
 }
 
index 184683d..5b48b3e 100644 (file)
@@ -90,7 +90,7 @@ static std::vector<VectorT> convertIntTensorToVector(const mir::Tensor<int32_t>&
 }
 
 static const mir::TensorVariant& extractTensor(mir::IODescriptor descr) {
-  auto constant_op = dynamic_cast<ops::ConstantOp*>(descr.op);
+  auto constant_op = dynamic_cast<ops::ConstantOp*>(descr->getNode());
   if (constant_op == nullptr)
     throw PassException("Non-constant input is not supported.");
   return constant_op->getValue();
@@ -114,8 +114,8 @@ TFLiteOpCreator::convertConv2D(const Conv2DOptions* opts,
   std::vector<int32_t> padding_before(2);
   std::vector<int32_t> padding_after(2);
 
-  const auto& input_shape = input.getShape();
-  const auto& kernel_shape = kernel.getShape();
+  const auto& input_shape = input->getShape();
+  const auto& kernel_shape = kernel->getShape();
   Shape window_shape{kernel_shape.dim(1), kernel_shape.dim(2)};
   calculatePadding(opts->padding(), input_shape, window_shape,
                    strides, padding_before, padding_after);
@@ -146,8 +146,8 @@ TFLiteOpCreator::convertDepthwiseConv2D(const DepthwiseConv2DOptions* opts,
   std::vector<int32_t> padding_before(2);
   std::vector<int32_t> padding_after(2);
 
-  const auto& input_shape = input.getShape();
-  const auto& kernel_shape = kernel.getShape();
+  const auto& input_shape = input->getShape();
+  const auto& kernel_shape = kernel->getShape();
   Shape window_shape{kernel_shape.dim(0), kernel_shape.dim(1)};
   calculatePadding(opts->padding(), input_shape, window_shape,
                    strides, padding_before, padding_after);
@@ -180,7 +180,7 @@ TFLiteOpCreator::convertMaxPool2D(const ::tflite::Pool2DOptions* opts,
                                   const std::vector<mir::IODescriptor>& inputs) {
   auto input = inputs.at(0);
 
-  const auto& input_shape = input.getShape();
+  const auto& input_shape = input->getShape();
   Shape window_shape{opts->filter_height(), opts->filter_width()};
   Shape strides{opts->stride_h(), opts->stride_w()};
   std::vector<int32_t> padding_before(2);
@@ -200,7 +200,7 @@ TFLiteOpCreator::convertAveragePool2D(const ::tflite::Pool2DOptions* opts,
                                       const std::vector<mir::IODescriptor>& inputs) {
   auto input = inputs.at(0);
 
-  const auto& input_shape = input.getShape();
+  const auto& input_shape = input->getShape();
   Shape window_shape{opts->filter_height(), opts->filter_width()};
   Shape strides{opts->stride_h(), opts->stride_w()};
   std::vector<int32_t> padding_before(2);
@@ -221,7 +221,7 @@ TFLiteOpCreator::convertSoftmax(const ::tflite::SoftmaxOptions* opts,
   auto input = inputs.at(0);
 
   // Softmax in TFLite is always 2-D.
-  assert(input.getShape().rank() == 2);
+  assert(input->getShape().rank() == 2);
   const int32_t axis = 1;
   auto result = createOp<ops::SoftmaxOp>(input, axis);
   return {result->getOutput(0)};
@@ -284,7 +284,7 @@ TFLiteOpCreator::convertResizeNearestNeighbor(const ::tflite::ResizeNearestNeigh
   auto input = inputs.at(0);
   mir::Tensor<int32_t> size_tensor(extractTensor(inputs.at(1)));
 
-  const auto& input_shape = input.getShape();
+  const auto& input_shape = input->getShape();
   Shape res_shape{input_shape.dim(0),
                   size_tensor.at(mir::Index{0}),
                   size_tensor.at(mir::Index{1}),
@@ -337,7 +337,7 @@ TFLiteOpCreator::convertFullyConnected(const ::tflite::FullyConnectedOptions* op
   auto bias = inputs.at(2);
 
   // Flatten input to 2-D shape.
-  const auto& input_shape = input.getShape();
+  const auto& input_shape = input->getShape();
   int32_t outer_size = input_shape.dim(0);
   int32_t inner_size = input_shape.numElements() / outer_size;
   auto flatten = createOp<ops::ReshapeOp>(input, Shape{outer_size, inner_size});
@@ -395,7 +395,7 @@ TFLiteOpCreator::convertPad(const ::tflite::PadOptions* opts,
   auto input = inputs.at(0);
   mir::Tensor<int32_t> paddings_tensor(extractTensor(inputs.at(1)));
 
-  const auto& input_shape = input.getShape();
+  const auto& input_shape = input->getShape();
   int32_t num_dims = input_shape.rank();
 
   std::vector<std::pair<int32_t, int32_t>> paddings;
@@ -488,7 +488,7 @@ TFLiteOpCreator::convertStridedSlice(const ::tflite::StridedSliceOptions* opts,
   int32_t end_mask = opts->end_mask();
   int32_t shrink_axis_mask = opts->shrink_axis_mask();
 
-  const auto& input_shape = input.getShape();
+  const auto& input_shape = input->getShape();
   int32_t num_dims = input_shape.rank();
 
   for (int32_t stride : strides) {
index 7b852cc..67c3051 100644 (file)
@@ -38,14 +38,13 @@ TEST(NodeMutatorTest, SimpleChainTest) {
   auto n5 = g->create<ops::ReluOp>("op5", n1->getOutput(0));
 
   g->replaceNode(n2, n5);
-  delete n2;
 
   std::stringstream ss;
   DumpVisitor d(ss);
   g->accept(&d);
 
   auto str = ss.str();
-  ASSERT_EQ(str, "iop1rop5rop3rop4");
+  ASSERT_TRUE(str == "iop1rop5rop3rop4" || str == "iop1rop5rop4rop3") << "str = " << str;
   delete g;
 }
 
index ad52172..9bbeb93 100644 (file)
@@ -31,7 +31,7 @@ TEST(Operation, ConnectionTest) {
   auto op2 = new ops::ReshapeOp(op1->getOutput(0), Shape{});
   op2->setId(1);
 
-  ASSERT_EQ(op1->getId(), op2->getInput(0).op->getId());
+  ASSERT_EQ(op1, op2->getInput(0)->getProducer()->getNode());
 
   delete op1;
   delete op2;
index 3557d15..daf1b0e 100644 (file)
@@ -221,7 +221,8 @@ getReferenceTensor(mir::Graph& g,
   for (int i = 0; i < static_cast<int>(input_ntensors.size()); ++i)
     interpreter.setInput("x" + to_string(i), *input_ntensors[i]);
   g.accept(&interpreter);
-  return interpreter.getResult(g.getOutputs()[0]->getInput(0));
+  const auto* output_op = g.getOutputs()[0];
+  return interpreter.getResult(output_op->getInput(0)->getProducer());
 };
 
 /**
index 7401719..1b3c36c 100644 (file)
@@ -66,9 +66,9 @@ TEST(ModelAnalyzer, linearization) {
   ASSERT_EQ(seq.size(), 6u);
   auto it = seq.begin();
   ASSERT_EQ(getCall(*(it++))->mirOp, input);
-  ASSERT_EQ(getCall(*(it++))->mirOp, head2);
-  ASSERT_EQ(getCall(*(it++))->mirOp, tail2);
   ASSERT_EQ(getCall(*(it++))->mirOp, head1);
   ASSERT_EQ(getCall(*(it++))->mirOp, tail1);
+  ASSERT_EQ(getCall(*(it++))->mirOp, head2);
+  ASSERT_EQ(getCall(*(it++))->mirOp, tail2);
   ASSERT_EQ(getCall(*(it++))->mirOp, join);
 }