overview("interpreter option: set input node in Computational Graph"),
std::string(),
optional(true));
-Option<std::string> interOutNode(optname("--output-node"),
+Option<std::vector<std::string>> interOutNode(optname("--output-node"),
overview("interpreter option: set output node in Computational Graph"),
- std::string(),
+ std::vector<std::string>{},
optional(true));
-Option<std::string> interResFileName(optname("--res-filename"),
- overview("interpreter option: file for result tensor"),
- std::string(),
- optional(true));
} // namespace clopt
} // namespace contrib
* @param tensorName - name, by wich tensor will be saved
* @param fileName - path to file, in which tensor will be saved
*/
-static void writeTensorToHDF5File(TensorVariant *tensor, std::string tensorName, std::string fileName)
+static void writeTensorToHDF5File(TensorVariant *tensor, std::string tensorName, std::string destPath)
{
// Prepare shape, rank, dims, numElems
Shape shape = tensor->getShape();
// Backslashes are not allowed in tensor names
std::replace(tensorName.begin(), tensorName.end(), '/', '_');
+ std::string fileName = destPath + "/" + tensorName + ".hdf5";
// Write to .hdf5 file
H5::H5File h5File(fileName, H5F_ACC_TRUNC);
throw PassException("input node <" + clopt::interInNode +"> not found" );
}
- auto outputNode = g->getOutput(clopt::interOutNode);
- if (outputNode == nullptr) {
- throw PassException("output node <" + clopt::interOutNode +"> not found" );
- }
auto input = loadInput(inputNode->getOperation()->getOutputShape(0));
interpreter.setInput(clopt::interInNode, input);
g->accept(&interpreter);
- _out = new TensorVariant(interpreter.getResult(outputNode)[0]);
+ // Check nodes
+ for (auto &tensorName : clopt::interOutNode) {
+ auto outputNode = interpreter.getOperationResult(tensorName);
+ if (outputNode.empty()) {
+ throw PassException("output node <" + tensorName + "> not found");
+ } else {
+ std::cout << "OutNode <" + tensorName + "> found" << std::endl;
+ }
+ }
+
+ bool is_several_outs = (clopt::interOutNode.size() > 1);
+
+ nncc::contrib::core::ADT::TensorVariant *out = nullptr;
+ for (auto &tensorName : clopt::interOutNode) {
+ out = new TensorVariant(interpreter.getOperationResult(tensorName)[0]);
#ifdef NNC_HDF5_SUPPORTED
- std::string outputFile = clopt::interResFileName.empty() ? "out.hdf5" : clopt::interResFileName.getRawValue();
- writeTensorToHDF5File(_out, clopt::interOutNode, outputFile);
+ writeTensorToHDF5File(out, tensorName, clopt::artifactDir);
#else
- std::cout << "Result wasn't saved, due to lack of HDF5" << std::endl;
+ std::cout << "Result <" << tensorName << "> wasn't saved, due to lack of HDF5" << std::endl;
#endif // NNC_HDF5_SUPPORTED
+ if ( is_several_outs )
+ delete out;
+ }
+
+ _out = is_several_outs ? nullptr : out;
return _out;
}