auto tensor_obj = _tensor_builder->wrapTensor(ind);
fn(model_obj, *tensor_obj);
}
- _init_map.clear();
-}
-
-void ConstantInitializer::visit(const model::operation::AbsNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::AbsNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-void ConstantInitializer::visit(const model::operation::AddNode &node)
-{
- const auto &lhs_index = node.getInputs().at(model::operation::AddNode::LHS);
- const auto &lhs_obj = _operands.at(lhs_index);
- registerPermuteInitializer(lhs_index, lhs_obj);
-
- const auto &rhs_index = node.getInputs().at(model::operation::AddNode::RHS);
- const auto &rhs_obj = _operands.at(rhs_index);
- registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ArgMaxNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ArgMaxNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::AvgPool2DNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::AvgPool2DNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::CastNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::CastNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ComparisonNode &node)
-{
- const auto &input0_index = node.getInputs().at(model::operation::ComparisonNode::INPUT0);
- const auto &input0_obj = _operands.at(input0_index);
- registerPermuteInitializer(input0_index, input0_obj);
-
- const auto &input1_index = node.getInputs().at(model::operation::ComparisonNode::INPUT1);
- const auto &input1_obj = _operands.at(input1_index);
- registerPermuteInitializer(input1_index, input1_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ConcatNode &node)
-{
- const auto inputs = node.getInputs();
- for (const auto &input_index : inputs)
- {
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
- }
+ _init_map.clear();
}
void ConstantInitializer::visit(const model::operation::Conv2DNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::Conv2DNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &kernel_index = node.getInputs().at(model::operation::Conv2DNode::KERNEL);
const auto &kernel_obj = _operands.at(kernel_index);
registerPermuteInitializer(kernel_index, kernel_obj);
const auto &bias_index = node.getInputs().at(model::operation::Conv2DNode::BIAS);
const auto &bias_obj = _operands.at(bias_index);
- registerDefaultInitializer(bias_index, bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::DepthToSpaceNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::DepthToSpaceNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
+ registerCopyInitializer(bias_index, bias_obj);
}
void ConstantInitializer::visit(const model::operation::DepthwiseConv2DNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &kernel_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::KERNEL);
const auto &kernel_obj = _operands.at(kernel_index);
registerPermuteInitializer(kernel_index, kernel_obj);
const auto &bias_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::BIAS);
const auto &bias_obj = _operands.at(bias_index);
- registerDefaultInitializer(bias_index, bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::DequantizeNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::DequantizeNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::DivNode &node)
-{
- const auto &lhs_index = node.getInputs().at(model::operation::DivNode::LHS);
- const auto &lhs_obj = _operands.at(lhs_index);
- registerPermuteInitializer(lhs_index, lhs_obj);
-
- const auto &rhs_index = node.getInputs().at(model::operation::DivNode::RHS);
- const auto &rhs_obj = _operands.at(rhs_index);
- registerPermuteInitializer(rhs_index, rhs_obj);
+ registerCopyInitializer(bias_index, bias_obj);
}
void ConstantInitializer::visit(const model::operation::EmbeddingLookupNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::EmbeddingLookupNode::VALUES);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &lookups_index = node.getInputs().at(model::operation::EmbeddingLookupNode::LOOKUPS);
const auto &lookups_obj = _operands.at(lookups_index);
- registerDefaultInitializer(lookups_index, lookups_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ExpNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ExpNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::FloorNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::FloorNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
+ registerCopyInitializer(lookups_index, lookups_obj);
}
void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::FullyConnectedNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &weight_index = node.getInputs().at(model::operation::FullyConnectedNode::WEIGHT);
const auto &weight_obj = _operands.at(weight_index);
- registerDefaultInitializer(weight_index, weight_obj);
+ registerCopyInitializer(weight_index, weight_obj);
const auto &bias_index = node.getInputs().at(model::operation::FullyConnectedNode::BIAS);
const auto &bias_obj = _operands.at(bias_index);
- registerDefaultInitializer(bias_index, bias_obj);
+ registerCopyInitializer(bias_index, bias_obj);
}
void ConstantInitializer::visit(const model::operation::GatherNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::GatherNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &indices_index = node.getInputs().at(model::operation::GatherNode::INDICES);
const auto &indices_obj = _operands.at(indices_index);
- registerDefaultInitializer(indices_index, indices_obj);
+ registerCopyInitializer(indices_index, indices_obj);
}
void ConstantInitializer::visit(const model::operation::HashtableLookupNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::HashtableLookupNode::VALUES);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &lookups_index = node.getInputs().at(model::operation::HashtableLookupNode::LOOKUPS);
const auto &lookups_obj = _operands.at(lookups_index);
- registerDefaultInitializer(lookups_index, lookups_obj);
+ registerCopyInitializer(lookups_index, lookups_obj);
const auto &keys_index = node.getInputs().at(model::operation::HashtableLookupNode::KEYS);
const auto &keys_obj = _operands.at(keys_index);
- registerDefaultInitializer(keys_index, keys_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::L2NormalizationNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::L2NormalizationNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::L2Pool2DNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::L2Pool2DNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LocalResponseNormalizationNode &node)
-{
- const auto &input_index =
- node.getInputs().at(model::operation::LocalResponseNormalizationNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LogicalAndNode &node)
-{
- const auto &input0_index = node.getInputs().at(model::operation::LogicalAndNode::INPUT0);
- const auto &input0_obj = _operands.at(input0_index);
- registerPermuteInitializer(input0_index, input0_obj);
-
- const auto &input1_index = node.getInputs().at(model::operation::LogicalAndNode::INPUT1);
- const auto &input1_obj = _operands.at(input1_index);
- registerPermuteInitializer(input1_index, input1_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LogicalNotNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::LogicalNotNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LogicalOrNode &node)
-{
- const auto &input0_index = node.getInputs().at(model::operation::LogicalOrNode::INPUT0);
- const auto &input0_obj = _operands.at(input0_index);
- registerPermuteInitializer(input0_index, input0_obj);
-
- const auto &input1_index = node.getInputs().at(model::operation::LogicalOrNode::INPUT1);
- const auto &input1_obj = _operands.at(input1_index);
- registerPermuteInitializer(input1_index, input1_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::LogisticNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::LogisticNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
+ registerCopyInitializer(keys_index, keys_obj);
}
void ConstantInitializer::visit(const model::operation::LSTMNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::LSTMNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
- const auto &output_state_in_index =
- node.getInputs().at(model::operation::LSTMNode::OUTPUT_STATE_IN);
- const auto &output_state_in_obj = _operands.at(output_state_in_index);
- registerPermuteInitializer(output_state_in_index, output_state_in_obj);
-
- const auto &cell_state_in_index = node.getInputs().at(model::operation::LSTMNode::CELL_STATE_IN);
- const auto &cell_state_in_obj = _operands.at(cell_state_in_index);
- registerPermuteInitializer(cell_state_in_index, cell_state_in_obj);
-
const auto &input_to_input_weights_index =
node.getInputs().at(model::operation::LSTMNode::INPUT_TO_INPUT_WEIGHTS);
const auto &input_to_input_weights_obj = _operands.at(input_to_input_weights_index);
- registerDefaultInitializer(input_to_input_weights_index, input_to_input_weights_obj);
+ registerCopyInitializer(input_to_input_weights_index, input_to_input_weights_obj);
const auto &input_to_forget_weights_index =
node.getInputs().at(model::operation::LSTMNode::INPUT_TO_FORGET_WEIGHTS);
const auto &input_to_forget_weights_obj = _operands.at(input_to_forget_weights_index);
- registerDefaultInitializer(input_to_forget_weights_index, input_to_forget_weights_obj);
+ registerCopyInitializer(input_to_forget_weights_index, input_to_forget_weights_obj);
const auto &input_to_cell_weights_index =
node.getInputs().at(model::operation::LSTMNode::INPUT_TO_CELL_WEIGHTS);
const auto &input_to_cell_weights_obj = _operands.at(input_to_cell_weights_index);
- registerDefaultInitializer(input_to_cell_weights_index, input_to_cell_weights_obj);
+ registerCopyInitializer(input_to_cell_weights_index, input_to_cell_weights_obj);
const auto &input_to_output_weights_index =
node.getInputs().at(model::operation::LSTMNode::INPUT_TO_OUTPUT_WEIGHTS);
const auto &input_to_output_weights_obj = _operands.at(input_to_output_weights_index);
- registerDefaultInitializer(input_to_output_weights_index, input_to_output_weights_obj);
+ registerCopyInitializer(input_to_output_weights_index, input_to_output_weights_obj);
const auto &recurrent_to_input_weights_index =
node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_INPUT_WEIGHTS);
const auto &recurrent_to_input_weights_obj = _operands.at(recurrent_to_input_weights_index);
- registerDefaultInitializer(recurrent_to_input_weights_index, recurrent_to_input_weights_obj);
+ registerCopyInitializer(recurrent_to_input_weights_index, recurrent_to_input_weights_obj);
const auto &recurrent_to_forget_weights_index =
node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_FORGET_WEIGHTS);
const auto &recurrent_to_forget_weights_obj = _operands.at(recurrent_to_forget_weights_index);
- registerDefaultInitializer(recurrent_to_forget_weights_index, recurrent_to_forget_weights_obj);
+ registerCopyInitializer(recurrent_to_forget_weights_index, recurrent_to_forget_weights_obj);
const auto &recurrent_to_cell_weights_index =
node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_CELL_WEIGHTS);
const auto &recurrent_to_cell_weights_obj = _operands.at(recurrent_to_cell_weights_index);
- registerDefaultInitializer(recurrent_to_cell_weights_index, recurrent_to_cell_weights_obj);
+ registerCopyInitializer(recurrent_to_cell_weights_index, recurrent_to_cell_weights_obj);
const auto &recurrent_to_output_weights_index =
node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_OUTPUT_WEIGHTS);
const auto &recurrent_to_output_weights_obj = _operands.at(recurrent_to_output_weights_index);
- registerDefaultInitializer(recurrent_to_output_weights_index, recurrent_to_output_weights_obj);
+ registerCopyInitializer(recurrent_to_output_weights_index, recurrent_to_output_weights_obj);
const auto &cell_to_input_weights_index =
node.getInputs().at(model::operation::LSTMNode::CELL_TO_INPUT_WEIGHTS);
const auto &cell_to_input_weights_obj = _operands.at(cell_to_input_weights_index);
- registerDefaultInitializer(cell_to_input_weights_index, cell_to_input_weights_obj);
+ registerCopyInitializer(cell_to_input_weights_index, cell_to_input_weights_obj);
const auto &cell_to_forget_weights_index =
node.getInputs().at(model::operation::LSTMNode::CELL_TO_FORGET_WEIGHTS);
const auto &cell_to_forget_weights_obj = _operands.at(cell_to_forget_weights_index);
- registerDefaultInitializer(cell_to_forget_weights_index, cell_to_forget_weights_obj);
+ registerCopyInitializer(cell_to_forget_weights_index, cell_to_forget_weights_obj);
const auto &cell_to_output_weights_index =
node.getInputs().at(model::operation::LSTMNode::CELL_TO_OUTPUT_WEIGHTS);
const auto &cell_to_output_weights_obj = _operands.at(cell_to_output_weights_index);
- registerDefaultInitializer(cell_to_output_weights_index, cell_to_output_weights_obj);
+ registerCopyInitializer(cell_to_output_weights_index, cell_to_output_weights_obj);
const auto &input_gate_bias_index =
node.getInputs().at(model::operation::LSTMNode::INPUT_GATE_BIAS);
const auto &input_gate_bias_obj = _operands.at(input_gate_bias_index);
- registerDefaultInitializer(input_gate_bias_index, input_gate_bias_obj);
+ registerCopyInitializer(input_gate_bias_index, input_gate_bias_obj);
const auto &forget_gate_bias_index =
node.getInputs().at(model::operation::LSTMNode::FORGET_GATE_BIAS);
const auto &forget_gate_bias_obj = _operands.at(forget_gate_bias_index);
- registerDefaultInitializer(forget_gate_bias_index, forget_gate_bias_obj);
+ registerCopyInitializer(forget_gate_bias_index, forget_gate_bias_obj);
const auto &output_gate_bias_index =
node.getInputs().at(model::operation::LSTMNode::OUTPUT_GATE_BIAS);
const auto &output_gate_bias_obj = _operands.at(output_gate_bias_index);
- registerDefaultInitializer(output_gate_bias_index, output_gate_bias_obj);
+ registerCopyInitializer(output_gate_bias_index, output_gate_bias_obj);
const auto &projection_weights_index =
node.getInputs().at(model::operation::LSTMNode::PROJECTION_WEIGHTS);
const auto &projection_weights_obj = _operands.at(projection_weights_index);
- registerDefaultInitializer(projection_weights_index, projection_weights_obj);
+ registerCopyInitializer(projection_weights_index, projection_weights_obj);
const auto &projection_bias_index =
node.getInputs().at(model::operation::LSTMNode::PROJECTION_BIAS);
const auto &projection_bias_obj = _operands.at(projection_bias_index);
- registerDefaultInitializer(projection_bias_index, projection_bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MaxPool2DNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::MaxPool2DNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MeanNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::MeanNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MulNode &node)
-{
- const auto &lhs_index = node.getInputs().at(model::operation::MulNode::LHS);
- const auto &lhs_obj = _operands.at(lhs_index);
- registerPermuteInitializer(lhs_index, lhs_obj);
-
- const auto &rhs_index = node.getInputs().at(model::operation::MulNode::RHS);
- const auto &rhs_obj = _operands.at(rhs_index);
- registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::NegNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::NegNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::PadNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::PadNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::PReLUNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::PReLUNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
- const auto &alpha_index = node.getInputs().at(model::operation::PReLUNode::ALPHA);
- const auto &alpha_obj = _operands.at(alpha_index);
- registerPermuteInitializer(alpha_index, alpha_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReduceMaxNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ReduceMaxNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReduceMinNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ReduceMinNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReduceSumNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ReduceSumNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReLUNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ReLUNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReLU1Node &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ReLU1Node::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReLU6Node &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ReLU6Node::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReshapeNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ReshapeNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ResizeBilinearNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ResizeBilinearNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
+ registerCopyInitializer(projection_bias_index, projection_bias_obj);
}
void ConstantInitializer::visit(const model::operation::RNNNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::RNNNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &weights_index = node.getInputs().at(model::operation::RNNNode::WEIGHTS);
const auto &weights_obj = _operands.at(weights_index);
- registerDefaultInitializer(weights_index, weights_obj);
+ registerCopyInitializer(weights_index, weights_obj);
const auto &recurrent_weights_index =
node.getInputs().at(model::operation::RNNNode::RECURRENT_WEIGHTS);
const auto &recurrent_weights_obj = _operands.at(recurrent_weights_index);
- registerDefaultInitializer(recurrent_weights_index, recurrent_weights_obj);
+ registerCopyInitializer(recurrent_weights_index, recurrent_weights_obj);
const auto &bias_index = node.getInputs().at(model::operation::RNNNode::BIAS);
const auto &bias_obj = _operands.at(bias_index);
- registerDefaultInitializer(bias_index, bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::RSQRTNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::RSQRTNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SoftmaxNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::SoftmaxNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SpaceToDepthNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::SpaceToDepthNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SplitNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::SplitNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SQRTNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::SQRTNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SquaredDifferenceNode &node)
-{
- const auto &lhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::LHS);
- const auto &lhs_obj = _operands.at(lhs_index);
- registerPermuteInitializer(lhs_index, lhs_obj);
-
- const auto &rhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::RHS);
- const auto &rhs_obj = _operands.at(rhs_index);
- registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SqueezeNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::SqueezeNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::StridedSliceNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::StridedSliceNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SubNode &node)
-{
- const auto &lhs_index = node.getInputs().at(model::operation::SubNode::LHS);
- const auto &lhs_obj = _operands.at(lhs_index);
- registerPermuteInitializer(lhs_index, lhs_obj);
-
- const auto &rhs_index = node.getInputs().at(model::operation::SubNode::RHS);
- const auto &rhs_obj = _operands.at(rhs_index);
- registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::TanhNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::TanhNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::TopKV2Node &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::TopKV2Node::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
+ registerCopyInitializer(bias_index, bias_obj);
}
void ConstantInitializer::visit(const model::operation::TransposeConvNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::TransposeConvNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &kernel_index = node.getInputs().at(model::operation::TransposeConvNode::KERNEL);
const auto &kernel_obj = _operands.at(kernel_index);
- registerDefaultInitializer(kernel_index, kernel_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::TransposeNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::TransposeNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::UnpackNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::UnpackNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
+ registerCopyInitializer(kernel_index, kernel_obj);
}
} // namespace acl_cl
auto tensor_obj = _tensor_builder->wrapTensor(ind);
fn(model_obj, *tensor_obj);
}
- _init_map.clear();
-}
-
-void ConstantInitializer::visit(const model::operation::AddNode &node)
-{
- const auto &lhs_index = node.getInputs().at(model::operation::AddNode::LHS);
- const auto &lhs_obj = _operands.at(lhs_index);
- registerPermuteInitializer(lhs_index, lhs_obj);
-
- const auto &rhs_index = node.getInputs().at(model::operation::AddNode::RHS);
- const auto &rhs_obj = _operands.at(rhs_index);
- registerPermuteInitializer(rhs_index, rhs_obj);
-}
-void ConstantInitializer::visit(const model::operation::AvgPool2DNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::AvgPool2DNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ConcatNode &node)
-{
- const auto inputs = node.getInputs();
- for (const auto &input_index : inputs)
- {
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
- }
+ _init_map.clear();
}
void ConstantInitializer::visit(const model::operation::Conv2DNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::Conv2DNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &kernel_index = node.getInputs().at(model::operation::Conv2DNode::KERNEL);
const auto &kernel_obj = _operands.at(kernel_index);
registerPermuteInitializer(kernel_index, kernel_obj);
const auto &bias_index = node.getInputs().at(model::operation::Conv2DNode::BIAS);
const auto &bias_obj = _operands.at(bias_index);
- registerDefaultInitializer(bias_index, bias_obj);
+ registerCopyInitializer(bias_index, bias_obj);
}
void ConstantInitializer::visit(const model::operation::DepthwiseConv2DNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &kernel_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::KERNEL);
const auto &kernel_obj = _operands.at(kernel_index);
registerPermuteInitializer(kernel_index, kernel_obj);
const auto &bias_index = node.getInputs().at(model::operation::DepthwiseConv2DNode::BIAS);
const auto &bias_obj = _operands.at(bias_index);
- registerDefaultInitializer(bias_index, bias_obj);
+ registerCopyInitializer(bias_index, bias_obj);
}
void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node)
{
- const auto &input_index = node.getInputs().at(model::operation::FullyConnectedNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-
const auto &weight_index = node.getInputs().at(model::operation::FullyConnectedNode::WEIGHT);
const auto &weight_obj = _operands.at(weight_index);
- registerDefaultInitializer(weight_index, weight_obj);
+ registerCopyInitializer(weight_index, weight_obj);
const auto &bias_index = node.getInputs().at(model::operation::FullyConnectedNode::BIAS);
const auto &bias_obj = _operands.at(bias_index);
- registerDefaultInitializer(bias_index, bias_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MaxPool2DNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::MaxPool2DNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MeanNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::MeanNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::MulNode &node)
-{
- const auto &lhs_index = node.getInputs().at(model::operation::MulNode::LHS);
- const auto &lhs_obj = _operands.at(lhs_index);
- registerPermuteInitializer(lhs_index, lhs_obj);
-
- const auto &rhs_index = node.getInputs().at(model::operation::MulNode::RHS);
- const auto &rhs_obj = _operands.at(rhs_index);
- registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::ReshapeNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::ReshapeNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::RSQRTNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::RSQRTNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SoftmaxNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::SoftmaxNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SquaredDifferenceNode &node)
-{
- const auto &lhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::LHS);
- const auto &lhs_obj = _operands.at(lhs_index);
- registerPermuteInitializer(lhs_index, lhs_obj);
-
- const auto &rhs_index = node.getInputs().at(model::operation::SquaredDifferenceNode::RHS);
- const auto &rhs_obj = _operands.at(rhs_index);
- registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::SubNode &node)
-{
- const auto &lhs_index = node.getInputs().at(model::operation::SubNode::LHS);
- const auto &lhs_obj = _operands.at(lhs_index);
- registerPermuteInitializer(lhs_index, lhs_obj);
-
- const auto &rhs_index = node.getInputs().at(model::operation::SubNode::RHS);
- const auto &rhs_obj = _operands.at(rhs_index);
- registerPermuteInitializer(rhs_index, rhs_obj);
-}
-
-void ConstantInitializer::visit(const model::operation::TanhNode &node)
-{
- const auto &input_index = node.getInputs().at(model::operation::TanhNode::INPUT);
- const auto &input_obj = _operands.at(input_index);
- registerPermuteInitializer(input_index, input_obj);
+ registerCopyInitializer(bias_index, bias_obj);
}
} // namespace acl_neon