Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClWorkloads / ClBatchNormalizationFloat32Workload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include "ClBatchNormalizationFloat32Workload.hpp"
7 #include "backends/ClTensorHandle.hpp"
8 #include "backends/CpuTensorHandle.hpp"
9 #include "backends/ArmComputeTensorUtils.hpp"
10 #include "backends/ClLayerSupport.hpp"
11
12 namespace armnn
13 {
14 using namespace armcomputetensorutils;
15
16 arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
17                                                  const TensorInfo& output,
18                                                  const TensorInfo& mean,
19                                                  const TensorInfo& var,
20                                                  const TensorInfo& beta,
21                                                  const TensorInfo& gamma,
22                                                  const BatchNormalizationDescriptor &desc)
23 {
24     const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
25     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
26     const arm_compute::TensorInfo aclMeanInfo = BuildArmComputeTensorInfo(mean);
27     const arm_compute::TensorInfo aclVarInfo = BuildArmComputeTensorInfo(var);
28     const arm_compute::TensorInfo aclBetaInfo = BuildArmComputeTensorInfo(beta);
29     const arm_compute::TensorInfo aclGammaInfo = BuildArmComputeTensorInfo(gamma);
30
31     return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo,
32                                                             &aclOutputInfo,
33                                                             &aclMeanInfo,
34                                                             &aclVarInfo,
35                                                             &aclBetaInfo,
36                                                             &aclGammaInfo,
37                                                             desc.m_Eps);
38 }
39
40 ClBatchNormalizationFloat32Workload::ClBatchNormalizationFloat32Workload(
41     const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
42     : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
43 {
44     m_Mean = std::make_unique<arm_compute::CLTensor>();
45     BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
46
47     m_Variance = std::make_unique<arm_compute::CLTensor>();
48     BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
49
50     m_Gamma = std::make_unique<arm_compute::CLTensor>();
51     BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
52
53     m_Beta = std::make_unique<arm_compute::CLTensor>();
54     BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
55
56     m_Data.ValidateInputsOutputs("ClBatchNormalizationFloat32Workload", 1, 1);
57
58     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
59     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
60
61     m_Layer.configure(&input,
62                       &output,
63                       m_Mean.get(),
64                       m_Variance.get(),
65                       m_Beta.get(),
66                       m_Gamma.get(),
67                       m_Data.m_Parameters.m_Eps);
68
69     InitializeArmComputeClTensorDataForFloatTypes(*m_Mean, m_Data.m_Mean);
70     InitializeArmComputeClTensorDataForFloatTypes(*m_Variance, m_Data.m_Variance);
71     InitializeArmComputeClTensorDataForFloatTypes(*m_Beta, m_Data.m_Beta);
72     InitializeArmComputeClTensorDataForFloatTypes(*m_Gamma, m_Data.m_Gamma);
73
74     // Force Compute Library to perform the necessary copying and reshaping, after which
75     // delete all the input tensors that will no longer be needed
76     m_Layer.prepare();
77     FreeUnusedTensors();
78 }
79
80 void ClBatchNormalizationFloat32Workload::Execute() const
81 {
82     ARMNN_SCOPED_PROFILING_EVENT_CL("ClBatchNormalizationFloat32Workload_Execute");
83     m_Layer.run();
84 }
85
86 void ClBatchNormalizationFloat32Workload::FreeUnusedTensors()
87 {
88     FreeTensorIfUnused(m_Mean);
89     FreeTensorIfUnused(m_Variance);
90     FreeTensorIfUnused(m_Gamma);
91     FreeTensorIfUnused(m_Beta);
92 }
93
94 } //namespace armnn