namespace tf
{
+struct ModelSignature
+{
+public:
+ void add_input(const std::string &input) { _inputs.push_back(input); }
+ void add_output(const std::string &output) { _outputs.push_back(output); }
+
+ const std::vector<std::string> &inputs() const { return _inputs; }
+ const std::vector<std::string> &outputs() const { return _outputs; }
+
+private:
+ std::vector<std::string> _inputs; // graph inputs
+ std::vector<std::string> _outputs; // graph outputs
+};
+
class Frontend
{
public:
Frontend();
public:
- void add_input(const std::string &input) { _inputs.push_back(input); }
- void add_output(const std::string &output) { _outputs.push_back(output); }
-
-public:
- std::unique_ptr<loco::Graph> load(const char *, FileType) const;
-
-private:
- std::vector<std::string> _inputs; // graph inputs
- std::vector<std::string> _outputs; // graph outputs
+ std::unique_ptr<loco::Graph> load(const ModelSignature &, const char *, FileType) const;
};
} // namespace tf
// DO NOTHING
}
-std::unique_ptr<loco::Graph> Frontend::load(const char *modelfile, FileType type) const
+std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, const char *modelfile,
+ FileType type) const
{
tensorflow::GraphDef tf_graph_def;
TEST(TensorFlowFrontend, load_model)
{
moco::tf::Frontend frontend;
+ moco::tf::ModelSignature signature;
// TODO fix not to use "../../.."
- frontend.load("../../../test/tf/Placeholder_000.pbtxt", moco::tf::Frontend::FileType::Text);
- frontend.load("../../../test/tf/Placeholder_000.pb", moco::tf::Frontend::FileType::Binary);
+ frontend.load(signature, "../../../test/tf/Placeholder_000.pbtxt",
+ moco::tf::Frontend::FileType::Text);
+ frontend.load(signature, "../../../test/tf/Placeholder_000.pb",
+ moco::tf::Frontend::FileType::Binary);
}