2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "NeonBatchNormalizationFloatWorkload.hpp"
7 #include <backendsCommon/CpuTensorHandle.hpp>
8 #include <aclCommon/ArmComputeTensorUtils.hpp>
9 #include <armnn/ArmNN.hpp>
13 using namespace armcomputetensorutils;
16 arm_compute::Status NeonBatchNormalizationValidate(const TensorInfo& input,
17 const TensorInfo& output,
18 const TensorInfo& mean,
19 const TensorInfo& var,
20 const TensorInfo& beta,
21 const TensorInfo& gamma,
22 const BatchNormalizationDescriptor& descriptor)
24 const DataLayout dataLayout = descriptor.m_DataLayout.GetDataLayout();
26 const arm_compute::TensorInfo aclInputInfo =
27 armcomputetensorutils::BuildArmComputeTensorInfo(input, dataLayout);
28 const arm_compute::TensorInfo aclOutputInfo =
29 armcomputetensorutils::BuildArmComputeTensorInfo(output, dataLayout);
30 const arm_compute::TensorInfo aclMeanInfo =
31 armcomputetensorutils::BuildArmComputeTensorInfo(mean, dataLayout);
32 const arm_compute::TensorInfo aclVarInfo =
33 armcomputetensorutils::BuildArmComputeTensorInfo(var, dataLayout);
34 const arm_compute::TensorInfo aclBetaInfo =
35 armcomputetensorutils::BuildArmComputeTensorInfo(beta, dataLayout);
36 const arm_compute::TensorInfo aclGammaInfo =
37 armcomputetensorutils::BuildArmComputeTensorInfo(gamma, dataLayout);
39 return arm_compute::NEBatchNormalizationLayer::validate(&aclInputInfo,
48 NeonBatchNormalizationFloatWorkload::NeonBatchNormalizationFloatWorkload(
49 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
50 : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
52 m_Data.ValidateInputsOutputs("NeonBatchNormalizationFloatWorkload", 1, 1);
54 arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
55 arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
57 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout.GetDataLayout());
58 input.info()->set_data_layout(aclDataLayout);
59 output.info()->set_data_layout(aclDataLayout);
61 m_Mean = std::make_unique<arm_compute::Tensor>();
62 BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
64 m_Variance = std::make_unique<arm_compute::Tensor>();
65 BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
67 m_Gamma = std::make_unique<arm_compute::Tensor>();
68 BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
70 m_Beta = std::make_unique<arm_compute::Tensor>();
71 BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
73 m_Layer.configure(&input,
79 m_Data.m_Parameters.m_Eps);
81 InitializeArmComputeTensorData(*m_Mean, m_Data.m_Mean);
82 InitializeArmComputeTensorData(*m_Variance, m_Data.m_Variance);
83 InitializeArmComputeTensorData(*m_Gamma, m_Data.m_Gamma);
84 InitializeArmComputeTensorData(*m_Beta, m_Data.m_Beta);
86 // Force Compute Library to perform the necessary copying and reshaping, after which
87 // delete all the input tensors that will no longer be needed
92 void NeonBatchNormalizationFloatWorkload::Execute() const
94 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonBatchNormalizationFloatWorkload_Execute");
98 void NeonBatchNormalizationFloatWorkload::FreeUnusedTensors()
100 FreeTensorIfUnused(m_Mean);
101 FreeTensorIfUnused(m_Variance);
102 FreeTensorIfUnused(m_Gamma);
103 FreeTensorIfUnused(m_Beta);