#include "Dialect/IR/TFLNodes.h"
#include "Dialect/IR/TFLNodeVisitor.h"
+#include "Check.h"
+
#include <loco/IR/CanonicalNode.h>
#include <loco/IR/CanonicalNodeVisitor.h>
#include <locoex/COpCall.h>
public:
// FOR TFLNodes
void visit(locoex::TFLAdd *) final;
- // TODO TFLAveragePool2D
+ void visit(locoex::TFLAveragePool2D *) final;
// TODO TFLConcatenation
// TODO TFLConv2D
// TODO TFLDepthwiseConv2D
gd._operators.push_back(op_offset);
}
-// TODO TFLAveragePool2D
+void OperationExporter::visit(locoex::TFLAveragePool2D *node)
+{
+ EXO_ASSERT(node->padding() != locoex::Padding::UNDEFINED, "Padding is not set");
+ EXO_ASSERT(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED,
+ "fused activation function is not set");
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_AVERAGE_POOL_2D);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ tflite::Padding padding =
+ node->padding() == locoex::Padding::VALID ? tflite::Padding_VALID : tflite::Padding_SAME;
+
+ auto options = CreatePool2DOptions(builder, padding, node->stride()->w(), node->stride()->h(),
+ node->filter()->w(), node->filter()->h());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_Pool2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
// TODO TFLConcatenation