Introduce a func that initializes inputs by default (#5743)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Tue, 23 Jul 2019 03:26:10 +0000 (12:26 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 23 Jul 2019 03:26:10 +0000 (12:26 +0900)
This commit introduces defaultInit func that initializes inputs by default.
  - Introduce defaultInit func
  - Remove registering initializer that initializes inputs from ConstantInitializer

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/neurun/backend/acl_cl/ConstantInitializer.cc
runtimes/neurun/backend/acl_cl/ConstantInitializer.h
runtimes/neurun/backend/acl_neon/ConstantInitializer.cc
runtimes/neurun/backend/acl_neon/ConstantInitializer.h
runtimes/neurun/backend/cpu/ConstantInitializer.cc
runtimes/neurun/backend/cpu/ConstantInitializer.h
runtimes/neurun/core/include/backend/IConstantInitializer.h
runtimes/neurun/core/src/compiler/ExecutorFactory.cc
runtimes/neurun/core/src/compiler/PlanBuilder.cc
runtimes/neurun/core/src/linear/Linear.cc
runtimes/neurun/core/src/linear/Linear.h

index ffaf51a..926ec09 100644 (file)
@@ -41,578 +41,172 @@ void ConstantInitializer::run()
     auto tensor_obj = _tensor_builder->wrapTensor(ind);
     fn(model_obj, *tensor_obj);
   }
-  _init_map.clear();
-}
-
-void ConstantInitializer::visit(const model::operation::AbsNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::AbsNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
 
-void ConstantInitializer::visit(const model::operation::AddNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::AddNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::AddNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ArgMaxNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ArgMaxNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::AvgPool2DNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::AvgPool2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::CastNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::CastNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ComparisonNode &node)
-{
-  const auto &input0_index = node.getInputs().at(model::operation::ComparisonNode::INPUT0);
-  const auto &input0_obj = _operands.at(input0_index);
-  registerPermuteInitializer(input0_index, input0_obj);
-
-  const auto &input1_index = node.getInputs().at(model::operation::ComparisonNode::INPUT1);
-  const auto &input1_obj = _operands.at(input1_index);
-  registerPermuteInitializer(input1_index, input1_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ConcatNode &node)
-{
-  const auto inputs = node.getInputs();
-  for (const auto &input_index : inputs)
-  {
-    const auto &input_obj = _operands.at(input_index);
-    registerPermuteInitializer(input_index, input_obj);
-  }
+  _init_map.clear();
 }
 
 void ConstantInitializer::visit(const model::operation::Conv2DNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::Conv2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &kernel_index = node.getInputs().at(model::operation::Conv2DNode::KERNEL);
   const auto &kernel_obj = _operands.at(kernel_index);
   registerPermuteInitializer(kernel_index, kernel_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::Conv2DNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::DepthToSpaceNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::DepthToSpaceNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::DepthwiseConv2DNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &kernel_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::KERNEL);
   const auto &kernel_obj = _operands.at(kernel_index);
   registerPermuteInitializer(kernel_index, kernel_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::DequantizeNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::DequantizeNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::DivNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::DivNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::DivNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::EmbeddingLookupNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::EmbeddingLookupNode::VALUES);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &lookups_index = node.getInputs().at(model::operation::EmbeddingLookupNode::LOOKUPS);
   const auto &lookups_obj = _operands.at(lookups_index);
-  registerDefaultInitializer(lookups_index, lookups_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ExpNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ExpNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::FloorNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::FloorNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
+  registerCopyInitializer(lookups_index, lookups_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::FullyConnectedNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &weight_index = node.getInputs().at(model::operation::FullyConnectedNode::WEIGHT);
   const auto &weight_obj = _operands.at(weight_index);
-  registerDefaultInitializer(weight_index, weight_obj);
+  registerCopyInitializer(weight_index, weight_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::FullyConnectedNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::GatherNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::GatherNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &indices_index = node.getInputs().at(model::operation::GatherNode::INDICES);
   const auto &indices_obj = _operands.at(indices_index);
-  registerDefaultInitializer(indices_index, indices_obj);
+  registerCopyInitializer(indices_index, indices_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::HashtableLookupNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::HashtableLookupNode::VALUES);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &lookups_index = node.getInputs().at(model::operation::HashtableLookupNode::LOOKUPS);
   const auto &lookups_obj = _operands.at(lookups_index);
-  registerDefaultInitializer(lookups_index, lookups_obj);
+  registerCopyInitializer(lookups_index, lookups_obj);
 
   const auto &keys_index = node.getInputs().at(model::operation::HashtableLookupNode::KEYS);
   const auto &keys_obj = _operands.at(keys_index);
-  registerDefaultInitializer(keys_index, keys_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::L2NormalizationNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::L2NormalizationNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::L2Pool2DNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::L2Pool2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LocalResponseNormalizationNode &node)
-{
-  const auto &input_index =
-      node.getInputs().at(model::operation::LocalResponseNormalizationNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LogicalAndNode &node)
-{
-  const auto &input0_index = node.getInputs().at(model::operation::LogicalAndNode::INPUT0);
-  const auto &input0_obj = _operands.at(input0_index);
-  registerPermuteInitializer(input0_index, input0_obj);
-
-  const auto &input1_index = node.getInputs().at(model::operation::LogicalAndNode::INPUT1);
-  const auto &input1_obj = _operands.at(input1_index);
-  registerPermuteInitializer(input1_index, input1_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LogicalNotNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::LogicalNotNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LogicalOrNode &node)
-{
-  const auto &input0_index = node.getInputs().at(model::operation::LogicalOrNode::INPUT0);
-  const auto &input0_obj = _operands.at(input0_index);
-  registerPermuteInitializer(input0_index, input0_obj);
-
-  const auto &input1_index = node.getInputs().at(model::operation::LogicalOrNode::INPUT1);
-  const auto &input1_obj = _operands.at(input1_index);
-  registerPermuteInitializer(input1_index, input1_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LogisticNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::LogisticNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
+  registerCopyInitializer(keys_index, keys_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::LSTMNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::LSTMNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
-  const auto &output_state_in_index =
-      node.getInputs().at(model::operation::LSTMNode::OUTPUT_STATE_IN);
-  const auto &output_state_in_obj = _operands.at(output_state_in_index);
-  registerPermuteInitializer(output_state_in_index, output_state_in_obj);
-
-  const auto &cell_state_in_index = node.getInputs().at(model::operation::LSTMNode::CELL_STATE_IN);
-  const auto &cell_state_in_obj = _operands.at(cell_state_in_index);
-  registerPermuteInitializer(cell_state_in_index, cell_state_in_obj);
-
   const auto &input_to_input_weights_index =
       node.getInputs().at(model::operation::LSTMNode::INPUT_TO_INPUT_WEIGHTS);
   const auto &input_to_input_weights_obj = _operands.at(input_to_input_weights_index);
-  registerDefaultInitializer(input_to_input_weights_index, input_to_input_weights_obj);
+  registerCopyInitializer(input_to_input_weights_index, input_to_input_weights_obj);
 
   const auto &input_to_forget_weights_index =
       node.getInputs().at(model::operation::LSTMNode::INPUT_TO_FORGET_WEIGHTS);
   const auto &input_to_forget_weights_obj = _operands.at(input_to_forget_weights_index);
-  registerDefaultInitializer(input_to_forget_weights_index, input_to_forget_weights_obj);
+  registerCopyInitializer(input_to_forget_weights_index, input_to_forget_weights_obj);
 
   const auto &input_to_cell_weights_index =
       node.getInputs().at(model::operation::LSTMNode::INPUT_TO_CELL_WEIGHTS);
   const auto &input_to_cell_weights_obj = _operands.at(input_to_cell_weights_index);
-  registerDefaultInitializer(input_to_cell_weights_index, input_to_cell_weights_obj);
+  registerCopyInitializer(input_to_cell_weights_index, input_to_cell_weights_obj);
 
   const auto &input_to_output_weights_index =
       node.getInputs().at(model::operation::LSTMNode::INPUT_TO_OUTPUT_WEIGHTS);
   const auto &input_to_output_weights_obj = _operands.at(input_to_output_weights_index);
-  registerDefaultInitializer(input_to_output_weights_index, input_to_output_weights_obj);
+  registerCopyInitializer(input_to_output_weights_index, input_to_output_weights_obj);
 
   const auto &recurrent_to_input_weights_index =
       node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_INPUT_WEIGHTS);
   const auto &recurrent_to_input_weights_obj = _operands.at(recurrent_to_input_weights_index);
-  registerDefaultInitializer(recurrent_to_input_weights_index, recurrent_to_input_weights_obj);
+  registerCopyInitializer(recurrent_to_input_weights_index, recurrent_to_input_weights_obj);
 
   const auto &recurrent_to_forget_weights_index =
       node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_FORGET_WEIGHTS);
   const auto &recurrent_to_forget_weights_obj = _operands.at(recurrent_to_forget_weights_index);
-  registerDefaultInitializer(recurrent_to_forget_weights_index, recurrent_to_forget_weights_obj);
+  registerCopyInitializer(recurrent_to_forget_weights_index, recurrent_to_forget_weights_obj);
 
   const auto &recurrent_to_cell_weights_index =
       node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_CELL_WEIGHTS);
   const auto &recurrent_to_cell_weights_obj = _operands.at(recurrent_to_cell_weights_index);
-  registerDefaultInitializer(recurrent_to_cell_weights_index, recurrent_to_cell_weights_obj);
+  registerCopyInitializer(recurrent_to_cell_weights_index, recurrent_to_cell_weights_obj);
 
   const auto &recurrent_to_output_weights_index =
       node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_OUTPUT_WEIGHTS);
   const auto &recurrent_to_output_weights_obj = _operands.at(recurrent_to_output_weights_index);
-  registerDefaultInitializer(recurrent_to_output_weights_index, recurrent_to_output_weights_obj);
+  registerCopyInitializer(recurrent_to_output_weights_index, recurrent_to_output_weights_obj);
 
   const auto &cell_to_input_weights_index =
       node.getInputs().at(model::operation::LSTMNode::CELL_TO_INPUT_WEIGHTS);
   const auto &cell_to_input_weights_obj = _operands.at(cell_to_input_weights_index);
-  registerDefaultInitializer(cell_to_input_weights_index, cell_to_input_weights_obj);
+  registerCopyInitializer(cell_to_input_weights_index, cell_to_input_weights_obj);
 
   const auto &cell_to_forget_weights_index =
       node.getInputs().at(model::operation::LSTMNode::CELL_TO_FORGET_WEIGHTS);
   const auto &cell_to_forget_weights_obj = _operands.at(cell_to_forget_weights_index);
-  registerDefaultInitializer(cell_to_forget_weights_index, cell_to_forget_weights_obj);
+  registerCopyInitializer(cell_to_forget_weights_index, cell_to_forget_weights_obj);
 
   const auto &cell_to_output_weights_index =
       node.getInputs().at(model::operation::LSTMNode::CELL_TO_OUTPUT_WEIGHTS);
   const auto &cell_to_output_weights_obj = _operands.at(cell_to_output_weights_index);
-  registerDefaultInitializer(cell_to_output_weights_index, cell_to_output_weights_obj);
+  registerCopyInitializer(cell_to_output_weights_index, cell_to_output_weights_obj);
 
   const auto &input_gate_bias_index =
       node.getInputs().at(model::operation::LSTMNode::INPUT_GATE_BIAS);
   const auto &input_gate_bias_obj = _operands.at(input_gate_bias_index);
-  registerDefaultInitializer(input_gate_bias_index, input_gate_bias_obj);
+  registerCopyInitializer(input_gate_bias_index, input_gate_bias_obj);
 
   const auto &forget_gate_bias_index =
       node.getInputs().at(model::operation::LSTMNode::FORGET_GATE_BIAS);
   const auto &forget_gate_bias_obj = _operands.at(forget_gate_bias_index);
-  registerDefaultInitializer(forget_gate_bias_index, forget_gate_bias_obj);
+  registerCopyInitializer(forget_gate_bias_index, forget_gate_bias_obj);
 
   const auto &output_gate_bias_index =
       node.getInputs().at(model::operation::LSTMNode::OUTPUT_GATE_BIAS);
   const auto &output_gate_bias_obj = _operands.at(output_gate_bias_index);
-  registerDefaultInitializer(output_gate_bias_index, output_gate_bias_obj);
+  registerCopyInitializer(output_gate_bias_index, output_gate_bias_obj);
 
   const auto &projection_weights_index =
       node.getInputs().at(model::operation::LSTMNode::PROJECTION_WEIGHTS);
   const auto &projection_weights_obj = _operands.at(projection_weights_index);
-  registerDefaultInitializer(projection_weights_index, projection_weights_obj);
+  registerCopyInitializer(projection_weights_index, projection_weights_obj);
 
   const auto &projection_bias_index =
       node.getInputs().at(model::operation::LSTMNode::PROJECTION_BIAS);
   const auto &projection_bias_obj = _operands.at(projection_bias_index);
-  registerDefaultInitializer(projection_bias_index, projection_bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MaxPool2DNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::MaxPool2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MeanNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::MeanNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MulNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::MulNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::MulNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::NegNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::NegNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::PadNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::PadNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::PReLUNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::PReLUNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
-  const auto &alpha_index = node.getInputs().at(model::operation::PReLUNode::ALPHA);
-  const auto &alpha_obj = _operands.at(alpha_index);
-  registerPermuteInitializer(alpha_index, alpha_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReduceMaxNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ReduceMaxNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReduceMinNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ReduceMinNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReduceSumNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ReduceSumNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReLUNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ReLUNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReLU1Node &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ReLU1Node::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReLU6Node &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ReLU6Node::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReshapeNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ReshapeNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ResizeBilinearNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ResizeBilinearNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
+  registerCopyInitializer(projection_bias_index, projection_bias_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::RNNNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::RNNNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &weights_index = node.getInputs().at(model::operation::RNNNode::WEIGHTS);
   const auto &weights_obj = _operands.at(weights_index);
-  registerDefaultInitializer(weights_index, weights_obj);
+  registerCopyInitializer(weights_index, weights_obj);
 
   const auto &recurrent_weights_index =
       node.getInputs().at(model::operation::RNNNode::RECURRENT_WEIGHTS);
   const auto &recurrent_weights_obj = _operands.at(recurrent_weights_index);
-  registerDefaultInitializer(recurrent_weights_index, recurrent_weights_obj);
+  registerCopyInitializer(recurrent_weights_index, recurrent_weights_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::RNNNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::RSQRTNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::RSQRTNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SoftmaxNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::SoftmaxNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SpaceToDepthNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::SpaceToDepthNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SplitNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::SplitNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SQRTNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::SQRTNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SquaredDifferenceNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SqueezeNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::SqueezeNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::StridedSliceNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::StridedSliceNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SubNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::SubNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::SubNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::TanhNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::TanhNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::TopKV2Node &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::TopKV2Node::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::TransposeConvNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::TransposeConvNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &kernel_index = node.getInputs().at(model::operation::TransposeConvNode::KERNEL);
   const auto &kernel_obj = _operands.at(kernel_index);
-  registerDefaultInitializer(kernel_index, kernel_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::TransposeNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::TransposeNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::UnpackNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::UnpackNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
+  registerCopyInitializer(kernel_index, kernel_obj);
 }
 
 } // namespace acl_cl
index f1d243d..59772e0 100644 (file)
@@ -38,61 +38,15 @@ public:
   void run() override;
 
 public:
-  void visit(const model::operation::AbsNode &) override;
-  void visit(const model::operation::AddNode &) override;
-  void visit(const model::operation::ArgMaxNode &) override;
-  void visit(const model::operation::AvgPool2DNode &) override;
-  void visit(const model::operation::CastNode &) override;
-  void visit(const model::operation::ComparisonNode &) override;
-  void visit(const model::operation::ConcatNode &) override;
   void visit(const model::operation::Conv2DNode &) override;
-  void visit(const model::operation::DepthToSpaceNode &) override;
   void visit(const model::operation::DepthwiseConv2DNode &) override;
-  void visit(const model::operation::DequantizeNode &) override;
-  void visit(const model::operation::DivNode &) override;
   void visit(const model::operation::EmbeddingLookupNode &) override;
-  void visit(const model::operation::ExpNode &) override;
-  void visit(const model::operation::FloorNode &) override;
   void visit(const model::operation::FullyConnectedNode &) override;
   void visit(const model::operation::GatherNode &) override;
   void visit(const model::operation::HashtableLookupNode &) override;
-  void visit(const model::operation::L2NormalizationNode &) override;
-  void visit(const model::operation::L2Pool2DNode &) override;
-  void visit(const model::operation::LocalResponseNormalizationNode &) override;
-  void visit(const model::operation::LogicalAndNode &) override;
-  void visit(const model::operation::LogicalNotNode &) override;
-  void visit(const model::operation::LogicalOrNode &) override;
-  void visit(const model::operation::LogisticNode &) override;
   void visit(const model::operation::LSTMNode &) override;
-  void visit(const model::operation::MaxPool2DNode &) override;
-  void visit(const model::operation::MeanNode &) override;
-  void visit(const model::operation::MulNode &) override;
-  void visit(const model::operation::NegNode &) override;
-  void visit(const model::operation::PadNode &) override;
-  void visit(const model::operation::PReLUNode &) override;
-  void visit(const model::operation::ReduceMaxNode &) override;
-  void visit(const model::operation::ReduceMinNode &) override;
-  void visit(const model::operation::ReduceSumNode &) override;
-  void visit(const model::operation::ReLUNode &) override;
-  void visit(const model::operation::ReLU1Node &) override;
-  void visit(const model::operation::ReLU6Node &) override;
-  void visit(const model::operation::ReshapeNode &node) override;
-  void visit(const model::operation::ResizeBilinearNode &) override;
   void visit(const model::operation::RNNNode &) override;
-  void visit(const model::operation::RSQRTNode &) override;
-  void visit(const model::operation::SoftmaxNode &node) override;
-  void visit(const model::operation::SpaceToDepthNode &) override;
-  void visit(const model::operation::SplitNode &) override;
-  void visit(const model::operation::SQRTNode &) override;
-  void visit(const model::operation::SquaredDifferenceNode &) override;
-  void visit(const model::operation::SqueezeNode &) override;
-  void visit(const model::operation::StridedSliceNode &) override;
-  void visit(const model::operation::SubNode &) override;
-  void visit(const model::operation::TanhNode &) override;
-  void visit(const model::operation::TopKV2Node &) override;
   void visit(const model::operation::TransposeConvNode &) override;
-  void visit(const model::operation::TransposeNode &) override;
-  void visit(const model::operation::UnpackNode &) override;
 
 private:
   const model::Operands &_operands;
index 4ed578f..98be80b 100644 (file)
@@ -41,155 +41,41 @@ void ConstantInitializer::run()
     auto tensor_obj = _tensor_builder->wrapTensor(ind);
     fn(model_obj, *tensor_obj);
   }
-  _init_map.clear();
-}
-
-void ConstantInitializer::visit(const model::operation::AddNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::AddNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::AddNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
 
-void ConstantInitializer::visit(const model::operation::AvgPool2DNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::AvgPool2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ConcatNode &node)
-{
-  const auto inputs = node.getInputs();
-  for (const auto &input_index : inputs)
-  {
-    const auto &input_obj = _operands.at(input_index);
-    registerPermuteInitializer(input_index, input_obj);
-  }
+  _init_map.clear();
 }
 
 void ConstantInitializer::visit(const model::operation::Conv2DNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::Conv2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &kernel_index = node.getInputs().at(model::operation::Conv2DNode::KERNEL);
   const auto &kernel_obj = _operands.at(kernel_index);
   registerPermuteInitializer(kernel_index, kernel_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::Conv2DNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::DepthwiseConv2DNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &kernel_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::KERNEL);
   const auto &kernel_obj = _operands.at(kernel_index);
   registerPermuteInitializer(kernel_index, kernel_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::FullyConnectedNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &weight_index = node.getInputs().at(model::operation::FullyConnectedNode::WEIGHT);
   const auto &weight_obj = _operands.at(weight_index);
-  registerDefaultInitializer(weight_index, weight_obj);
+  registerCopyInitializer(weight_index, weight_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::FullyConnectedNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MaxPool2DNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::MaxPool2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MeanNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::MeanNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MulNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::MulNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::MulNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReshapeNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ReshapeNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::RSQRTNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::RSQRTNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SoftmaxNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::SoftmaxNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SquaredDifferenceNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SubNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::SubNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::SubNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::TanhNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::TanhNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 } // namespace acl_neon
index 91c818f..62e889c 100644 (file)
@@ -38,21 +38,9 @@ public:
   void run() override;
 
 public:
-  void visit(const model::operation::AddNode &) override;
-  void visit(const model::operation::AvgPool2DNode &) override;
-  void visit(const model::operation::ConcatNode &) override;
   void visit(const model::operation::Conv2DNode &) override;
   void visit(const model::operation::DepthwiseConv2DNode &) override;
   void visit(const model::operation::FullyConnectedNode &) override;
-  void visit(const model::operation::MaxPool2DNode &) override;
-  void visit(const model::operation::MeanNode &) override;
-  void visit(const model::operation::MulNode &) override;
-  void visit(const model::operation::ReshapeNode &) override;
-  void visit(const model::operation::RSQRTNode &) override;
-  void visit(const model::operation::SoftmaxNode &) override;
-  void visit(const model::operation::SquaredDifferenceNode &) override;
-  void visit(const model::operation::SubNode &) override;
-  void visit(const model::operation::TanhNode &) override;
 
 private:
   const model::Operands &_operands;
index 8997933..cff9fcf 100644 (file)
@@ -41,118 +41,41 @@ void ConstantInitializer::run()
     auto tensor_obj = _tensor_builder->wrapTensor(ind);
     fn(model_obj, *tensor_obj);
   }
-  _init_map.clear();
-}
-
-void ConstantInitializer::visit(const model::operation::AddNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::AddNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::AddNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::AvgPool2DNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::AvgPool2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
 
-void ConstantInitializer::visit(const model::operation::ConcatNode &node)
-{
-  const auto inputs = node.getInputs();
-  for (const auto &input_index : inputs)
-  {
-    const auto &input_obj = _operands.at(input_index);
-    registerPermuteInitializer(input_index, input_obj);
-  }
+  _init_map.clear();
 }
 
 void ConstantInitializer::visit(const model::operation::Conv2DNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::Conv2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &kernel_index = node.getInputs().at(model::operation::Conv2DNode::KERNEL);
   const auto &kernel_obj = _operands.at(kernel_index);
-  registerDefaultInitializer(kernel_index, kernel_obj);
+  registerCopyInitializer(kernel_index, kernel_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::Conv2DNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::DepthwiseConv2DNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &kernel_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::KERNEL);
   const auto &kernel_obj = _operands.at(kernel_index);
-  registerDefaultInitializer(kernel_index, kernel_obj);
+  registerCopyInitializer(kernel_index, kernel_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node)
 {
-  const auto &input_index = node.getInputs().at(model::operation::FullyConnectedNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-
   const auto &weight_index = node.getInputs().at(model::operation::FullyConnectedNode::WEIGHT);
   const auto &weight_obj = _operands.at(weight_index);
-  registerDefaultInitializer(weight_index, weight_obj);
+  registerCopyInitializer(weight_index, weight_obj);
 
   const auto &bias_index = node.getInputs().at(model::operation::FullyConnectedNode::BIAS);
   const auto &bias_obj = _operands.at(bias_index);
-  registerDefaultInitializer(bias_index, bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MaxPool2DNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::MaxPool2DNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MulNode &node)
-{
-  const auto &lhs_index = node.getInputs().at(model::operation::MulNode::LHS);
-  const auto &lhs_obj = _operands.at(lhs_index);
-  registerPermuteInitializer(lhs_index, lhs_obj);
-
-  const auto &rhs_index = node.getInputs().at(model::operation::MulNode::RHS);
-  const auto &rhs_obj = _operands.at(rhs_index);
-  registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReshapeNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::ReshapeNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SoftmaxNode &node)
-{
-  const auto &input_index = node.getInputs().at(model::operation::SoftmaxNode::INPUT);
-  const auto &input_obj = _operands.at(input_index);
-  registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::PermuteNode &)
-{
-  // DO NOTHING
-  // This node's constant doesn't exist.
+  registerCopyInitializer(bias_index, bias_obj);
 }
 
 } // namespace cpu
index 4f9794c..91f1d50 100644 (file)
@@ -38,17 +38,9 @@ public:
   void run() override;
 
 public:
-  void visit(const model::operation::AddNode &) override;
-  void visit(const model::operation::AvgPool2DNode &) override;
-  void visit(const model::operation::ConcatNode &) override;
   void visit(const model::operation::Conv2DNode &) override;
   void visit(const model::operation::DepthwiseConv2DNode &) override;
   void visit(const model::operation::FullyConnectedNode &) override;
-  void visit(const model::operation::MaxPool2DNode &) override;
-  void visit(const model::operation::MulNode &) override;
-  void visit(const model::operation::ReshapeNode &) override;
-  void visit(const model::operation::SoftmaxNode &) override;
-  void visit(const model::operation::PermuteNode &) override;
 
 private:
   const model::Operands &_operands;
index 0545391..0c767be 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "ITensorBuilder.h"
 #include "model/Operand.h"
+#include "model/Operands.h"
 #include "model/OperationVisitor.h"
 #include "model/Subgraph.h"
 #include "util/logging.h"
@@ -30,7 +31,7 @@
 namespace
 {
 template <typename T>
-void defaultInit(const neurun::model::Operand &model_obj, neurun::backend::operand::IObject &obj)
+void copyInit(const neurun::model::Operand &model_obj, neurun::backend::operand::IObject &obj)
 {
   const auto shape = model_obj.shape();
   auto base = reinterpret_cast<const T *>(model_obj.data().base());
@@ -207,19 +208,30 @@ public:
 public:
   using Initializer = std::function<void(const model::Operand &, backend::operand::IObject &)>;
 
-  void generate(const model::Subgraph &subg) { subg.accept(*this); }
+  void generate(const model::Subgraph &subg, const model::Operands &operands)
+  {
+    subg.accept(*this);
+    for (const auto &e : subg.operations())
+    {
+      for (const auto &ind : e.node->getInputs())
+      {
+        const auto &obj = operands.at(ind);
+        if (obj.isConstant() && !exist(ind))
+        {
+          registerPermuteInitializer(ind, obj);
+        }
+      }
+    }
+  }
 
 protected:
-#define OP(InternalName, IsNnApi)                                     \
-  virtual void visit(const model::operation::InternalName &) override \
-  {                                                                   \
-    throw std::runtime_error("NYI");                                  \
-  }
+#define OP(InternalName, IsNnApi) \
+  virtual void visit(const model::operation::InternalName &) override { /* DO NOTHING */}
 #include "model/Operations.lst"
 #undef OP
 
 protected:
-  void registerDefaultInitializer(const model::OperandIndex &index, const model::Operand &obj)
+  void registerCopyInitializer(const model::OperandIndex &index, const model::Operand &obj)
   {
     // For only CONSTANTS
     if (!obj.isConstant())
@@ -233,17 +245,17 @@ protected:
     switch (type)
     {
       case DataType::FLOAT32:
-        _init_map[index] = defaultInit<float>;
+        _init_map[index] = copyInit<float>;
         break;
       case DataType::INT32:
-        _init_map[index] = defaultInit<int32_t>;
+        _init_map[index] = copyInit<int32_t>;
         break;
       case DataType::UINT32:
-        _init_map[index] = defaultInit<uint32_t>;
+        _init_map[index] = copyInit<uint32_t>;
         break;
       case DataType::BOOL8:
       case DataType::QUANT8_ASYMM:
-        _init_map[index] = defaultInit<uint8_t>;
+        _init_map[index] = copyInit<uint8_t>;
         break;
       default:
         throw std::runtime_error("Not supported, yet");
@@ -284,6 +296,9 @@ protected:
     }
   }
 
+private:
+  bool exist(const model::OperandIndex &ind) { return _init_map.find(ind) != _init_map.end(); }
+
 protected:
   std::unordered_map<model::OperandIndex, Initializer> _init_map;
 };
index 4328c00..efe3fd3 100644 (file)
@@ -226,7 +226,7 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(graph::Graph &graph, bo
       [&](const model::SubgraphIndex &subg_index, const model::Subgraph &subg) {
         auto backend = graph.getLowerInfo(subg_index)->backend();
         auto constant_initializer = backend->constant_initializer();
-        constant_initializer->generate(subg);
+        constant_initializer->generate(subg, graph.operands());
         // TODO This approach is temporal. See declaration of `setNextIndex`.
         execution_builder->setNextIndex(subg_index);
         auto kernel_gen = backend->kernel_gen();
index fbef2cc..9895c86 100644 (file)
@@ -18,7 +18,6 @@
 
 #include "backend/operand/IObject.h"
 #include "backend/Backend.h"
-#include "backend/IConstantInitializer.h"
 #include "backend/IKernelGenerator.h"
 #include "linear/Linear.h"
 
@@ -42,12 +41,13 @@ void PlanBuilder::finalize(const linear::Linear *linear,
     });
   }
 
+  // Generate initializers
+  linear->generateConstantInitializers();
+
   // Generate kernels
   auto execution_builder = nnfw::cpp14::make_unique<ExecutionBuilder>(_functions);
   linear->iterate([&](const linear::Element &element) {
     auto backend = element.lower_info->backend();
-    auto constant_initializer = backend->constant_initializer();
-    constant_initializer->generate(*element.subgraph);
     auto kernel_gen = backend->kernel_gen();
     kernel_gen->generate(*element.subgraph, execution_builder.get());
   });
index 94d1089..0e28d74 100644 (file)
@@ -22,6 +22,7 @@
 #include "graph/operand/LowerInfo.h"
 #include "backend/IShapeFixer.h"
 #include "backend/IConfig.h"
+#include "backend/IConstantInitializer.h"
 #include "backend/Backend.h"
 #include "compiler/SubTensorInfo.h"
 #include "model/OperandInfo.h"
@@ -309,6 +310,15 @@ void Linear::iterate(const std::function<void(const Element &element)> &fn) cons
   }
 }
 
+void Linear::generateConstantInitializers(void) const
+{
+  iterate([&](const linear::Element &element) {
+    auto backend = element.lower_info->backend();
+    auto constant_initializer = backend->constant_initializer();
+    constant_initializer->generate(*element.subgraph, _model->operands);
+  });
+}
+
 const graph::operation::LowerInfo *Linear::getLowerInfo(const model::SubgraphIndex &index) const
 {
   if (!_lower_info_map)
index ed82c48..45c8489 100644 (file)
@@ -60,6 +60,8 @@ public:
 
   void iterate(const std::function<void(const Element &element)> &fn) const;
 
+  void generateConstantInitializers(void) const;
+
   std::unique_ptr<graph::LowerInfoMap> releaseLowerInfo() { return std::move(_lower_info_map); }
   graph::LowerInfoMap *getLowerInfo() { return _lower_info_map.get(); }