#include "OperationFactory.h"
-#include "model/operation/Conv2DNode.h"
+#include "model/operation/Node.Include.h"
OperationFactory &OperationFactory::instance()
{
OperationFactory::OperationFactory()
{
+ using namespace neurun::model;
+
_map[ANEURALNETWORKS_CONV_2D] = [](const OperationFactory::Param &init_param) {
- using namespace neurun::model;
using neurun::model::operation::Conv2DNode;
// inputCount is either 7 or 10 acccording to NN API specification.
return new Conv2DNode{inputs, outputs, param};
};
+
+ _map[ANEURALNETWORKS_ADD] = [](const OperationFactory::Param &init_param) {
+ assert(init_param.input_count == 3);
+ assert(init_param.output_count == 1);
+
+ // Each input should be interpreted as follows:
+ //
+ // 0 -> Lefthand side operand
+ // 1 -> Righthand side operand
+
+ operand::IndexSet inputs{init_param.inputs[0], init_param.inputs[1]};
+ operand::IndexSet outputs{init_param.outputs[0]};
+
+ operation::AddNode::Param param;
+
+ param.activation_index = operand::Index{init_param.inputs[2]};
+
+ return new operation::AddNode{inputs, outputs, param};
+ };
}
neurun::model::operation::Node *OperationFactory::create(ANeuralNetworksOperationType type,
switch (type)
{
case ANEURALNETWORKS_ADD:
- {
- assert(inputCount == 3);
- assert(outputCount == 1);
-
- using GraphNode = neurun::model::operation::AddNode;
-
- _model->addOperation(nnfw::cpp14::make_unique<GraphNode>(node_param));
-
- break;
- }
case ANEURALNETWORKS_CONV_2D:
{
auto node = factory.create(type, param);
void AddNode::accept(NodeVisitor &&v) const { v.visit(*this); }
-AddNode::AddNode(const model::operation::Node::InitParam &init_param)
- : model::operation::Node{OperandConstraint::createExact(2u)}
+AddNode::AddNode(const operand::IndexSet &inputs, const operand::IndexSet &outputs,
+ const Param ¶m)
+ : model::operation::Node{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
- assert(init_param.input_count == 3);
- assert(init_param.output_count == 1);
-
- // Each input should be interpreted as follows:
- //
- // 0 -> Lefthand side operand
- // 1 -> Righthand side operand
-
- setInputs({init_param.inputs[0], init_param.inputs[1]});
- setOutputs({init_param.outputs[0]});
-
- _param.activation_index = operand::Index{init_param.inputs[2]};
}
} // namespace operation
class AddNode : public model::operation::Node
{
public:
- AddNode(const model::operation::Node::InitParam &init_param);
-
enum Input
{
LHS = 0,
};
public:
+ AddNode(const operand::IndexSet &inputs, const operand::IndexSet &outputs, const Param ¶m);
+
+public:
virtual void accept(NodeVisitor &&) const override;
virtual std::string getName() const override { return "Add"; }