$(NNTRAINER_ROOT)/nntrainer/compiler/flatten_realizer.cpp \
$(NNTRAINER_ROOT)/nntrainer/compiler/recurrent_realizer.cpp \
$(NNTRAINER_ROOT)/nntrainer/compiler/remap_realizer.cpp \
+ $(NNTRAINER_ROOT)/nntrainer/compiler/slice_realizer.cpp \
$(NNTRAINER_ROOT)/nntrainer/app_context.cpp
ifeq ($(ENABLE_TFLITE_INTERPRETER), 1)
'ini_interpreter.cpp',
'flatten_realizer.cpp',
'recurrent_realizer.cpp',
- 'remap_realizer.cpp'
+ 'remap_realizer.cpp',
+ 'slice_realizer.cpp'
]
compiler_headers = []
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file slice_realizer.cpp
+ * @date 14 October 2021
+ * @brief NNTrainer graph realizer which slice the graph representation
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+#include <layer_node.h>
+#include <slice_realizer.h>
+
+#include <unordered_map>
+
+namespace nntrainer {
+
+SliceRealizer::SliceRealizer(const std::vector<std::string> &start_layers,
+ const std::vector<std::string> &end_layers) :
+ start_layers(start_layers),
+ end_layers(end_layers.begin(), end_layers.end()) {}
+
+SliceRealizer::~SliceRealizer() {}
+
+GraphRepresentation
+SliceRealizer::realize(const GraphRepresentation &reference) {
+ struct NodeInfo {
+ NodeInfo() : NodeInfo(nullptr) {}
+ NodeInfo(std::shared_ptr<LayerNode> node) :
+ node(node),
+ is_visited(false),
+ is_added(false) {}
+ std::shared_ptr<LayerNode> node; /**< set this if not visited */
+ bool is_visited; /**< set this if visited */
+ bool is_added; /**< set this if added */
+ std::vector<std::string> children;
+ std::vector<std::string> path;
+ /**< path is the tracing result from start to current node
+ eg) if traversal has started from a -> b -> c -> d.
+ The path has {"a", "b", "c", "d"} */
+
+ LayerNode *operator->() { return node.get(); }
+ };
+
+ std::unordered_map<std::string, NodeInfo> mp; /// map point
+ std::transform(
+ reference.begin(), reference.end(), std::inserter(mp, mp.end()),
+ [](std::shared_ptr<LayerNode> node) {
+ return std::pair<std::string, NodeInfo>(node->getName(), node);
+ });
+
+ auto cur_start_layers = start_layers;
+ auto cur_end_layers = end_layers;
+
+ if (start_layers.empty()) {
+ for (auto &node : reference) {
+ if (node->getNumInputConnections() == 0) {
+ cur_start_layers.push_back(node->getName());
+ }
+ }
+ }
+
+ if (end_layers.empty()) {
+ for (auto &node : mp) {
+ if (node.second.children.size() == 0) {
+ cur_end_layers.insert(node.first);
+ }
+ }
+ }
+
+ std::for_each(reference.begin(), reference.end(),
+ [&mp](std::shared_ptr<LayerNode> node) {
+ auto node_name = node->getName();
+ for (auto &parent : node->getInputLayers()) {
+ mp.at(parent).children.push_back(node_name);
+ };
+ });
+
+ GraphRepresentation processed;
+
+ auto update_processed = [&processed, &mp](const std::string &name) {
+ auto &node_info = mp.at(name);
+ if (!node_info.is_added) {
+ processed.push_back(node_info.node);
+ node_info.is_added = true;
+ }
+ };
+
+ std::vector<std::string> dfs_stack(cur_start_layers.rbegin(),
+ cur_start_layers.rend());
+
+ auto is_end_node = [&cur_end_layers](const std::string &name) {
+ auto iter = cur_end_layers.find(name);
+ return iter != cur_end_layers.end();
+ };
+ while (!dfs_stack.empty()) {
+ auto &node_info = mp.at(dfs_stack.back());
+ auto &path = node_info.path;
+ path.push_back(node_info->getName());
+ if (is_end_node(node_info->getName())) {
+ std::for_each(path.begin(), path.end(), update_processed);
+ }
+
+ dfs_stack.pop_back();
+ node_info.is_visited = true;
+
+ auto &children = node_info.children;
+ std::for_each(children.begin(), children.end(),
+ [&path, &mp](const auto &name) { mp.at(name).path = path; });
+
+ /// @todo: stop inserting to the dfs stack if children->isAdded == true
+ dfs_stack.insert(dfs_stack.end(), children.rbegin(), children.rend());
+ }
+
+ NNTR_THROW_IF(processed.empty(), std::invalid_argument)
+ << "After slice, there is no node left, please check if configuration is "
+ "correct";
+
+ return processed;
+}
+
+} // namespace nntrainer
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file slice_realizer.h
+ * @date 14 October 2021
+ * @brief NNTrainer graph realizer which slice the graph representation
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+#ifndef __SLICE_REALIZER_H__
+#define __SLICE_REALIZER_H__
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include <realizer.h>
+
+namespace nntrainer {
+
+/**
+ * @brief Graph realizer class which slice graph representation
+ *
+ */
+class SliceRealizer final : public GraphRealizer {
+public:
+ /**
+ * @brief Construct a new Slice Realizer object
+ *
+ * @param start_layers start layers
+ * @param end_layers end layers
+ */
+ SliceRealizer(const std::vector<std::string> &start_layers,
+ const std::vector<std::string> &end_layers);
+
+ /**
+ * @brief Destroy the Graph Realizer object
+ *
+ */
+ ~SliceRealizer();
+
+ /**
+ * @brief graph realizer creates a new graph based on the reference
+ * @note for each layer in start_layers, start dfs, if traversal meets an end
+ * layers, node is added to an ordered set.
+ * @throw std::invalid_argument if created GraphRepresentation is empty
+ *
+ */
+ GraphRepresentation realize(const GraphRepresentation &reference) override;
+
+private:
+ std::vector<std::string> start_layers;
+ std::unordered_set<std::string> end_layers;
+};
+
+} // namespace nntrainer
+
+#endif // __SLICE_REALIZER_H__
#include <profiler.h>
#include <recurrent_realizer.h>
#include <remap_realizer.h>
+#include <slice_realizer.h>
#include <util_func.h>
/**
}
std::vector<std::unique_ptr<GraphRealizer>> realizers;
- if (!scope.empty()) {
- realizers.emplace_back(new RemapRealizer(
- [&scope](std::string &name) { name = scope + "/" + name; }));
- }
+ realizers.emplace_back(new SliceRealizer(start_layers, end_layers));
- if (!start_layers.empty() || !end_layers.empty()) {
- /// @todo add slice realizer
- /// this will extract part of layers from start to end
+ if (type == ml::train::ReferenceLayersType::RECURRENT) {
+ realizers.emplace_back(
+ new RecurrentRealizer(type_properties, input_layers));
}
if (input_layers.empty()) {
/// @todo add input setter realizer
}
- if (type == ml::train::ReferenceLayersType::RECURRENT) {
+ if (!scope.empty()) {
realizers.emplace_back(
- new RecurrentRealizer(type_properties, input_layers));
+ new RemapRealizer([&scope, &input_layers](std::string &name) {
+ for (auto &i : input_layers) {
+ if (istrequal(i, name)) {
+ return;
+ }
+ }
+ name = scope + "/" + name;
+ }));
}
for (auto &realizer : realizers) {
#include <realizer.h>
#include <recurrent_realizer.h>
#include <remap_realizer.h>
+#include <slice_realizer.h>
#include <compiler_test_util.h>
realizeAndEqual(r, {input1}, {expected1});
}
+
+TEST(SliceRealizer, slice_p) {
+ /**
+ * graph architecture
+ *
+ * a1 a2
+ * | |
+ * b1 b2 b3
+ * \ / \ /
+ * c1 c2
+ * / \
+ * d1 d2
+ */
+ std::vector<LayerRepresentation> before = {
+ {"fully_connected", {"name=a1"}},
+ {"fully_connected", {"name=a2"}},
+ {"fully_connected", {"name=b1", "input_layers=a1"}},
+ {"fully_connected", {"name=b2", "input_layers=a2"}},
+ {"fully_connected", {"name=b3"}},
+ {"fully_connected", {"name=c1", "input_layers=b1,b2"}},
+ {"fully_connected", {"name=c2", "input_layers=b2,b3"}},
+ {"fully_connected", {"name=d1", "input_layers=c1"}},
+ {"fully_connected", {"name=d2", "input_layers=c1"}},
+ };
+
+ /**
+ * graph architecture
+ * start_layer = a1, b1, b2
+ * end_layer = a1, d1, d2
+ *
+ * a1 (was input port)
+ * |
+ * b1 b2 (orphaned)
+ * \ /
+ * c1
+ * / \
+ * d1 d2
+ */
+ std::vector<LayerRepresentation> after = {
+ {"fully_connected", {"name=a1"}},
+ {"fully_connected", {"name=b1", "input_layers=a1"}},
+ {"fully_connected", {"name=c1", "input_layers=b1,b2"}},
+ {"fully_connected", {"name=d1", "input_layers=c1"}},
+ {"fully_connected", {"name=d2", "input_layers=c1"}},
+ {"fully_connected", {"name=b2", "input_layers=a2"}},
+ };
+
+ SliceRealizer r({"a1", "b1", "b2"}, {"a1", "d1", "d2"});
+
+ realizeAndEqual(r, before, after);
+}