[mir_caffe2] Inherit tensor names (#6853)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Thu, 22 Aug 2019 15:11:47 +0000 (00:11 +0900)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 22 Aug 2019 15:11:47 +0000 (18:11 +0300)
Set `Operation::Output`s names based on model tensor names.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir-caffe2-importer/caffe2_importer.cpp
compiler/mir-caffe2-importer/caffe2_importer.h

index 2d95146..7432589 100644 (file)
@@ -99,7 +99,8 @@ std::unique_ptr<mir::Graph> Caffe2Importer::createIR()
   // TODO Caffe2 does not provide a way to detect model inputs. For now assume that the first input
   // of the first operation is the only input to the model.
   const auto &input_name = _predict_net->op(0).input(0);
-  _blobNameToOutput[input_name] = _opCreator->createInput(input_name, _inputShapes[0]);
+  auto input = _opCreator->createInput(input_name, _inputShapes[0]);
+  setOutputForTensor(input_name, input);
 
   for (const auto &op : _predict_net->op())
     createMIRNodesFromOp(op);
@@ -201,9 +202,7 @@ void Caffe2Importer::createMIRNodesFromOp(const OperatorDef &op)
 
   for (size_t i = 0; i < outputs.size(); ++i)
   {
-    // caffe2 input blob name could be same as output blob name, and next line will overwrite
-    // '_blobNameToOpOutput' element, but in all networks that I saw it was not a problem
-    _blobNameToOutput[op.output(i)] = outputs.at(i);
+    setOutputForTensor(op.output(i), outputs[i]);
   }
 
   // `outputs` can be empty if constant input was not processed.
@@ -219,14 +218,30 @@ std::vector<mir::Operation::Output *> Caffe2Importer::getInputMIROps(const Opera
 
   for (const auto &input_name : op.input())
   {
-    if (_blobNameToOutput.find(input_name) == _blobNameToOutput.end())
-      throw std::runtime_error("Cannot find blob \"" + input_name + "\".");
-    inputs.push_back(_blobNameToOutput[input_name]);
+    inputs.push_back(getOutputForTensor(input_name));
   }
 
   return inputs;
 }
 
+void Caffe2Importer::setOutputForTensor(const std::string &tensor_name, Operation::Output *output)
+{
+  auto it = _blobNameToOutput.find(tensor_name);
+  if (it != _blobNameToOutput.cend())
+  {
+    // caffe2 input blob name could be same as output blob name, and next line will overwrite
+    // '_blobNameToOpOutput' element, but in all networks that I saw it was not a problem
+    it->second->setName("");
+  }
+  output->setName(tensor_name);
+  _blobNameToOutput[tensor_name] = output;
+}
+
+mir::Operation::Output *Caffe2Importer::getOutputForTensor(const std::string &name) const
+{
+  return _blobNameToOutput.at(name);
+}
+
 void Caffe2Importer::setGraphOutputs()
 {
   // For now, we assume that:
index 782462d..ccff1c9 100644 (file)
@@ -73,6 +73,9 @@ private:
    */
   std::vector<mir::Operation::Output *> getInputMIROps(const ::caffe2::OperatorDef &op);
 
+  void setOutputForTensor(const std::string &tensor_name, Operation::Output *output);
+  mir::Operation::Output *getOutputForTensor(const std::string &name) const;
+
   /**
    * @brief Mark output MIR nodes
    */