Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClWorkloads / ClBatchNormalizationFloat32Workload.hpp
index ddbd0f0..a45614a 100644 (file)
 namespace armnn
 {
 
-class ClBatchNormalizationFloat32Workload : public Float32Workload<BatchNormalizationQueueDescriptor>
+arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
+                                                 const TensorInfo& output,
+                                                 const TensorInfo& mean,
+                                                 const TensorInfo& var,
+                                                 const TensorInfo& beta,
+                                                 const TensorInfo& gamma,
+                                                 const BatchNormalizationDescriptor& desc);
+
+class ClBatchNormalizationFloat32Workload : public FloatWorkload<BatchNormalizationQueueDescriptor>
 {
 public:
     ClBatchNormalizationFloat32Workload(const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info);
 
-    using Float32Workload<BatchNormalizationQueueDescriptor>::Float32Workload;
+    using FloatWorkload<BatchNormalizationQueueDescriptor>::FloatWorkload;
     void Execute() const override;
 
 private:
     mutable arm_compute::CLBatchNormalizationLayer m_Layer;
 
-    arm_compute::CLTensor m_Mean;
-    arm_compute::CLTensor m_Variance;
-    arm_compute::CLTensor m_Gamma;
-    arm_compute::CLTensor m_Beta;
+    std::unique_ptr<arm_compute::CLTensor> m_Mean;
+    std::unique_ptr<arm_compute::CLTensor> m_Variance;
+    std::unique_ptr<arm_compute::CLTensor> m_Gamma;
+    std::unique_ptr<arm_compute::CLTensor> m_Beta;
+
+    void FreeUnusedTensors();
 };
 
 } //namespace armnn