From a438c08169dc730483b7d7573de41c7265061389 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 23 Aug 2019 00:11:47 +0900 Subject: [PATCH] [mir_caffe2] Inherit tensor names (#6853) Set `Operation::Output`s names based on model tensor names. Signed-off-by: Sergei Barannikov --- compiler/mir-caffe2-importer/caffe2_importer.cpp | 29 ++++++++++++++++++------ compiler/mir-caffe2-importer/caffe2_importer.h | 3 +++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/compiler/mir-caffe2-importer/caffe2_importer.cpp b/compiler/mir-caffe2-importer/caffe2_importer.cpp index 2d95146..7432589 100644 --- a/compiler/mir-caffe2-importer/caffe2_importer.cpp +++ b/compiler/mir-caffe2-importer/caffe2_importer.cpp @@ -99,7 +99,8 @@ std::unique_ptr Caffe2Importer::createIR() // TODO Caffe2 does not provide a way to detect model inputs. For now assume that the first input // of the first operation is the only input to the model. const auto &input_name = _predict_net->op(0).input(0); - _blobNameToOutput[input_name] = _opCreator->createInput(input_name, _inputShapes[0]); + auto input = _opCreator->createInput(input_name, _inputShapes[0]); + setOutputForTensor(input_name, input); for (const auto &op : _predict_net->op()) createMIRNodesFromOp(op); @@ -201,9 +202,7 @@ void Caffe2Importer::createMIRNodesFromOp(const OperatorDef &op) for (size_t i = 0; i < outputs.size(); ++i) { - // caffe2 input blob name could be same as output blob name, and next line will overwrite - // '_blobNameToOpOutput' element, but in all networks that I saw it was not a problem - _blobNameToOutput[op.output(i)] = outputs.at(i); + setOutputForTensor(op.output(i), outputs[i]); } // `outputs` can be empty if constant input was not processed. @@ -219,14 +218,30 @@ std::vector Caffe2Importer::getInputMIROps(const Opera for (const auto &input_name : op.input()) { - if (_blobNameToOutput.find(input_name) == _blobNameToOutput.end()) - throw std::runtime_error("Cannot find blob \"" + input_name + "\"."); - inputs.push_back(_blobNameToOutput[input_name]); + inputs.push_back(getOutputForTensor(input_name)); } return inputs; } +void Caffe2Importer::setOutputForTensor(const std::string &tensor_name, Operation::Output *output) +{ + auto it = _blobNameToOutput.find(tensor_name); + if (it != _blobNameToOutput.cend()) + { + // caffe2 input blob name could be same as output blob name, and next line will overwrite + // '_blobNameToOpOutput' element, but in all networks that I saw it was not a problem + it->second->setName(""); + } + output->setName(tensor_name); + _blobNameToOutput[tensor_name] = output; +} + +mir::Operation::Output *Caffe2Importer::getOutputForTensor(const std::string &name) const +{ + return _blobNameToOutput.at(name); +} + void Caffe2Importer::setGraphOutputs() { // For now, we assume that: diff --git a/compiler/mir-caffe2-importer/caffe2_importer.h b/compiler/mir-caffe2-importer/caffe2_importer.h index 782462d..ccff1c9 100644 --- a/compiler/mir-caffe2-importer/caffe2_importer.h +++ b/compiler/mir-caffe2-importer/caffe2_importer.h @@ -73,6 +73,9 @@ private: */ std::vector getInputMIROps(const ::caffe2::OperatorDef &op); + void setOutputForTensor(const std::string &tensor_name, Operation::Output *output); + mir::Operation::Output *getOutputForTensor(const std::string &name) const; + /** * @brief Mark output MIR nodes */ -- 2.7.4