#include <set>
+/*
+ Note: Terms for variables in this implementation is as follows:
+
+ ex) subgraph handled: TFLConv2D -------- TFLAdd
+ (or TFLDepthwiseConv2D) (or TFLSub)
+ | |
+ \|/ \|/
+ variable name : former latter
+ Type : FormerT LatterT
+ (shortened name from Mixin) (template type)
+*/
namespace
{
+using FormerT = locoex::TFLNodeMixin<locoex::TFLNodeTrait::Bias>;
+
locoex::TFLConst *get_const(loco::Node *x, loco::Node *y)
{
if (auto const_node = dynamic_cast<locoex::TFLConst *>(x))
return nullptr;
}
+FormerT *get_former(loco::Node *x, loco::Node *y)
+{
+ if (auto node = dynamic_cast<FormerT *>(x))
+ return node;
+ else if (auto node = dynamic_cast<FormerT *>(y))
+ return node;
+
+ return nullptr;
+}
+
+// TODO replace this with get_former
locoex::TFLConv2D *get_conv2d(loco::Node *x, loco::Node *y)
{
if (auto conv2d_node = dynamic_cast<locoex::TFLConv2D *>(x))
template <> float calc<locoex::TFLAdd>(float x, float y) { return x + y; }
template <> float calc<locoex::TFLSub>(float x, float y) { return x - y; }
+// TODO rewrite this by using FormerT and LatterT (Remove Conv2D dependency)
// TFLType is either TFLAdd or TFLSub
template <typename TFLType> class Fuser
{
_fusable_node->y(nullptr);
}
+// TODO rewrite this by using FormerT and LatterT (Remove Conv2D dependency)
struct Collector final : public locoex::TFLNodeMutableVisitor<void>
{
void setCandidate(locoex::TFLNode *node, loco::Node *x, loco::Node *y)
struct Performer final : public locoex::TFLNodeMutableVisitor<void>
{
- void visit(locoex::TFLAdd *node) final
+ void visit(locoex::TFLAdd *latter) final
{
- Fuser<locoex::TFLAdd> fuser(node);
+ assert(get_former(latter->x(), latter->y()));
+
+ Fuser<locoex::TFLAdd> fuser(latter);
fuser.fuse();
}
- void visit(locoex::TFLSub *node) final
+ void visit(locoex::TFLSub *latter) final
{
- Fuser<locoex::TFLSub> fuser(node);
+ assert(get_former(latter->x(), latter->y()));
+
+ Fuser<locoex::TFLSub> fuser(latter);
fuser.fuse();
}