feature: add multiple interpreter outputs (#1543)
authorVitaliy Cherepanov/AI Tools Lab /SRR/Engineer/삼성전자 <v.cherepanov@samsung.com>
Mon, 17 Sep 2018 19:15:57 +0000 (22:15 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Mon, 17 Sep 2018 19:15:57 +0000 (22:15 +0300)
--res-filename removed redundant
--output-dir used for outputs
--output-node can take multiple values

Signed-off-by: Vitaliy Cherepanov <v.cherepanov@samsung.com>
contrib/nnc/driver/Options.cpp
contrib/nnc/include/option/Options.h
contrib/nnc/passes/interpreter/interpreter_pass.cpp

index d1b16fe..5b8e70b 100644 (file)
@@ -106,14 +106,10 @@ Option<std::string> interInNode(optname("--input-node"),
                                 overview("interpreter option: set input node in Computational Graph"),
                                 std::string(),
                                 optional(true));
-Option<std::string> interOutNode(optname("--output-node"),
+Option<std::vector<std::string>> interOutNode(optname("--output-node"),
                                  overview("interpreter option: set output node in Computational Graph"),
-                                 std::string(),
+                                 std::vector<std::string>{},
                                  optional(true));
-Option<std::string> interResFileName(optname("--res-filename"),
-                                     overview("interpreter option: file for result tensor"),
-                                     std::string(),
-                                     optional(true));
 
 } // namespace clopt
 } // namespace contrib
index 6c0fc70..57afbd9 100644 (file)
@@ -40,8 +40,7 @@ extern Option<std::string> artifactName;  // name of artifact
  */
 extern Option<std::string> interInputData;  // input data for model
 extern Option<std::string> interInNode;     // name of input node in computational graph
-extern Option<std::string> interOutNode;    // name of output node in computational graph
-extern Option<std::string> interResFileName;// output file
+extern Option<std::vector<std::string>> interOutNode;    // name of output nodes in computational graph
 
 } // namespace clopt
 } // namespace contrib
index 5d8edf6..79e3c55 100644 (file)
@@ -55,7 +55,7 @@ Pass &InterpreterPass::getInstance() {
  * @param tensorName - name, by wich tensor will be saved
  * @param fileName - path to file, in which tensor will be saved
  */
-static void writeTensorToHDF5File(TensorVariant *tensor, std::string tensorName, std::string fileName)
+static void writeTensorToHDF5File(TensorVariant *tensor, std::string tensorName, std::string destPath)
 {
   // Prepare shape, rank, dims, numElems
   Shape shape = tensor->getShape();
@@ -77,6 +77,7 @@ static void writeTensorToHDF5File(TensorVariant *tensor, std::string tensorName,
 
   // Backslashes are not allowed in tensor names
   std::replace(tensorName.begin(), tensorName.end(), '/', '_');
+  std::string fileName = destPath + "/" + tensorName + ".hdf5";
 
   // Write to .hdf5 file
   H5::H5File h5File(fileName, H5F_ACC_TRUNC);
@@ -105,23 +106,37 @@ PassData InterpreterPass::run(PassData data)
     throw PassException("input node <" + clopt::interInNode +"> not found" );
   }
 
-  auto outputNode = g->getOutput(clopt::interOutNode);
-  if (outputNode == nullptr) {
-    throw PassException("output node <" + clopt::interOutNode +"> not found" );
-  }
 
   auto input = loadInput(inputNode->getOperation()->getOutputShape(0));
   interpreter.setInput(clopt::interInNode, input);
   g->accept(&interpreter);
 
-  _out = new TensorVariant(interpreter.getResult(outputNode)[0]);
+  // Check nodes
+  for (auto &tensorName : clopt::interOutNode) {
+    auto outputNode = interpreter.getOperationResult(tensorName);
+    if (outputNode.empty()) {
+      throw PassException("output node <" + tensorName + "> not found");
+    } else {
+      std::cout << "OutNode <" + tensorName + "> found" << std::endl;
+    }
+  }
+  
+  bool is_several_outs = (clopt::interOutNode.size() > 1);
+
+  nncc::contrib::core::ADT::TensorVariant *out = nullptr;
+  for (auto &tensorName : clopt::interOutNode) {
+    out = new TensorVariant(interpreter.getOperationResult(tensorName)[0]);
 
 #ifdef NNC_HDF5_SUPPORTED
-  std::string outputFile = clopt::interResFileName.empty() ? "out.hdf5" : clopt::interResFileName.getRawValue();
-  writeTensorToHDF5File(_out, clopt::interOutNode, outputFile);
+    writeTensorToHDF5File(out, tensorName, clopt::artifactDir);
 #else
-  std::cout << "Result wasn't saved, due to lack of HDF5" << std::endl;
+    std::cout << "Result <" << tensorName << "> wasn't saved, due to lack of HDF5" << std::endl;
 #endif  // NNC_HDF5_SUPPORTED
+    if ( is_several_outs )
+      delete out;
+  }
+
+  _out = is_several_outs ? nullptr : out;
 
   return _out;
 }