2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
8 #include <backends/NeonWorkloadUtils.hpp>
13 arm_compute::Status NeonBatchNormalizationValidate(const TensorInfo& input,
14 const TensorInfo& output,
15 const TensorInfo& mean,
16 const TensorInfo& var,
17 const TensorInfo& beta,
18 const TensorInfo& gamma,
19 const BatchNormalizationDescriptor& descriptor);
21 class NeonBatchNormalizationFloat32Workload : public FloatWorkload<BatchNormalizationQueueDescriptor>
24 NeonBatchNormalizationFloat32Workload(const BatchNormalizationQueueDescriptor& descriptor,
25 const WorkloadInfo& info);
26 virtual void Execute() const override;
29 mutable arm_compute::NEBatchNormalizationLayer m_Layer;
31 std::unique_ptr<arm_compute::Tensor> m_Mean;
32 std::unique_ptr<arm_compute::Tensor> m_Variance;
33 std::unique_ptr<arm_compute::Tensor> m_Gamma;
34 std::unique_ptr<arm_compute::Tensor> m_Beta;
36 void FreeUnusedTensors();