[neurun] Get only CPU shape as Input/Output (#3208)
author김수진/동작제어Lab(SR)/Engineer/삼성전자 <sjsujin.kim@samsung.com>
Wed, 17 Oct 2018 05:14:31 +0000 (14:14 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 17 Oct 2018 05:14:31 +0000 (14:14 +0900)
This commit gets only CPU shape as Input/Output(it would insert Permute ops).

Signed-off-by: sjsujinkim <sjsujin.kim@samsung.com>
runtimes/neurun/src/graph/Graph.cc
runtimes/neurun/src/graph/Graph.h

index 3f9da84..40a0682 100644 (file)
@@ -69,6 +69,12 @@ operation::Index Graph::insertPermute(const operand::Index &operand_index,
   // Generate output operand and permute operation
   auto out_operand_index = addOperand(operand.shape(), operand.typeInfo());
   auto &out_operand = _operands.at(out_operand_index);
+  // change model output if operand_index is model output index
+  auto &model_outputs = getOutputs();
+  if (model_outputs.contains(operand_index))
+  {
+    model_outputs.replace(operand_index, out_operand_index);
+  }
   out_operand.setAsOperationOutput();
   auto out_operand_li = nnfw::make_unique<operand::LowerInfo>(operand::asShape4D(operand.shape()));
   out_operand_li->addDefBackend(backend);
@@ -168,6 +174,27 @@ void Graph::lower(void)
       }
     });
 
+    // Add def backend to model input/output operand as default backend
+    for (uint32_t n = 0; n < getInputs().size(); ++n)
+    {
+      operand::IO::Index input_index{n};
+
+      operand::Index index{getInputs().at(input_index)};
+      auto &&lower_info = operands_lower_info.at(index);
+
+      lower_info->addDefBackend(_backend_resolver->getDefaultBackend());
+    }
+
+    for (uint32_t n = 0; n < getOutputs().size(); ++n)
+    {
+      operand::IO::Index output_index{n};
+
+      operand::Index index{getOutputs().at(output_index)};
+      auto &&lower_info = operands_lower_info.at(index);
+
+      lower_info->addUseBackend(_backend_resolver->getDefaultBackend());
+    }
+
     // Set LowerInfo for each operand from the operand::LowerInfo holder
     _operands.iterate([&](const operand::Index &index, operand::Object &object) {
       object.lower_info(std::move(operands_lower_info[index]));
index 304b9ea..3632f7b 100644 (file)
@@ -119,6 +119,7 @@ private:
 public:
   const operand::IndexSet &getInputs() const { return _inputs; }
   const operand::IndexSet &getOutputs() const { return _outputs; }
+  operand::IndexSet &getOutputs() { return _outputs; }
   const operand::Set &operands() const { return _operands; }
   operand::Set &operands() { return _operands; } // TODO Remove this non-const accessor
   const operation::Set &operations() const { return _operations; }