Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / NeonWorkloads / NeonBatchNormalizationFloat32Workload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #pragma once
7
8 #include <backends/NeonWorkloadUtils.hpp>
9
10 namespace armnn
11 {
12
13 arm_compute::Status NeonBatchNormalizationValidate(const TensorInfo& input,
14                                                    const TensorInfo& output,
15                                                    const TensorInfo& mean,
16                                                    const TensorInfo& var,
17                                                    const TensorInfo& beta,
18                                                    const TensorInfo& gamma,
19                                                    const BatchNormalizationDescriptor& descriptor);
20
21 class NeonBatchNormalizationFloat32Workload : public FloatWorkload<BatchNormalizationQueueDescriptor>
22 {
23 public:
24     NeonBatchNormalizationFloat32Workload(const BatchNormalizationQueueDescriptor& descriptor,
25                                           const WorkloadInfo& info);
26     virtual void Execute() const override;
27
28 private:
29     mutable arm_compute::NEBatchNormalizationLayer m_Layer;
30
31     std::unique_ptr<arm_compute::Tensor> m_Mean;
32     std::unique_ptr<arm_compute::Tensor> m_Variance;
33     std::unique_ptr<arm_compute::Tensor> m_Gamma;
34     std::unique_ptr<arm_compute::Tensor> m_Beta;
35
36     void FreeUnusedTensors();
37 };
38
39 } //namespace armnn
40
41
42