* @bug No known bugs except for NYI items
*/
#include <flatten_realizer.h>
+#include <remap_realizer.h>
+#include <unordered_map>
#include <flatten_layer.h>
#include <layer_node.h>
GraphRepresentation processed;
processed.reserve(reference.size());
+ std::unordered_map<std::string /**< layer_name */,
+ std::string /**< flatten_layer_name */>
+ remap_table;
+ std::vector<LayerNode *> flatten_nodes;
+ std::unordered_map<std::string /**< temp_layer_name */,
+ std::string /**< layer_name */>
+ recovery_table;
+
for (auto &node : reference) {
/// @note: [node] type=flatten; flatten=true; is awkward but allowed.
/// There is no reason to prohibit this.
processed.push_back(node);
if (node->getFlatten() && !node->getDistribute()) {
+ node->setProperty({"flatten=false"});
+
auto layer_name = node->getName();
+
+ auto flatten_name = layer_name + "/flatten_realized";
+ auto temp_name = flatten_name + "/temp";
+
+ remap_table.insert({layer_name, flatten_name});
+ recovery_table.insert({temp_name, layer_name});
+
auto flatten_node =
- createLayerNode(FlattenLayer::type, {"name=" + layer_name});
- node->setProperty({"flatten=false"});
- node->setProperty({"name=" + layer_name + "/flatten_realized"});
- flatten_node->setProperty({"input_layers=" + node->getName()});
+ createLayerNode(FlattenLayer::type, {"name=" + flatten_name});
+ flatten_node->setProperty({"input_layers=" + temp_name});
processed.push_back(std::move(flatten_node));
}
}
+ RemapRealizer remap_others([&remap_table](std::string &name, unsigned &idx) {
+ if (auto iter = remap_table.find(name); iter != remap_table.end()) {
+ name = iter->second;
+ }
+ });
+
+ RemapRealizer recover_temp(
+ [&recovery_table](std::string &name, unsigned &idx) {
+ if (auto iter = recovery_table.find(name); iter != recovery_table.end()) {
+ name = iter->second;
+ }
+ });
+
+ processed = remap_others.realize(processed);
+ processed = recover_temp.realize(processed);
return processed;
}
TEST(FlattenRealizer, flatten_p) {
FlattenRealizer fr;
- LayerRepresentation input1 = {"fully_connected",
- {"name=layer1", "flatten=true"}};
- LayerRepresentation expected1 = {"fully_connected",
- {"name=layer1/flatten_realized"}};
+ LayerRepresentation input1 = {
+ "fully_connected",
+ {"name=layer1", "flatten=true"},
+ };
+ LayerRepresentation expected1 = {"fully_connected", {"name=layer1"}};
LayerRepresentation expected2 = {
- "flatten", {"name=layer1", "input_layers=layer1/flatten_realized"}};
+ "flatten",
+ {"name=layer1/flatten_realized", "input_layers=layer1"},
+ };
realizeAndEqual(fr, {input1}, {expected1, expected2});
}