#include "TypeInference.h"
#include "ShapeInference.h"
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
using namespace flatbuffers;
using namespace tflite;
namespace
{
-void exportRelu(loco::ReLU *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+class OperationExporter final : public loco::CanonicalNodeMutableVisitor<void>
+{
+public:
+ OperationExporter(FlatBufferBuilder &fbb, SerializedModelData &ctx) : builder{fbb}, gd{ctx}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void visit(loco::ReLU *) final;
+ void visit(loco::Push *) final { /* DO NOTHING */}
+ void visit(loco::Pull *) final { /* DO NOTHING */}
+ void visit(loco::FeatureEncode *) final;
+ void visit(loco::FeatureDecode *) final;
+ void visit(loco::FilterEncode *) final;
+ void visit(loco::ConstGen *) final { /* skip, everything is done in exportOpDefinedTensors */}
+ void visit(loco::MaxPool2D *) final;
+ void visit(loco::AvgPool2D *) final;
+ void visit(loco::Conv2D *) final;
+ void visit(loco::TensorConcat *) final;
+ void visit(loco::BiasEncode *) final;
+ void visit(loco::TensorBiasAdd *) final;
+ void visit(loco::FeatureBiasAdd *) final;
+
+private:
+ FlatBufferBuilder &builder;
+ SerializedModelData &gd;
+};
+
+void OperationExporter::visit(loco::ReLU *node)
{
uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU);
std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
gd._operators.push_back(op_offset);
}
-void exportMaxPool2D(loco::MaxPool2D *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::MaxPool2D *node)
{
uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAX_POOL_2D);
std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())};
gd._operators.push_back(op_offset);
}
-void exportAvgPool2D(loco::AvgPool2D *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::AvgPool2D *node)
{
// TFlite only support Valid convention of average pooling
assert(node->convention() == loco::AvgPool2D::Convention::Valid);
gd._operators.push_back(op_offset);
}
-void exportConv2D(loco::Conv2D *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::Conv2D *node)
{
uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONV_2D);
gd._operators.push_back(transpose_offset);
}
-void exportFeatureEncode(loco::FeatureEncode *node, FlatBufferBuilder &builder,
- SerializedModelData &gd)
+void OperationExporter::visit(loco::FeatureEncode *node)
{
auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder());
auto perm = encoder->perm();
}
}
-void exportFeatureDecode(loco::FeatureDecode *node, FlatBufferBuilder &builder,
- SerializedModelData &gd)
+void OperationExporter::visit(loco::FeatureDecode *node)
{
auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder());
auto perm = decoder->perm();
}
}
-void exportFilterEncode(loco::FilterEncode *node, FlatBufferBuilder &builder,
- SerializedModelData &gd)
+void OperationExporter::visit(loco::FilterEncode *node)
{
auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder());
auto perm = encoder->perm();
}
}
-void exportBiasAdd(loco::BiasAdd<loco::Domain::Tensor> *node, FlatBufferBuilder &builder,
- SerializedModelData &gd)
+void OperationExporter::visit(loco::BiasAdd<loco::Domain::Tensor> *node)
{
uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);
std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())};
gd._operators.push_back(op_offset);
}
-void exportBiasAdd(loco::FeatureBiasAdd *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::FeatureBiasAdd *node)
{
uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);
std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())};
}
/// @brief Export CONCATENATION of **TWO** tensors only
-void exportConcat(loco::TensorConcat *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::TensorConcat *node)
{
uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION);
std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
gd._operators.push_back(op_offset);
}
+void OperationExporter::visit(loco::BiasEncode *encode) { exportIdentity(encode, builder, gd); }
+
void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
SerializedModelData &data)
{
return;
}
- if (auto *relu = dynamic_cast<loco::ReLU *>(node))
- {
- exportRelu(relu, builder, data);
- }
- else if (dynamic_cast<loco::Pull *>(node))
- {
- // DO NOTHING
- }
- else if (dynamic_cast<loco::Push *>(node))
- {
- // DO NOTHING
- }
- else if (auto *encode = dynamic_cast<loco::FeatureEncode *>(node))
- {
- exportFeatureEncode(encode, builder, data);
- }
- else if (auto *decode = dynamic_cast<loco::FeatureDecode *>(node))
- {
- exportFeatureDecode(decode, builder, data);
- }
- else if (auto *encode = dynamic_cast<loco::FilterEncode *>(node))
- {
- exportFilterEncode(encode, builder, data);
- }
- else if (dynamic_cast<loco::ConstGen *>(node))
- {
- // skip, everything is done in exportOpDefinedTensors
- }
- else if (auto *max_pool = dynamic_cast<loco::MaxPool2D *>(node))
- {
- exportMaxPool2D(max_pool, builder, data);
- }
- else if (auto *avg_pool = dynamic_cast<loco::AvgPool2D *>(node))
- {
- exportAvgPool2D(avg_pool, builder, data);
- }
- else if (auto *conv2d = dynamic_cast<loco::Conv2D *>(node))
- {
- exportConv2D(conv2d, builder, data);
- }
- else if (auto *tconcat = dynamic_cast<loco::TensorConcat *>(node))
- {
- exportConcat(tconcat, builder, data);
- }
- else if (auto *encode = dynamic_cast<loco::BiasEncode *>(node))
- {
- exportIdentity(encode, builder, data);
- }
- else if (auto *biasadd = dynamic_cast<loco::BiasAdd<loco::Domain::Tensor> *>(node))
- {
- exportBiasAdd(biasadd, builder, data);
- }
- else if (auto *biasadd = dynamic_cast<loco::FeatureBiasAdd *>(node))
+ if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
{
- exportBiasAdd(biasadd, builder, data);
+ OperationExporter exporter{builder, data};
+ canonical_node->accept(&exporter);
}
else
{