2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
8 #include <neon/workloads/NeonWorkloadUtils.hpp>
13 arm_compute::Status NeonBatchNormalizationValidate(const TensorInfo& input,
14 const TensorInfo& output,
15 const TensorInfo& mean,
16 const TensorInfo& var,
17 const TensorInfo& beta,
18 const TensorInfo& gamma,
19 const BatchNormalizationDescriptor& descriptor);
21 class NeonBatchNormalizationFloatWorkload : public FloatWorkload<BatchNormalizationQueueDescriptor>
24 NeonBatchNormalizationFloatWorkload(const BatchNormalizationQueueDescriptor& descriptor,
25 const WorkloadInfo& info);
26 virtual void Execute() const override;
29 mutable arm_compute::NEBatchNormalizationLayer m_Layer;
31 std::unique_ptr<arm_compute::Tensor> m_Mean;
32 std::unique_ptr<arm_compute::Tensor> m_Variance;
33 std::unique_ptr<arm_compute::Tensor> m_Gamma;
34 std::unique_ptr<arm_compute::Tensor> m_Beta;
36 void FreeUnusedTensors();