#undef TFL_NODE
/// @brief Default fallback
- virtual T visit(const TFLNode *node) { assert(false); }
+ virtual T visit(const TFLNode *) { assert(false); }
};
/**
#undef TFL_NODE
/// @brief Default fallback
- virtual T visit(TFLNode *node) { assert(false); }
+ virtual T visit(TFLNode *) { assert(false); }
};
} // namespace locoex
#include "ExporterUtils.h"
#include "ShapeInference.h"
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
#include <loco/IR/CanonicalNode.h>
#include <loco/IR/CanonicalNodeVisitor.h>
#include <locoex/COpCall.h>
namespace
{
-class OperationExporter final : public loco::CanonicalNodeMutableVisitor<void>
+class OperationExporter final : public locoex::TFLNodeMutableVisitor<void>,
+ public loco::CanonicalNodeMutableVisitor<void>
{
public:
OperationExporter(FlatBufferBuilder &fbb, SerializedModelData &ctx) : builder{fbb}, gd{ctx}
}
public:
+ // FOR TFLNodes
+ void visit(locoex::TFLRelu *) final;
+
+ // FOR canonical nodes. These will be removed later
void visit(loco::ReLU *) final;
void visit(loco::ReLU6 *) final;
void visit(loco::Tanh *) final;
SerializedModelData &gd;
};
+void OperationExporter::visit(locoex::TFLRelu *node)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+ gd._operators.push_back(op_offset);
+}
+
void OperationExporter::visit(loco::ReLU *node)
{
uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU);