#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
-#include <fcntl.h>
#include <cerrno>
#include <cstring>
#include <fstream>
#include <memory>
+#include <fcntl.h>
+#include <map>
using namespace std;
using namespace nncc::contrib;
static const char *cpp_header_types =
#include "cpp_header_types.def"
;
+
+ static const char *cpp_operations =
+ #include "cpp_operations.def"
+ ;
#undef S
void CPPCodeGenerator::materializeHeader(ostream &out, const ModelAnalyzer &ma)
out << "};\n";
}
+// print allocation of temporary tensors
+static void printTmpTensors(ostream &out, const ModelAnalyzer::OpDescr &op)
+{
+ for (const ModelAnalyzer::TensorDescription &td: op._outputs)
+ {
+ if (td._isNNOutput)
+ continue;
+ out << " Tensor " << td._name << ";\n";
+ }
+}
+
+// generate function output arguments
+static void gatherOperationCallOutputs(const ModelAnalyzer::OpDescr &op, vector<string> &args)
+{
+ for (const ModelAnalyzer::TensorDescription &td: op._outputs)
+ {
+ const string &tensorName = td._name;
+ if (td._isNNOutput)
+ args.push_back("*" + tensorName);
+ else
+ args.push_back(tensorName);
+ }
+}
+
+// generate function input arguments
+static void gatherOperationCallInputs(const ModelAnalyzer::OpDescr &op, map<INode*,
+ const ModelAnalyzer::OpDescr *> &node2Descr, vector<string> &args)
+{
+ for (const INode::IODescriptor &d: op._node->getPrevNodes())
+ {
+ size_t idx = d.index;
+ INode *node = d.node;
+ assert(node2Descr.find(node) != node2Descr.end());
+ const ModelAnalyzer::OpDescr &descr = *node2Descr[node];
+ const ModelAnalyzer::TensorDescription &tDescr = descr._outputs[idx];
+ const string &tensorName = tDescr._name;
+ if (tDescr._isNNOutput)
+ args.push_back("*" + tensorName);
+ else
+ args.push_back(tensorName);
+ }
+}
+
+// print operation call arguments
+static void printOperationArgs(ostream &out, const vector<string> &args)
+{
+ bool insertComma = false;
+ for (const string &arg: args)
+ {
+ if (insertComma)
+ out << ", ";
+ insertComma = true;
+ out << arg;
+ }
+}
+
+// generate inference sequence
+static void materializeCPPInferenceSequence(ostream &out, const ModelAnalyzer &ma)
+{
+ using OpDescr = ModelAnalyzer::OpDescr;
+ map<INode*, const OpDescr *> node2Descr;
+ for (const ModelAnalyzer::OpDescr &op: ma.getInferenceSequence())
+ {
+ node2Descr.insert(pair<INode *, const OpDescr *>(op._node, &op));
+ using Type = OpDescr::Type;
+ using TensorDescription = ModelAnalyzer::TensorDescription;
+ if (op._type == Type::IN)
+ continue;
+ // create temporary tensors
+ printTmpTensors(out, op);
+ // materialize call
+ out << " " << op._opName << "(";
+ const auto &prevNodes = op._node->getPrevNodes();
+ const auto &outTensors = op._outputs;
+ vector<string> args;
+ args.reserve(prevNodes.size() + outTensors.size() + 1);
+ // gather output arguments
+ gatherOperationCallOutputs(op, args);
+ // parameters offset
+ args.push_back(to_string(op._paramStartOffset));
+ // gather input arguments
+ gatherOperationCallInputs(op, node2Descr, args);
+ // put arguments into stream
+ printOperationArgs(out, args);
+ out << ");\n";
+ }
+}
+
+// TODO think about better string formatting to make code more readable
void CPPCodeGenerator::materializeCode(ostream &out, const ModelAnalyzer &ma)
{
- // TODO emit C++ code to out stream
+ string className = ma.getModelName() + "Model";
+ out << cpp_operations;
+
+ // gen NN constructor
+ out << className << "::" << className << "(const string ¶metersPath)\n"
+ "{\n"
+ " readParameters(_parameters, parametersPath, " <<
+ ma.getFormatVersion() << ", " << ma.getModelHash() << ");"
+ "}\n";
+
+ // gen input setters
+ for (const string &inName: ma.getInputs())
+ {
+ out << "void " << className << "::set_" << inName << "(const Tensor& t)\n"
+ "{\n"
+ " _" << inName << " = t;"
+ "}\n";
+ }
+
+ // gen output getters
+ for (const string &outName: ma.getOutputs())
+ {
+ out << "shared_ptr<Tensor> " << className <<"::get_" << outName << "()\n"
+ "{\n"
+ " return _" << outName << ";"
+ "}\n";
+ }
+ out << "void " << className << "::doInference()\n"
+ "{\n";
+ for (const string &outName: ma.getOutputs())
+ {
+ out << " _" << outName << ".reset(new Tensor());\n";
+ }
+
+ // gen inference sequence
+ materializeCPPInferenceSequence(out, ma);
+ out << "}";
}
} // namespace soft