std::set<locoex::TFLNode *> candidates;
};
+void set_activation_fusion(loco::Node *node, locoex::FusedActFunc f)
+{
+ using namespace locoex;
+
+ if (auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node))
+ fusable_node->fusedActivationFunction(f);
+ else
+ assert(false);
+}
+
+struct Performer final : public locoex::TFLNodeMutableVisitor<void>
+{
+ void visit(locoex::TFLRelu *the_relu) final
+ {
+ set_activation_fusion(the_relu->features(), locoex::FusedActFunc::RELU);
+
+ loco::replace(the_relu).with(the_relu->features());
+ the_relu->features(nullptr);
+ }
+
+ void visit(locoex::TFLRelu6 *the_relu6) final
+ {
+ set_activation_fusion(the_relu6->features(), locoex::FusedActFunc::RELU6);
+
+ loco::replace(the_relu6).with(the_relu6->features());
+ the_relu6->features(nullptr);
+ }
+
+ void visit(locoex::TFLNode *) final { assert(false && "should not be called"); }
+};
+
} // namespace
namespace exo
}
}
- // TODO write code for fusing
+ Performer performer;
+
+ for (auto node : collector.candidates)
+ {
+ node->accept(&performer);
+ }
- return false;
+ return collector.candidates.size() > 0;
}
} // namespace exo