From 8b93a927fe50af4c3464cd8eef7fb525e1815c05 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EA=B9=80=EC=88=98=EC=A7=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Wed, 17 Oct 2018 14:14:31 +0900 Subject: [PATCH] [neurun] Get only CPU shape as Input/Output (#3208) This commit gets only CPU shape as Input/Output(it would insert Permute ops). Signed-off-by: sjsujinkim --- runtimes/neurun/src/graph/Graph.cc | 27 +++++++++++++++++++++++++++ runtimes/neurun/src/graph/Graph.h | 1 + 2 files changed, 28 insertions(+) diff --git a/runtimes/neurun/src/graph/Graph.cc b/runtimes/neurun/src/graph/Graph.cc index 3f9da84..40a0682 100644 --- a/runtimes/neurun/src/graph/Graph.cc +++ b/runtimes/neurun/src/graph/Graph.cc @@ -69,6 +69,12 @@ operation::Index Graph::insertPermute(const operand::Index &operand_index, // Generate output operand and permute operation auto out_operand_index = addOperand(operand.shape(), operand.typeInfo()); auto &out_operand = _operands.at(out_operand_index); + // change model output if operand_index is model output index + auto &model_outputs = getOutputs(); + if (model_outputs.contains(operand_index)) + { + model_outputs.replace(operand_index, out_operand_index); + } out_operand.setAsOperationOutput(); auto out_operand_li = nnfw::make_unique(operand::asShape4D(operand.shape())); out_operand_li->addDefBackend(backend); @@ -168,6 +174,27 @@ void Graph::lower(void) } }); + // Add def backend to model input/output operand as default backend + for (uint32_t n = 0; n < getInputs().size(); ++n) + { + operand::IO::Index input_index{n}; + + operand::Index index{getInputs().at(input_index)}; + auto &&lower_info = operands_lower_info.at(index); + + lower_info->addDefBackend(_backend_resolver->getDefaultBackend()); + } + + for (uint32_t n = 0; n < getOutputs().size(); ++n) + { + operand::IO::Index output_index{n}; + + operand::Index index{getOutputs().at(output_index)}; + auto &&lower_info = operands_lower_info.at(index); + + lower_info->addUseBackend(_backend_resolver->getDefaultBackend()); + } + // Set LowerInfo for each operand from the operand::LowerInfo holder _operands.iterate([&](const operand::Index &index, operand::Object &object) { object.lower_info(std::move(operands_lower_info[index])); diff --git a/runtimes/neurun/src/graph/Graph.h b/runtimes/neurun/src/graph/Graph.h index 304b9ea..3632f7b 100644 --- a/runtimes/neurun/src/graph/Graph.h +++ b/runtimes/neurun/src/graph/Graph.h @@ -119,6 +119,7 @@ private: public: const operand::IndexSet &getInputs() const { return _inputs; } const operand::IndexSet &getOutputs() const { return _outputs; } + operand::IndexSet &getOutputs() { return _outputs; } const operand::Set &operands() const { return _operands; } operand::Set &operands() { return _operands; } // TODO Remove this non-const accessor const operation::Set &operations() const { return _operations; } -- 2.7.4