/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#include "Builders.h"
-#include "kernels/PRelu.h"
+#include "kernels/Gelu.h"
namespace luci_interpreter
{
-std::unique_ptr<Kernel> build_kernel_CirclePRelu(std::vector<const Tensor *> &&inputs,
- std::vector<Tensor *> &&outputs,
- const uint32_t op_index, KernelBuilder &builder)
+std::unique_ptr<Kernel> build_kernel_CircleGelu(const luci::CircleNode *circle_node,
+ KernelBuilderHelper &helper)
{
- assert(inputs.size() == 2);
+ const auto *node = loco::must_cast<const luci::CircleGelu *>(circle_node);
+ assert(node->arity() == 1);
+ const Tensor *input = helper.getInputTensor(node->features());
+ Tensor *output = helper.getOutputTensor(node);
- const Tensor *input = inputs.at(0);
- const Tensor *alpha = inputs.at(1);
- Tensor *output = outputs.at(0);
+ GeluParams params{};
+ params.approximate = node->approximate();
- return std::make_unique<kernels::PRelu>(input, alpha, output);
+ return std::make_unique<kernels::Gelu>(input, output, params);
}
} // namespace luci_interpreter