IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / cl / workloads / ClBatchNormalizationFloatWorkload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ClBatchNormalizationFloatWorkload.hpp"
7 #include <cl/ClTensorHandle.hpp>
8 #include <backendsCommon/CpuTensorHandle.hpp>
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 #include <cl/ClLayerSupport.hpp>
11
12 #include "ClWorkloadUtils.hpp"
13
14 namespace armnn
15 {
16 using namespace armcomputetensorutils;
17
18 arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
19                                                  const TensorInfo& output,
20                                                  const TensorInfo& mean,
21                                                  const TensorInfo& var,
22                                                  const TensorInfo& beta,
23                                                  const TensorInfo& gamma,
24                                                  const BatchNormalizationDescriptor &desc)
25 {
26     const DataLayout dataLayout = desc.m_DataLayout.GetDataLayout();
27
28     const arm_compute::TensorInfo aclInputInfo =
29           armcomputetensorutils::BuildArmComputeTensorInfo(input, dataLayout);
30     const arm_compute::TensorInfo aclOutputInfo =
31           armcomputetensorutils::BuildArmComputeTensorInfo(output, dataLayout);
32     const arm_compute::TensorInfo aclMeanInfo =
33           armcomputetensorutils::BuildArmComputeTensorInfo(mean, dataLayout);
34     const arm_compute::TensorInfo aclVarInfo =
35           armcomputetensorutils::BuildArmComputeTensorInfo(var, dataLayout);
36     const arm_compute::TensorInfo aclBetaInfo =
37           armcomputetensorutils::BuildArmComputeTensorInfo(beta, dataLayout);
38     const arm_compute::TensorInfo aclGammaInfo =
39           armcomputetensorutils::BuildArmComputeTensorInfo(gamma, dataLayout);
40
41     return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo,
42                                                             &aclOutputInfo,
43                                                             &aclMeanInfo,
44                                                             &aclVarInfo,
45                                                             &aclBetaInfo,
46                                                             &aclGammaInfo,
47                                                             desc.m_Eps);
48 }
49
50 ClBatchNormalizationFloatWorkload::ClBatchNormalizationFloatWorkload(
51     const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
52     : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
53 {
54     m_Mean = std::make_unique<arm_compute::CLTensor>();
55     BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
56
57     m_Variance = std::make_unique<arm_compute::CLTensor>();
58     BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
59
60     m_Gamma = std::make_unique<arm_compute::CLTensor>();
61     BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
62
63     m_Beta = std::make_unique<arm_compute::CLTensor>();
64     BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
65
66     m_Data.ValidateInputsOutputs("ClBatchNormalizationFloatWorkload", 1, 1);
67
68     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
69     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
70
71     arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout.GetDataLayout());
72     input.info()->set_data_layout(aclDataLayout);
73     output.info()->set_data_layout(aclDataLayout);
74
75     m_Layer.configure(&input,
76                       &output,
77                       m_Mean.get(),
78                       m_Variance.get(),
79                       m_Beta.get(),
80                       m_Gamma.get(),
81                       m_Data.m_Parameters.m_Eps);
82
83     InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
84     InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
85     InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
86     InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
87
88     // Force Compute Library to perform the necessary copying and reshaping, after which
89     // delete all the input tensors that will no longer be needed
90     m_Layer.prepare();
91     FreeUnusedTensors();
92 }
93
94 void ClBatchNormalizationFloatWorkload::Execute() const
95 {
96     ARMNN_SCOPED_PROFILING_EVENT_CL("ClBatchNormalizationFloatWorkload_Execute");
97     RunClFunction(m_Layer, CHECK_LOCATION());
98 }
99
100 void ClBatchNormalizationFloatWorkload::FreeUnusedTensors()
101 {
102     FreeTensorIfUnused(m_Mean);
103     FreeTensorIfUnused(m_Variance);
104     FreeTensorIfUnused(m_Gamma);
105     FreeTensorIfUnused(m_Beta);
106 }
107
108 } //namespace armnn