#include <moco/tf/Names.h>
#include <loco.h>
+#include <angkor/TensorShape.h>
#include <tensorflow/core/framework/graph.pb.h>
namespace tf
{
+using TensorShape = angkor::TensorShape;
+
+/**
+ * @brief Class to store information to run a model. Normally this info comes from users
+ * via CLI params or configuration file.
+ */
struct ModelSignature
{
public:
const std::vector<TensorName> &inputs() const { return _inputs; }
const std::vector<TensorName> &outputs() const { return _outputs; }
+ /**
+ * @brief Adds customop op type (not name of node) provided from user
+ */
+ void add_customop(const std::string &op);
+ const std::vector<std::string> &customops() const { return _customops; }
+
+ /**
+ * @brief Adds node name and its shape provided from user
+ */
+ void shape(const std::string &node_name, const TensorShape &shape);
+ const TensorShape *shape(const std::string &node_name) const;
+
+ /**
+ * @brief Adds node name and its dtype provided from user
+ */
+ void dtype(const std::string &node_name, loco::DataType dtype);
+ loco::DataType dtype(const std::string &node_name) const;
+
private:
std::vector<TensorName> _inputs; // graph inputs
std::vector<TensorName> _outputs; // graph outputs
+
+ // For custom op types passed from user (e.g., via CLI)
+ std::vector<std::string> _customops;
+
+ // For and node names and shapes passed from user (e.g., via CLI)
+ std::map<std::string, TensorShape> _shapes;
+
+ // For and node names and dtype passed from user (e.g., via CLI)
+ std::map<std::string, loco::DataType> _dtypes;
};
class Frontend
namespace tf
{
+void ModelSignature::add_customop(const std::string &op)
+{
+ if (std::find(_customops.begin(), _customops.end(), op) == _customops.end())
+ _customops.emplace_back(op);
+ else
+ throw std::runtime_error{"Duplicated custom op: " + op};
+}
+
+void ModelSignature::shape(const std::string &node_name,
+ const nncc::core::ADT::tensor::Shape &shape)
+{
+ if (_shapes.find(node_name) != _shapes.end())
+ throw std::runtime_error{"Duplicated node name: " + node_name};
+
+ _shapes[node_name] = shape;
+}
+
+const nncc::core::ADT::tensor::Shape *ModelSignature::shape(const std::string &node_name) const
+{
+ auto res = _shapes.find(node_name);
+ if (res == _shapes.end())
+ return nullptr;
+ else
+ return &res->second;
+}
+
+void ModelSignature::dtype(const std::string &node_name, loco::DataType dtype)
+{
+ if (_dtypes.find(node_name) != _dtypes.end())
+ throw std::runtime_error{"Duplicated node name: " + node_name};
+
+ _dtypes[node_name] = dtype;
+}
+
+loco::DataType ModelSignature::dtype(const std::string &node_name) const
+{
+ auto res = _dtypes.find(node_name);
+ if (res == _dtypes.end())
+ return loco::DataType::Unknown;
+ else
+ return res->second;
+}
+
Frontend::Frontend()
{
// DO NOTHING