From ec6eab1905adea921bae1c84aff4ccbf8f127814 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9E=A5=EC=A7=80=EC=84=AD/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 23 Jul 2019 12:26:10 +0900 Subject: [PATCH] Introduce a func that initializes inputs by default (#5743) This commit introduces defaultInit func that initializes inputs by default. - Introduce defaultInit func - Remove registering initializer that initializes inputs from ConstantInitializer Signed-off-by: jiseob.jang --- .../neurun/backend/acl_cl/ConstantInitializer.cc | 464 ++------------------- .../neurun/backend/acl_cl/ConstantInitializer.h | 46 -- .../neurun/backend/acl_neon/ConstantInitializer.cc | 124 +----- .../neurun/backend/acl_neon/ConstantInitializer.h | 12 - runtimes/neurun/backend/cpu/ConstantInitializer.cc | 91 +--- runtimes/neurun/backend/cpu/ConstantInitializer.h | 8 - .../core/include/backend/IConstantInitializer.h | 39 +- .../neurun/core/src/compiler/ExecutorFactory.cc | 2 +- runtimes/neurun/core/src/compiler/PlanBuilder.cc | 6 +- runtimes/neurun/core/src/linear/Linear.cc | 10 + runtimes/neurun/core/src/linear/Linear.h | 2 + 11 files changed, 84 insertions(+), 720 deletions(-) diff --git a/runtimes/neurun/backend/acl_cl/ConstantInitializer.cc b/runtimes/neurun/backend/acl_cl/ConstantInitializer.cc index ffaf51a..926ec09 100644 --- a/runtimes/neurun/backend/acl_cl/ConstantInitializer.cc +++ b/runtimes/neurun/backend/acl_cl/ConstantInitializer.cc @@ -41,578 +41,172 @@ void ConstantInitializer::run() auto tensor_obj = _tensor_builder->wrapTensor(ind); fn(model_obj, *tensor_obj); } - _init_map.clear(); -} - -void ConstantInitializer::visit(const model::operation::AbsNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::AbsNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} -void ConstantInitializer::visit(const model::operation::AddNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::AddNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::AddNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} - -void ConstantInitializer::visit(const model::operation::ArgMaxNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ArgMaxNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::AvgPool2DNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::AvgPool2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::CastNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::CastNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::ComparisonNode &node) -{ - const auto &input0_index = node.getInputs().at(model::operation::ComparisonNode::INPUT0); - const auto &input0_obj = _operands.at(input0_index); - registerPermuteInitializer(input0_index, input0_obj); - - const auto &input1_index = node.getInputs().at(model::operation::ComparisonNode::INPUT1); - const auto &input1_obj = _operands.at(input1_index); - registerPermuteInitializer(input1_index, input1_obj); -} - -void ConstantInitializer::visit(const model::operation::ConcatNode &node) -{ - const auto inputs = node.getInputs(); - for (const auto &input_index : inputs) - { - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - } + _init_map.clear(); } void ConstantInitializer::visit(const model::operation::Conv2DNode &node) { - const auto &input_index = node.getInputs().at(model::operation::Conv2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &kernel_index = node.getInputs().at(model::operation::Conv2DNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); registerPermuteInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(model::operation::Conv2DNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); -} - -void ConstantInitializer::visit(const model::operation::DepthToSpaceNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::DepthToSpaceNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); + registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::DepthwiseConv2DNode &node) { - const auto &input_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &kernel_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); registerPermuteInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); -} - -void ConstantInitializer::visit(const model::operation::DequantizeNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::DequantizeNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::DivNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::DivNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::DivNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); + registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::EmbeddingLookupNode &node) { - const auto &input_index = node.getInputs().at(model::operation::EmbeddingLookupNode::VALUES); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &lookups_index = node.getInputs().at(model::operation::EmbeddingLookupNode::LOOKUPS); const auto &lookups_obj = _operands.at(lookups_index); - registerDefaultInitializer(lookups_index, lookups_obj); -} - -void ConstantInitializer::visit(const model::operation::ExpNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ExpNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::FloorNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::FloorNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); + registerCopyInitializer(lookups_index, lookups_obj); } void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node) { - const auto &input_index = node.getInputs().at(model::operation::FullyConnectedNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &weight_index = node.getInputs().at(model::operation::FullyConnectedNode::WEIGHT); const auto &weight_obj = _operands.at(weight_index); - registerDefaultInitializer(weight_index, weight_obj); + registerCopyInitializer(weight_index, weight_obj); const auto &bias_index = node.getInputs().at(model::operation::FullyConnectedNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); + registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::GatherNode &node) { - const auto &input_index = node.getInputs().at(model::operation::GatherNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &indices_index = node.getInputs().at(model::operation::GatherNode::INDICES); const auto &indices_obj = _operands.at(indices_index); - registerDefaultInitializer(indices_index, indices_obj); + registerCopyInitializer(indices_index, indices_obj); } void ConstantInitializer::visit(const model::operation::HashtableLookupNode &node) { - const auto &input_index = node.getInputs().at(model::operation::HashtableLookupNode::VALUES); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &lookups_index = node.getInputs().at(model::operation::HashtableLookupNode::LOOKUPS); const auto &lookups_obj = _operands.at(lookups_index); - registerDefaultInitializer(lookups_index, lookups_obj); + registerCopyInitializer(lookups_index, lookups_obj); const auto &keys_index = node.getInputs().at(model::operation::HashtableLookupNode::KEYS); const auto &keys_obj = _operands.at(keys_index); - registerDefaultInitializer(keys_index, keys_obj); -} - -void ConstantInitializer::visit(const model::operation::L2NormalizationNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::L2NormalizationNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::L2Pool2DNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::L2Pool2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::LocalResponseNormalizationNode &node) -{ - const auto &input_index = - node.getInputs().at(model::operation::LocalResponseNormalizationNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::LogicalAndNode &node) -{ - const auto &input0_index = node.getInputs().at(model::operation::LogicalAndNode::INPUT0); - const auto &input0_obj = _operands.at(input0_index); - registerPermuteInitializer(input0_index, input0_obj); - - const auto &input1_index = node.getInputs().at(model::operation::LogicalAndNode::INPUT1); - const auto &input1_obj = _operands.at(input1_index); - registerPermuteInitializer(input1_index, input1_obj); -} - -void ConstantInitializer::visit(const model::operation::LogicalNotNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::LogicalNotNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::LogicalOrNode &node) -{ - const auto &input0_index = node.getInputs().at(model::operation::LogicalOrNode::INPUT0); - const auto &input0_obj = _operands.at(input0_index); - registerPermuteInitializer(input0_index, input0_obj); - - const auto &input1_index = node.getInputs().at(model::operation::LogicalOrNode::INPUT1); - const auto &input1_obj = _operands.at(input1_index); - registerPermuteInitializer(input1_index, input1_obj); -} - -void ConstantInitializer::visit(const model::operation::LogisticNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::LogisticNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); + registerCopyInitializer(keys_index, keys_obj); } void ConstantInitializer::visit(const model::operation::LSTMNode &node) { - const auto &input_index = node.getInputs().at(model::operation::LSTMNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - - const auto &output_state_in_index = - node.getInputs().at(model::operation::LSTMNode::OUTPUT_STATE_IN); - const auto &output_state_in_obj = _operands.at(output_state_in_index); - registerPermuteInitializer(output_state_in_index, output_state_in_obj); - - const auto &cell_state_in_index = node.getInputs().at(model::operation::LSTMNode::CELL_STATE_IN); - const auto &cell_state_in_obj = _operands.at(cell_state_in_index); - registerPermuteInitializer(cell_state_in_index, cell_state_in_obj); - const auto &input_to_input_weights_index = node.getInputs().at(model::operation::LSTMNode::INPUT_TO_INPUT_WEIGHTS); const auto &input_to_input_weights_obj = _operands.at(input_to_input_weights_index); - registerDefaultInitializer(input_to_input_weights_index, input_to_input_weights_obj); + registerCopyInitializer(input_to_input_weights_index, input_to_input_weights_obj); const auto &input_to_forget_weights_index = node.getInputs().at(model::operation::LSTMNode::INPUT_TO_FORGET_WEIGHTS); const auto &input_to_forget_weights_obj = _operands.at(input_to_forget_weights_index); - registerDefaultInitializer(input_to_forget_weights_index, input_to_forget_weights_obj); + registerCopyInitializer(input_to_forget_weights_index, input_to_forget_weights_obj); const auto &input_to_cell_weights_index = node.getInputs().at(model::operation::LSTMNode::INPUT_TO_CELL_WEIGHTS); const auto &input_to_cell_weights_obj = _operands.at(input_to_cell_weights_index); - registerDefaultInitializer(input_to_cell_weights_index, input_to_cell_weights_obj); + registerCopyInitializer(input_to_cell_weights_index, input_to_cell_weights_obj); const auto &input_to_output_weights_index = node.getInputs().at(model::operation::LSTMNode::INPUT_TO_OUTPUT_WEIGHTS); const auto &input_to_output_weights_obj = _operands.at(input_to_output_weights_index); - registerDefaultInitializer(input_to_output_weights_index, input_to_output_weights_obj); + registerCopyInitializer(input_to_output_weights_index, input_to_output_weights_obj); const auto &recurrent_to_input_weights_index = node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_INPUT_WEIGHTS); const auto &recurrent_to_input_weights_obj = _operands.at(recurrent_to_input_weights_index); - registerDefaultInitializer(recurrent_to_input_weights_index, recurrent_to_input_weights_obj); + registerCopyInitializer(recurrent_to_input_weights_index, recurrent_to_input_weights_obj); const auto &recurrent_to_forget_weights_index = node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_FORGET_WEIGHTS); const auto &recurrent_to_forget_weights_obj = _operands.at(recurrent_to_forget_weights_index); - registerDefaultInitializer(recurrent_to_forget_weights_index, recurrent_to_forget_weights_obj); + registerCopyInitializer(recurrent_to_forget_weights_index, recurrent_to_forget_weights_obj); const auto &recurrent_to_cell_weights_index = node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_CELL_WEIGHTS); const auto &recurrent_to_cell_weights_obj = _operands.at(recurrent_to_cell_weights_index); - registerDefaultInitializer(recurrent_to_cell_weights_index, recurrent_to_cell_weights_obj); + registerCopyInitializer(recurrent_to_cell_weights_index, recurrent_to_cell_weights_obj); const auto &recurrent_to_output_weights_index = node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_OUTPUT_WEIGHTS); const auto &recurrent_to_output_weights_obj = _operands.at(recurrent_to_output_weights_index); - registerDefaultInitializer(recurrent_to_output_weights_index, recurrent_to_output_weights_obj); + registerCopyInitializer(recurrent_to_output_weights_index, recurrent_to_output_weights_obj); const auto &cell_to_input_weights_index = node.getInputs().at(model::operation::LSTMNode::CELL_TO_INPUT_WEIGHTS); const auto &cell_to_input_weights_obj = _operands.at(cell_to_input_weights_index); - registerDefaultInitializer(cell_to_input_weights_index, cell_to_input_weights_obj); + registerCopyInitializer(cell_to_input_weights_index, cell_to_input_weights_obj); const auto &cell_to_forget_weights_index = node.getInputs().at(model::operation::LSTMNode::CELL_TO_FORGET_WEIGHTS); const auto &cell_to_forget_weights_obj = _operands.at(cell_to_forget_weights_index); - registerDefaultInitializer(cell_to_forget_weights_index, cell_to_forget_weights_obj); + registerCopyInitializer(cell_to_forget_weights_index, cell_to_forget_weights_obj); const auto &cell_to_output_weights_index = node.getInputs().at(model::operation::LSTMNode::CELL_TO_OUTPUT_WEIGHTS); const auto &cell_to_output_weights_obj = _operands.at(cell_to_output_weights_index); - registerDefaultInitializer(cell_to_output_weights_index, cell_to_output_weights_obj); + registerCopyInitializer(cell_to_output_weights_index, cell_to_output_weights_obj); const auto &input_gate_bias_index = node.getInputs().at(model::operation::LSTMNode::INPUT_GATE_BIAS); const auto &input_gate_bias_obj = _operands.at(input_gate_bias_index); - registerDefaultInitializer(input_gate_bias_index, input_gate_bias_obj); + registerCopyInitializer(input_gate_bias_index, input_gate_bias_obj); const auto &forget_gate_bias_index = node.getInputs().at(model::operation::LSTMNode::FORGET_GATE_BIAS); const auto &forget_gate_bias_obj = _operands.at(forget_gate_bias_index); - registerDefaultInitializer(forget_gate_bias_index, forget_gate_bias_obj); + registerCopyInitializer(forget_gate_bias_index, forget_gate_bias_obj); const auto &output_gate_bias_index = node.getInputs().at(model::operation::LSTMNode::OUTPUT_GATE_BIAS); const auto &output_gate_bias_obj = _operands.at(output_gate_bias_index); - registerDefaultInitializer(output_gate_bias_index, output_gate_bias_obj); + registerCopyInitializer(output_gate_bias_index, output_gate_bias_obj); const auto &projection_weights_index = node.getInputs().at(model::operation::LSTMNode::PROJECTION_WEIGHTS); const auto &projection_weights_obj = _operands.at(projection_weights_index); - registerDefaultInitializer(projection_weights_index, projection_weights_obj); + registerCopyInitializer(projection_weights_index, projection_weights_obj); const auto &projection_bias_index = node.getInputs().at(model::operation::LSTMNode::PROJECTION_BIAS); const auto &projection_bias_obj = _operands.at(projection_bias_index); - registerDefaultInitializer(projection_bias_index, projection_bias_obj); -} - -void ConstantInitializer::visit(const model::operation::MaxPool2DNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::MaxPool2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::MeanNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::MeanNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::MulNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::MulNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::MulNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} - -void ConstantInitializer::visit(const model::operation::NegNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::NegNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::PadNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::PadNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::PReLUNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::PReLUNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - - const auto &alpha_index = node.getInputs().at(model::operation::PReLUNode::ALPHA); - const auto &alpha_obj = _operands.at(alpha_index); - registerPermuteInitializer(alpha_index, alpha_obj); -} - -void ConstantInitializer::visit(const model::operation::ReduceMaxNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ReduceMaxNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::ReduceMinNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ReduceMinNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::ReduceSumNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ReduceSumNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::ReLUNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ReLUNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::ReLU1Node &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ReLU1Node::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::ReLU6Node &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ReLU6Node::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::ReshapeNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ReshapeNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::ResizeBilinearNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ResizeBilinearNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); + registerCopyInitializer(projection_bias_index, projection_bias_obj); } void ConstantInitializer::visit(const model::operation::RNNNode &node) { - const auto &input_index = node.getInputs().at(model::operation::RNNNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &weights_index = node.getInputs().at(model::operation::RNNNode::WEIGHTS); const auto &weights_obj = _operands.at(weights_index); - registerDefaultInitializer(weights_index, weights_obj); + registerCopyInitializer(weights_index, weights_obj); const auto &recurrent_weights_index = node.getInputs().at(model::operation::RNNNode::RECURRENT_WEIGHTS); const auto &recurrent_weights_obj = _operands.at(recurrent_weights_index); - registerDefaultInitializer(recurrent_weights_index, recurrent_weights_obj); + registerCopyInitializer(recurrent_weights_index, recurrent_weights_obj); const auto &bias_index = node.getInputs().at(model::operation::RNNNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); -} - -void ConstantInitializer::visit(const model::operation::RSQRTNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::RSQRTNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::SoftmaxNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::SoftmaxNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::SpaceToDepthNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::SpaceToDepthNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::SplitNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::SplitNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::SQRTNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::SQRTNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::SquaredDifferenceNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} - -void ConstantInitializer::visit(const model::operation::SqueezeNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::SqueezeNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::StridedSliceNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::StridedSliceNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::SubNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::SubNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::SubNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} - -void ConstantInitializer::visit(const model::operation::TanhNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::TanhNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::TopKV2Node &node) -{ - const auto &input_index = node.getInputs().at(model::operation::TopKV2Node::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); + registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::TransposeConvNode &node) { - const auto &input_index = node.getInputs().at(model::operation::TransposeConvNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &kernel_index = node.getInputs().at(model::operation::TransposeConvNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); - registerDefaultInitializer(kernel_index, kernel_obj); -} - -void ConstantInitializer::visit(const model::operation::TransposeNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::TransposeNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::UnpackNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::UnpackNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); + registerCopyInitializer(kernel_index, kernel_obj); } } // namespace acl_cl diff --git a/runtimes/neurun/backend/acl_cl/ConstantInitializer.h b/runtimes/neurun/backend/acl_cl/ConstantInitializer.h index f1d243d..59772e0 100644 --- a/runtimes/neurun/backend/acl_cl/ConstantInitializer.h +++ b/runtimes/neurun/backend/acl_cl/ConstantInitializer.h @@ -38,61 +38,15 @@ public: void run() override; public: - void visit(const model::operation::AbsNode &) override; - void visit(const model::operation::AddNode &) override; - void visit(const model::operation::ArgMaxNode &) override; - void visit(const model::operation::AvgPool2DNode &) override; - void visit(const model::operation::CastNode &) override; - void visit(const model::operation::ComparisonNode &) override; - void visit(const model::operation::ConcatNode &) override; void visit(const model::operation::Conv2DNode &) override; - void visit(const model::operation::DepthToSpaceNode &) override; void visit(const model::operation::DepthwiseConv2DNode &) override; - void visit(const model::operation::DequantizeNode &) override; - void visit(const model::operation::DivNode &) override; void visit(const model::operation::EmbeddingLookupNode &) override; - void visit(const model::operation::ExpNode &) override; - void visit(const model::operation::FloorNode &) override; void visit(const model::operation::FullyConnectedNode &) override; void visit(const model::operation::GatherNode &) override; void visit(const model::operation::HashtableLookupNode &) override; - void visit(const model::operation::L2NormalizationNode &) override; - void visit(const model::operation::L2Pool2DNode &) override; - void visit(const model::operation::LocalResponseNormalizationNode &) override; - void visit(const model::operation::LogicalAndNode &) override; - void visit(const model::operation::LogicalNotNode &) override; - void visit(const model::operation::LogicalOrNode &) override; - void visit(const model::operation::LogisticNode &) override; void visit(const model::operation::LSTMNode &) override; - void visit(const model::operation::MaxPool2DNode &) override; - void visit(const model::operation::MeanNode &) override; - void visit(const model::operation::MulNode &) override; - void visit(const model::operation::NegNode &) override; - void visit(const model::operation::PadNode &) override; - void visit(const model::operation::PReLUNode &) override; - void visit(const model::operation::ReduceMaxNode &) override; - void visit(const model::operation::ReduceMinNode &) override; - void visit(const model::operation::ReduceSumNode &) override; - void visit(const model::operation::ReLUNode &) override; - void visit(const model::operation::ReLU1Node &) override; - void visit(const model::operation::ReLU6Node &) override; - void visit(const model::operation::ReshapeNode &node) override; - void visit(const model::operation::ResizeBilinearNode &) override; void visit(const model::operation::RNNNode &) override; - void visit(const model::operation::RSQRTNode &) override; - void visit(const model::operation::SoftmaxNode &node) override; - void visit(const model::operation::SpaceToDepthNode &) override; - void visit(const model::operation::SplitNode &) override; - void visit(const model::operation::SQRTNode &) override; - void visit(const model::operation::SquaredDifferenceNode &) override; - void visit(const model::operation::SqueezeNode &) override; - void visit(const model::operation::StridedSliceNode &) override; - void visit(const model::operation::SubNode &) override; - void visit(const model::operation::TanhNode &) override; - void visit(const model::operation::TopKV2Node &) override; void visit(const model::operation::TransposeConvNode &) override; - void visit(const model::operation::TransposeNode &) override; - void visit(const model::operation::UnpackNode &) override; private: const model::Operands &_operands; diff --git a/runtimes/neurun/backend/acl_neon/ConstantInitializer.cc b/runtimes/neurun/backend/acl_neon/ConstantInitializer.cc index 4ed578f..98be80b 100644 --- a/runtimes/neurun/backend/acl_neon/ConstantInitializer.cc +++ b/runtimes/neurun/backend/acl_neon/ConstantInitializer.cc @@ -41,155 +41,41 @@ void ConstantInitializer::run() auto tensor_obj = _tensor_builder->wrapTensor(ind); fn(model_obj, *tensor_obj); } - _init_map.clear(); -} - -void ConstantInitializer::visit(const model::operation::AddNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::AddNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::AddNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} -void ConstantInitializer::visit(const model::operation::AvgPool2DNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::AvgPool2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::ConcatNode &node) -{ - const auto inputs = node.getInputs(); - for (const auto &input_index : inputs) - { - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - } + _init_map.clear(); } void ConstantInitializer::visit(const model::operation::Conv2DNode &node) { - const auto &input_index = node.getInputs().at(model::operation::Conv2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &kernel_index = node.getInputs().at(model::operation::Conv2DNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); registerPermuteInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(model::operation::Conv2DNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); + registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::DepthwiseConv2DNode &node) { - const auto &input_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &kernel_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); registerPermuteInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); + registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node) { - const auto &input_index = node.getInputs().at(model::operation::FullyConnectedNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &weight_index = node.getInputs().at(model::operation::FullyConnectedNode::WEIGHT); const auto &weight_obj = _operands.at(weight_index); - registerDefaultInitializer(weight_index, weight_obj); + registerCopyInitializer(weight_index, weight_obj); const auto &bias_index = node.getInputs().at(model::operation::FullyConnectedNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); -} - -void ConstantInitializer::visit(const model::operation::MaxPool2DNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::MaxPool2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::MeanNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::MeanNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::MulNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::MulNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::MulNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} - -void ConstantInitializer::visit(const model::operation::ReshapeNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ReshapeNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::RSQRTNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::RSQRTNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::SoftmaxNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::SoftmaxNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::SquaredDifferenceNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} - -void ConstantInitializer::visit(const model::operation::SubNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::SubNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::SubNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} - -void ConstantInitializer::visit(const model::operation::TanhNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::TanhNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); + registerCopyInitializer(bias_index, bias_obj); } } // namespace acl_neon diff --git a/runtimes/neurun/backend/acl_neon/ConstantInitializer.h b/runtimes/neurun/backend/acl_neon/ConstantInitializer.h index 91c818f..62e889c 100644 --- a/runtimes/neurun/backend/acl_neon/ConstantInitializer.h +++ b/runtimes/neurun/backend/acl_neon/ConstantInitializer.h @@ -38,21 +38,9 @@ public: void run() override; public: - void visit(const model::operation::AddNode &) override; - void visit(const model::operation::AvgPool2DNode &) override; - void visit(const model::operation::ConcatNode &) override; void visit(const model::operation::Conv2DNode &) override; void visit(const model::operation::DepthwiseConv2DNode &) override; void visit(const model::operation::FullyConnectedNode &) override; - void visit(const model::operation::MaxPool2DNode &) override; - void visit(const model::operation::MeanNode &) override; - void visit(const model::operation::MulNode &) override; - void visit(const model::operation::ReshapeNode &) override; - void visit(const model::operation::RSQRTNode &) override; - void visit(const model::operation::SoftmaxNode &) override; - void visit(const model::operation::SquaredDifferenceNode &) override; - void visit(const model::operation::SubNode &) override; - void visit(const model::operation::TanhNode &) override; private: const model::Operands &_operands; diff --git a/runtimes/neurun/backend/cpu/ConstantInitializer.cc b/runtimes/neurun/backend/cpu/ConstantInitializer.cc index 8997933..cff9fcf 100644 --- a/runtimes/neurun/backend/cpu/ConstantInitializer.cc +++ b/runtimes/neurun/backend/cpu/ConstantInitializer.cc @@ -41,118 +41,41 @@ void ConstantInitializer::run() auto tensor_obj = _tensor_builder->wrapTensor(ind); fn(model_obj, *tensor_obj); } - _init_map.clear(); -} - -void ConstantInitializer::visit(const model::operation::AddNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::AddNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::AddNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} - -void ConstantInitializer::visit(const model::operation::AvgPool2DNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::AvgPool2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} -void ConstantInitializer::visit(const model::operation::ConcatNode &node) -{ - const auto inputs = node.getInputs(); - for (const auto &input_index : inputs) - { - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - } + _init_map.clear(); } void ConstantInitializer::visit(const model::operation::Conv2DNode &node) { - const auto &input_index = node.getInputs().at(model::operation::Conv2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &kernel_index = node.getInputs().at(model::operation::Conv2DNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); - registerDefaultInitializer(kernel_index, kernel_obj); + registerCopyInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(model::operation::Conv2DNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); + registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::DepthwiseConv2DNode &node) { - const auto &input_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &kernel_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::KERNEL); const auto &kernel_obj = _operands.at(kernel_index); - registerDefaultInitializer(kernel_index, kernel_obj); + registerCopyInitializer(kernel_index, kernel_obj); const auto &bias_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); + registerCopyInitializer(bias_index, bias_obj); } void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node) { - const auto &input_index = node.getInputs().at(model::operation::FullyConnectedNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); - const auto &weight_index = node.getInputs().at(model::operation::FullyConnectedNode::WEIGHT); const auto &weight_obj = _operands.at(weight_index); - registerDefaultInitializer(weight_index, weight_obj); + registerCopyInitializer(weight_index, weight_obj); const auto &bias_index = node.getInputs().at(model::operation::FullyConnectedNode::BIAS); const auto &bias_obj = _operands.at(bias_index); - registerDefaultInitializer(bias_index, bias_obj); -} - -void ConstantInitializer::visit(const model::operation::MaxPool2DNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::MaxPool2DNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::MulNode &node) -{ - const auto &lhs_index = node.getInputs().at(model::operation::MulNode::LHS); - const auto &lhs_obj = _operands.at(lhs_index); - registerPermuteInitializer(lhs_index, lhs_obj); - - const auto &rhs_index = node.getInputs().at(model::operation::MulNode::RHS); - const auto &rhs_obj = _operands.at(rhs_index); - registerPermuteInitializer(rhs_index, rhs_obj); -} - -void ConstantInitializer::visit(const model::operation::ReshapeNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::ReshapeNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::SoftmaxNode &node) -{ - const auto &input_index = node.getInputs().at(model::operation::SoftmaxNode::INPUT); - const auto &input_obj = _operands.at(input_index); - registerPermuteInitializer(input_index, input_obj); -} - -void ConstantInitializer::visit(const model::operation::PermuteNode &) -{ - // DO NOTHING - // This node's constant doesn't exist. + registerCopyInitializer(bias_index, bias_obj); } } // namespace cpu diff --git a/runtimes/neurun/backend/cpu/ConstantInitializer.h b/runtimes/neurun/backend/cpu/ConstantInitializer.h index 4f9794c..91f1d50 100644 --- a/runtimes/neurun/backend/cpu/ConstantInitializer.h +++ b/runtimes/neurun/backend/cpu/ConstantInitializer.h @@ -38,17 +38,9 @@ public: void run() override; public: - void visit(const model::operation::AddNode &) override; - void visit(const model::operation::AvgPool2DNode &) override; - void visit(const model::operation::ConcatNode &) override; void visit(const model::operation::Conv2DNode &) override; void visit(const model::operation::DepthwiseConv2DNode &) override; void visit(const model::operation::FullyConnectedNode &) override; - void visit(const model::operation::MaxPool2DNode &) override; - void visit(const model::operation::MulNode &) override; - void visit(const model::operation::ReshapeNode &) override; - void visit(const model::operation::SoftmaxNode &) override; - void visit(const model::operation::PermuteNode &) override; private: const model::Operands &_operands; diff --git a/runtimes/neurun/core/include/backend/IConstantInitializer.h b/runtimes/neurun/core/include/backend/IConstantInitializer.h index 0545391..0c767be 100644 --- a/runtimes/neurun/core/include/backend/IConstantInitializer.h +++ b/runtimes/neurun/core/include/backend/IConstantInitializer.h @@ -22,6 +22,7 @@ #include "ITensorBuilder.h" #include "model/Operand.h" +#include "model/Operands.h" #include "model/OperationVisitor.h" #include "model/Subgraph.h" #include "util/logging.h" @@ -30,7 +31,7 @@ namespace { template -void defaultInit(const neurun::model::Operand &model_obj, neurun::backend::operand::IObject &obj) +void copyInit(const neurun::model::Operand &model_obj, neurun::backend::operand::IObject &obj) { const auto shape = model_obj.shape(); auto base = reinterpret_cast(model_obj.data().base()); @@ -207,19 +208,30 @@ public: public: using Initializer = std::function; - void generate(const model::Subgraph &subg) { subg.accept(*this); } + void generate(const model::Subgraph &subg, const model::Operands &operands) + { + subg.accept(*this); + for (const auto &e : subg.operations()) + { + for (const auto &ind : e.node->getInputs()) + { + const auto &obj = operands.at(ind); + if (obj.isConstant() && !exist(ind)) + { + registerPermuteInitializer(ind, obj); + } + } + } + } protected: -#define OP(InternalName, IsNnApi) \ - virtual void visit(const model::operation::InternalName &) override \ - { \ - throw std::runtime_error("NYI"); \ - } +#define OP(InternalName, IsNnApi) \ + virtual void visit(const model::operation::InternalName &) override { /* DO NOTHING */} #include "model/Operations.lst" #undef OP protected: - void registerDefaultInitializer(const model::OperandIndex &index, const model::Operand &obj) + void registerCopyInitializer(const model::OperandIndex &index, const model::Operand &obj) { // For only CONSTANTS if (!obj.isConstant()) @@ -233,17 +245,17 @@ protected: switch (type) { case DataType::FLOAT32: - _init_map[index] = defaultInit; + _init_map[index] = copyInit; break; case DataType::INT32: - _init_map[index] = defaultInit; + _init_map[index] = copyInit; break; case DataType::UINT32: - _init_map[index] = defaultInit; + _init_map[index] = copyInit; break; case DataType::BOOL8: case DataType::QUANT8_ASYMM: - _init_map[index] = defaultInit; + _init_map[index] = copyInit; break; default: throw std::runtime_error("Not supported, yet"); @@ -284,6 +296,9 @@ protected: } } +private: + bool exist(const model::OperandIndex &ind) { return _init_map.find(ind) != _init_map.end(); } + protected: std::unordered_map _init_map; }; diff --git a/runtimes/neurun/core/src/compiler/ExecutorFactory.cc b/runtimes/neurun/core/src/compiler/ExecutorFactory.cc index 4328c00..efe3fd3 100644 --- a/runtimes/neurun/core/src/compiler/ExecutorFactory.cc +++ b/runtimes/neurun/core/src/compiler/ExecutorFactory.cc @@ -226,7 +226,7 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(graph::Graph &graph, bo [&](const model::SubgraphIndex &subg_index, const model::Subgraph &subg) { auto backend = graph.getLowerInfo(subg_index)->backend(); auto constant_initializer = backend->constant_initializer(); - constant_initializer->generate(subg); + constant_initializer->generate(subg, graph.operands()); // TODO This approach is temporal. See declaration of `setNextIndex`. execution_builder->setNextIndex(subg_index); auto kernel_gen = backend->kernel_gen(); diff --git a/runtimes/neurun/core/src/compiler/PlanBuilder.cc b/runtimes/neurun/core/src/compiler/PlanBuilder.cc index fbef2cc..9895c86 100644 --- a/runtimes/neurun/core/src/compiler/PlanBuilder.cc +++ b/runtimes/neurun/core/src/compiler/PlanBuilder.cc @@ -18,7 +18,6 @@ #include "backend/operand/IObject.h" #include "backend/Backend.h" -#include "backend/IConstantInitializer.h" #include "backend/IKernelGenerator.h" #include "linear/Linear.h" @@ -42,12 +41,13 @@ void PlanBuilder::finalize(const linear::Linear *linear, }); } + // Generate initializers + linear->generateConstantInitializers(); + // Generate kernels auto execution_builder = nnfw::cpp14::make_unique(_functions); linear->iterate([&](const linear::Element &element) { auto backend = element.lower_info->backend(); - auto constant_initializer = backend->constant_initializer(); - constant_initializer->generate(*element.subgraph); auto kernel_gen = backend->kernel_gen(); kernel_gen->generate(*element.subgraph, execution_builder.get()); }); diff --git a/runtimes/neurun/core/src/linear/Linear.cc b/runtimes/neurun/core/src/linear/Linear.cc index 94d1089..0e28d74 100644 --- a/runtimes/neurun/core/src/linear/Linear.cc +++ b/runtimes/neurun/core/src/linear/Linear.cc @@ -22,6 +22,7 @@ #include "graph/operand/LowerInfo.h" #include "backend/IShapeFixer.h" #include "backend/IConfig.h" +#include "backend/IConstantInitializer.h" #include "backend/Backend.h" #include "compiler/SubTensorInfo.h" #include "model/OperandInfo.h" @@ -309,6 +310,15 @@ void Linear::iterate(const std::function &fn) cons } } +void Linear::generateConstantInitializers(void) const +{ + iterate([&](const linear::Element &element) { + auto backend = element.lower_info->backend(); + auto constant_initializer = backend->constant_initializer(); + constant_initializer->generate(*element.subgraph, _model->operands); + }); +} + const graph::operation::LowerInfo *Linear::getLowerInfo(const model::SubgraphIndex &index) const { if (!_lower_info_map) diff --git a/runtimes/neurun/core/src/linear/Linear.h b/runtimes/neurun/core/src/linear/Linear.h index ed82c48..45c8489 100644 --- a/runtimes/neurun/core/src/linear/Linear.h +++ b/runtimes/neurun/core/src/linear/Linear.h @@ -60,6 +60,8 @@ public: void iterate(const std::function &fn) const; + void generateConstantInitializers(void) const; + std::unique_ptr releaseLowerInfo() { return std::move(_lower_info_map); } graph::LowerInfoMap *getLowerInfo() { return _lower_info_map.get(); } -- 2.7.4