#include "internal/layers/SimpleNeg.h"
#include "internal/layers/SimpleUnpackLayer.h"
#include "internal/layers/SimpleSQRT.h"
+#include "internal/layers/PReLULayer.h"
#include "internal/layers/SimpleArgMinMax.h"
#include "util/matrix/IndexIterator.h"
{
VERBOSE(PReLU) << "Configure PReLU operation" << std::endl;
- throw std::runtime_error("Not supported, yet");
+ const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
+ const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
+ const ::internal::tflite::operand::Index alpha_index{node.param().alpha_index};
+
+ // Set shape constraints
+ _builder.addShapeConstr(
+ ofm_index, asTensorInfo(asTensorShape(_ctx.at(ofm_index).shape()), _ctx.at(ofm_index).type(),
+ _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint()));
+ _builder.addShapeConstr(
+ ifm_index, asTensorInfo(asTensorShape(_ctx.at(ifm_index).shape()), _ctx.at(ifm_index).type(),
+ _ctx.at(ifm_index).scale(), _ctx.at(ifm_index).zeroPoint()));
+
+ _builder.addShapeConstr(alpha_index,
+ asTensorInfo(asTensorShape(_ctx.at(alpha_index).shape()),
+ _ctx.at(alpha_index).type(), _ctx.at(alpha_index).scale(),
+ _ctx.at(alpha_index).zeroPoint()));
+
+ struct Param
+ {
+ int ofm_index;
+ int ifm_index;
+ int alpha_index;
+ };
+
+ Param param;
+
+ param.ofm_index = ofm_index.asInt();
+ param.ifm_index = ifm_index.asInt();
+ param.alpha_index = alpha_index.asInt();
+
+ auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
+ auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
+ auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index});
+ auto alpha_alloc = ctx.at(::internal::tflite::operand::Index{param.alpha_index});
+
+ if (::internal::arm_compute::isGpuMode())
+ {
+ auto fn = nnfw::cpp14::make_unique<PReLULayer>();
+ fn->configure(CAST_CL(ifm_alloc), CAST_CL(alpha_alloc), CAST_CL(ofm_alloc));
+ builder.append("PReLU", std::move(fn));
+ }
+ else
+ {
+ // TODO Add NEON support
+
+ throw std::runtime_error("Not supported, yet");
+ }
+ };
+
+ _builder.addStage(stage);
}
void Planner::visit(const ::internal::tflite::op::ReLU::Node &node)