namespace armnn
{
-class ClBatchNormalizationFloat32Workload : public Float32Workload<BatchNormalizationQueueDescriptor>
+arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const TensorInfo& mean,
+ const TensorInfo& var,
+ const TensorInfo& beta,
+ const TensorInfo& gamma,
+ const BatchNormalizationDescriptor& desc);
+
+class ClBatchNormalizationFloat32Workload : public FloatWorkload<BatchNormalizationQueueDescriptor>
{
public:
ClBatchNormalizationFloat32Workload(const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info);
- using Float32Workload<BatchNormalizationQueueDescriptor>::Float32Workload;
+ using FloatWorkload<BatchNormalizationQueueDescriptor>::FloatWorkload;
void Execute() const override;
private:
mutable arm_compute::CLBatchNormalizationLayer m_Layer;
- arm_compute::CLTensor m_Mean;
- arm_compute::CLTensor m_Variance;
- arm_compute::CLTensor m_Gamma;
- arm_compute::CLTensor m_Beta;
+ std::unique_ptr<arm_compute::CLTensor> m_Mean;
+ std::unique_ptr<arm_compute::CLTensor> m_Variance;
+ std::unique_ptr<arm_compute::CLTensor> m_Gamma;
+ std::unique_ptr<arm_compute::CLTensor> m_Beta;
+
+ void FreeUnusedTensors();
};
} //namespace armnn