namespace nntrainer {
void GraphCore::addGraphNode(std::shared_ptr<GraphNode> node) {
- node->setIndex(node_list.size());
node_list.push_back(node);
+ node_map[node->getName()] = node_list.size() - 1;
}
const std::shared_ptr<GraphNode> &GraphCore::getNode(unsigned int ith) const {
- if (ith >= size())
- throw std::invalid_argument("Exceed total number of nodes");
-
- if (node_list[ith]->getIndex() != ith)
- throw std::runtime_error("Graph internal index mismatch");
-
- return node_list[ith];
+ return node_list.at(ith);
}
const std::shared_ptr<GraphNode> &
GraphCore::getSortedNode(unsigned int ith) const {
- if (ith >= Sorted.size())
- throw std::invalid_argument("Exceed total number of nodes");
-
- return Sorted[ith];
+ return Sorted.at(ith);
}
void GraphCore::makeAdjacencyList(
/** make the connections */
for (auto &node : node_list) {
for (auto const &in_conn : node->getInputConnections()) {
- unsigned int to_node_id = getNode(in_conn)->getIndex();
+ unsigned int to_node_id = getNodeIdx(in_conn);
adj[to_node_id].push_back(node);
}
}
std::list<std::shared_ptr<GraphNode>>::iterator i;
for (i = adj[ith].begin(); i != adj[ith].end(); ++i) {
- auto index = (*i)->getIndex();
+ auto index = getNodeIdx((*i)->getName());
if (!visited[index])
topologicalSortUtil(adj, index, visited, dfs_stack);
}
const std::shared_ptr<GraphNode> &
GraphCore::getNode(const std::string &name) const {
- for (auto &lnode : node_list) {
- if (istrequal(lnode->getName(), name))
- return lnode;
- }
-
- std::stringstream ss;
- ss << "Cannot find graph node: " << name;
- throw std::invalid_argument(ss.str());
+ return node_list.at(node_map.at(name));
}
void GraphCore::addNode(std::shared_ptr<GraphNode> node, bool ensure_name) {
addGraphNode(node);
}
-void GraphCore::ensureName(GraphNode &node, const std::string &prefix,
- const std::string &postfix, bool force_rename) {
- std::string orig_name = node.getName();
+void GraphCore::ensureName(GraphNode &node, const std::string &prefix_,
+ const std::string &postfix_, bool force_rename) {
+ auto to_lower = [](const std::string &str) -> std::string {
+ std::string ret = str;
+ ;
+ std::transform(ret.begin(), ret.end(), ret.begin(),
+ [](unsigned char c) { return std::tolower(c); });
+ return ret;
+ };
+
+ std::string orig_name = to_lower(node.getName());
+ std::string prefix = to_lower(prefix_);
+ std::string postfix = to_lower(postfix_);
+
bool orig_name_empty = orig_name.empty();
/** If node already has name which is unique and valid, and force is
* disabled, then nothing to do.
*/
if (!orig_name_empty && !force_rename && !verifyNode(orig_name)) {
+ node.setName(orig_name);
node_names.insert(orig_name);
return;
}
void GraphCore::replaceNode(std::shared_ptr<GraphNode> from,
std::shared_ptr<GraphNode> to) {
- unsigned int idx = from->getIndex();
- to->setIndex(idx);
- node_list[idx] = to;
+ if (node_map.find(from->getName()) == node_map.end())
+ throw std::invalid_argument("Graph node to be replaced is missing");
+ if (node_map.find(to->getName()) != node_map.end())
+ throw std::invalid_argument("Nodes in the graph must be unique");
+
+ unsigned int from_idx = getNodeIdx(from->getName());
+ node_list[from_idx] = to;
+ node_map.erase(from->getName());
+ node_map[to->getName()] = from_idx;
}
void GraphCore::realizeInputOutputNode() {
}
}
+unsigned int GraphCore::getNodeIdx(const std::string &name) {
+ return node_map.at(name);
+}
+
} /* namespace nntrainer */
#include <map>
#include <memory>
#include <stack>
+#include <unordered_map>
#include <unordered_set>
#include <vector>
using std::swap;
swap(lhs.node_list, rhs.node_list);
+ swap(lhs.node_map, rhs.node_map);
swap(lhs.Sorted, rhs.Sorted);
swap(lhs.node_names, rhs.node_names);
swap(lhs.def_name_count, rhs.def_name_count);
*/
void reset() {
node_list.clear();
+ node_map.clear();
Sorted.clear();
node_names.clear();
def_name_count = 0;
std::vector<std::shared_ptr<GraphNode>> output_list;
std::vector<std::shared_ptr<GraphNode>>
node_list; /**< Unordered Node List */
+ std::unordered_map<std::string, int> node_map; /**< Unordered Node map */
std::vector<std::shared_ptr<GraphNode>> Sorted; /**< Ordered Node List */
bool sorted; /** if the node_list is sorted */
*/
void
makeAdjacencyList(std::vector<std::list<std::shared_ptr<GraphNode>>> &adj);
+
+ /**
+ * @brief Get index of the node with given name
+ *
+ * @param name Name of the node
+ * @return internal index of the node
+ */
+ unsigned int getNodeIdx(const std::string &name);
};
} // namespace nntrainer
*/
virtual ~GraphNode() = default;
- /**
- * @brief Get index of the node
- *
- */
- virtual size_t getIndex() const = 0;
-
- /**
- * @brief Set index of the node
- *
- */
- virtual void setIndex(size_t) = 0;
-
/**
* @brief Get the Name of the underlying object
*
}
void NetworkGraph::addLayerNode(std::unique_ptr<Layer> layer) {
- graph.addNode(std::make_unique<LayerNode>(std::move(layer), graph.size()));
+ graph.addNode(std::make_unique<LayerNode>(std::move(layer)));
}
void NetworkGraph::countNonTrainableLayersAtBegin() {
lnode->setInputLayers({second_to_last_layer_node->getName()});
if (is_cross_entropy_loss) {
- lnode->setIndex(output_layer_node->getIndex());
graph.replaceNode(output_layer_node, lnode);
} else {
graph.addNode(lnode, false);
for (unsigned int i = 0; i < input_layers.size(); ++i) {
auto in_layer_node = getLayerNode(input_layers[i]);
- unsigned int location = 0;
- for (unsigned int j = 0; j < in_layer_node->getNumOutputConnections();
- ++j) {
- if (in_layer_node->getOutputLayers()[j] == lnode->getName()) {
- location = j;
- break;
- }
- }
+ auto const &in_layer_out_connect = in_layer_node->getOutputLayers();
+ unsigned int location =
+ std::find(in_layer_out_connect.begin(), in_layer_out_connect.end(),
+ lnode->getName()) -
+ in_layer_out_connect.begin();
lnode->setInputDimension(in_layer_node->getOutputDimensions()[location],
i);
skip_non_trainable_layers = 0;
}
- /**
- * @brief getter of LayerNode with index number
- * @param[in] index
- * @ret LayerNode
- */
- std::shared_ptr<LayerNode> getLayerNode(unsigned int ith) const {
- return std::static_pointer_cast<LayerNode>(graph.getNode(ith));
- }
-
/**
* @brief getter of Sorted LayerNode with index number
* @param[in] index
return lnode;
}
-LayerNode::LayerNode(std::unique_ptr<nntrainer::Layer> &&l, size_t idx) :
+LayerNode::LayerNode(std::unique_ptr<nntrainer::Layer> &&l) :
layer(std::move(l)),
- index(idx),
finalized(false),
activation_type(ActivationType::ACT_NONE),
layer_node_props(new PropsType(props::Name(), props::Flatten(),
/**
* @brief Default constructor
*/
- LayerNode() : LayerNode(nullptr, 0) {}
+ LayerNode() : LayerNode(nullptr) {}
/**
* @brief Constructor of LayerNode class for v2
* @param l layer to wrap with, the ownership is transferred to layer node
*
*/
- LayerNode(std::unique_ptr<nntrainer::Layer> &&l, size_t idx = 0);
+ LayerNode(std::unique_ptr<nntrainer::Layer> &&l);
/**
* @brief Destructor of LayerNode Class
* Support all the interface requirements by nntrainer::GraphNode
*/
- /**
- * @brief Get index of the node
- *
- * @return Index of the node
- */
- size_t getIndex() const { return index; }
-
- /**
- * @brief Set the index for the node
- * @param idx Index for the node
- */
- void setIndex(size_t idx) { index = idx; }
-
/**
* @brief set name of layer
*
* @param layers Name of the layers
*/
void setInputLayers(const std::vector<std::string> &layers) {
- input_layers = layers;
+ auto to_lower = [](const std::string &str) -> std::string {
+ std::string ret = str;
+ ;
+ std::transform(ret.begin(), ret.end(), ret.begin(),
+ [](unsigned char c) { return std::tolower(c); });
+ return ret;
+ };
+
+ input_layers.reserve(layers.size());
+ for (auto const &name : layers)
+ input_layers.push_back(to_lower(name));
resizeInputDimensions(input_layers.size());
}
std::unique_ptr<nntrainer::Layer>
layer; /**< The actual object in the graph node */
- // TODO: possibly remove, two identifiers for the same node (name and
- // index) can lead to issues later
- size_t index; /**< index of each node */
bool finalized; /**< if the layer node has been finalized */
std::vector<std::string> input_layers; /**< input layer names */