Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / WorkloadData.hpp
index 7f87135..db266e6 100644 (file)
@@ -17,7 +17,7 @@
 namespace armnn
 {
 
-//a helper function that returns the bias data type required for given input data type.
+//A helper function that returns the bias data type required for given input data type.
 DataType GetBiasDataType(DataType inputDataType);
 
 struct WorkloadInfo;
@@ -38,7 +38,7 @@ protected:
     QueueDescriptor& operator=(QueueDescriptor const&) = default;
 };
 
-// Base class for queue descriptors which contain parameters
+// Base class for queue descriptors which contain parameters.
 template <typename LayerDescriptor>
 struct QueueDescriptorWithParameters : public QueueDescriptor
 {
@@ -59,13 +59,13 @@ struct MemCopyQueueDescriptor : QueueDescriptor
 using InputQueueDescriptor = MemCopyQueueDescriptor;
 using OutputQueueDescriptor = MemCopyQueueDescriptor;
 
-// Softmax layer workload data
+// Softmax layer workload data.
 struct SoftmaxQueueDescriptor : QueueDescriptorWithParameters<SoftmaxDescriptor>
 {
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Splitter layer workload data
+// Splitter layer workload data.
 struct SplitterQueueDescriptor : QueueDescriptorWithParameters<ViewsDescriptor>
 {
     struct ViewOrigin
@@ -73,18 +73,18 @@ struct SplitterQueueDescriptor : QueueDescriptorWithParameters<ViewsDescriptor>
         ViewOrigin() {}
         ViewOrigin(std::vector<unsigned int> const& origin) : m_Origin(origin) {}
 
-        //view origin (size of the vector is the same as number of dimensions of the view)
+        //View origin (size of the vector is the same as number of dimensions of the view).
         std::vector<unsigned int> m_Origin;
     };
 
-    //view defines a tensor that will be carved from the input tensor.
-    //view origins are stored here, the extents are defined by sizes of the output tensors.
+    //View defines a tensor that will be carved from the input tensor.
+    //View origins are stored here, the extents are defined by sizes of the output tensors.
     std::vector<ViewOrigin> m_ViewOrigins;
 
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Merger layer workload data
+// Merger layer workload data.
 struct MergerQueueDescriptor : QueueDescriptorWithParameters<OriginsDescriptor>
 {
     struct ViewOrigin
@@ -92,24 +92,24 @@ struct MergerQueueDescriptor : QueueDescriptorWithParameters<OriginsDescriptor>
         ViewOrigin() {}
         ViewOrigin(const std::vector<unsigned int>& origin) : m_Origin(origin) {}
 
-        //view origin (size of the vector is the same as number of dimensions of the view)
+        //View origin (size of the vector is the same as number of dimensions of the view).
         std::vector<unsigned int> m_Origin;
     };
 
-    //view defines a sub-area of the output tensor that will be filled with the corresponding input tensor.
-    //view origins are stored here, the extents are defined by sizes of the input tensors.
+    //View defines a sub-area of the output tensor that will be filled with the corresponding input tensor.
+    //View origins are stored here, the extents are defined by sizes of the input tensors.
     std::vector<ViewOrigin> m_ViewOrigins;
 
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Activation layer workload data
+// Activation layer workload data.
 struct ActivationQueueDescriptor : QueueDescriptorWithParameters<ActivationDescriptor>
 {
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Fully connected layer workload data
+// Fully connected layer workload data.
 struct FullyConnectedQueueDescriptor : QueueDescriptorWithParameters<FullyConnectedDescriptor>
 {
     FullyConnectedQueueDescriptor()
@@ -124,19 +124,19 @@ struct FullyConnectedQueueDescriptor : QueueDescriptorWithParameters<FullyConnec
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Permute layer workload data
+// Permute layer workload data.
 struct PermuteQueueDescriptor : QueueDescriptorWithParameters<PermuteDescriptor>
 {
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Pooling 2D layer workload data
+// Pooling 2D layer workload data.
 struct Pooling2dQueueDescriptor : QueueDescriptorWithParameters<Pooling2dDescriptor>
 {
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Convolution 2D layer workload data
+// Convolution 2D layer workload data.
 struct Convolution2dQueueDescriptor : QueueDescriptorWithParameters<Convolution2dDescriptor>
 {
     Convolution2dQueueDescriptor()
@@ -151,7 +151,7 @@ struct Convolution2dQueueDescriptor : QueueDescriptorWithParameters<Convolution2
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Depthwise Convolution 2D layer workload data
+// Depthwise Convolution 2D layer workload data.
 struct DepthwiseConvolution2dQueueDescriptor : QueueDescriptorWithParameters<DepthwiseConvolution2dDescriptor>
 {
     DepthwiseConvolution2dQueueDescriptor()
@@ -166,25 +166,25 @@ struct DepthwiseConvolution2dQueueDescriptor : QueueDescriptorWithParameters<Dep
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Normalization layer workload data
+// Normalization layer workload data.
 struct NormalizationQueueDescriptor : QueueDescriptorWithParameters<NormalizationDescriptor>
 {
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Add layer workload data
+// Add layer workload data.
 struct AdditionQueueDescriptor : QueueDescriptor
 {
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Multiplication layer workload data
+// Multiplication layer workload data.
 struct MultiplicationQueueDescriptor : QueueDescriptor
 {
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Batch norm layer workload data
+// Batch norm layer workload data.
 struct BatchNormalizationQueueDescriptor : QueueDescriptorWithParameters<BatchNormalizationDescriptor>
 {
     BatchNormalizationQueueDescriptor()
@@ -249,4 +249,58 @@ struct FloorQueueDescriptor : QueueDescriptor
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
+struct LstmQueueDescriptor : QueueDescriptorWithParameters<LstmDescriptor>
+{
+    LstmQueueDescriptor()
+        : m_InputToInputWeights(nullptr)
+        , m_InputToForgetWeights(nullptr)
+        , m_InputToCellWeights(nullptr)
+        , m_InputToOutputWeights(nullptr)
+        , m_RecurrentToInputWeights(nullptr)
+        , m_RecurrentToForgetWeights(nullptr)
+        , m_RecurrentToCellWeights(nullptr)
+        , m_RecurrentToOutputWeights(nullptr)
+        , m_CellToInputWeights(nullptr)
+        , m_CellToForgetWeights(nullptr)
+        , m_CellToOutputWeights(nullptr)
+        , m_InputGateBias(nullptr)
+        , m_ForgetGateBias(nullptr)
+        , m_CellBias(nullptr)
+        , m_OutputGateBias(nullptr)
+        , m_ProjectionWeights(nullptr)
+        , m_ProjectionBias(nullptr)
+    {
+    }
+
+    const ConstCpuTensorHandle* m_InputToInputWeights;
+    const ConstCpuTensorHandle* m_InputToForgetWeights;
+    const ConstCpuTensorHandle* m_InputToCellWeights;
+    const ConstCpuTensorHandle* m_InputToOutputWeights;
+    const ConstCpuTensorHandle* m_RecurrentToInputWeights;
+    const ConstCpuTensorHandle* m_RecurrentToForgetWeights;
+    const ConstCpuTensorHandle* m_RecurrentToCellWeights;
+    const ConstCpuTensorHandle* m_RecurrentToOutputWeights;
+    const ConstCpuTensorHandle* m_CellToInputWeights;
+    const ConstCpuTensorHandle* m_CellToForgetWeights;
+    const ConstCpuTensorHandle* m_CellToOutputWeights;
+    const ConstCpuTensorHandle* m_InputGateBias;
+    const ConstCpuTensorHandle* m_ForgetGateBias;
+    const ConstCpuTensorHandle* m_CellBias;
+    const ConstCpuTensorHandle* m_OutputGateBias;
+    const ConstCpuTensorHandle* m_ProjectionWeights;
+    const ConstCpuTensorHandle* m_ProjectionBias;
+
+    void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ConvertFp16ToFp32QueueDescriptor : QueueDescriptor
+{
+    void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ConvertFp32ToFp16QueueDescriptor : QueueDescriptor
+{
+    void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
 } //namespace armnn