IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / neon / workloads / NeonBatchNormalizationFloatWorkload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <neon/workloads/NeonWorkloadUtils.hpp>
9
10 namespace armnn
11 {
12
13 arm_compute::Status NeonBatchNormalizationValidate(const TensorInfo& input,
14                                                    const TensorInfo& output,
15                                                    const TensorInfo& mean,
16                                                    const TensorInfo& var,
17                                                    const TensorInfo& beta,
18                                                    const TensorInfo& gamma,
19                                                    const BatchNormalizationDescriptor& descriptor);
20
21 class NeonBatchNormalizationFloatWorkload : public FloatWorkload<BatchNormalizationQueueDescriptor>
22 {
23 public:
24     NeonBatchNormalizationFloatWorkload(const BatchNormalizationQueueDescriptor& descriptor,
25                                         const WorkloadInfo& info);
26     virtual void Execute() const override;
27
28 private:
29     mutable arm_compute::NEBatchNormalizationLayer m_Layer;
30
31     std::unique_ptr<arm_compute::Tensor> m_Mean;
32     std::unique_ptr<arm_compute::Tensor> m_Variance;
33     std::unique_ptr<arm_compute::Tensor> m_Gamma;
34     std::unique_ptr<arm_compute::Tensor> m_Beta;
35
36     void FreeUnusedTensors();
37 };
38
39 } //namespace armnn
40
41
42