97f8ebd7d282969ffc717c44220f8e6a7626a607
[platform/upstream/armnn.git] / src / backends / test / WorkloadTestUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/Tensor.hpp>
8
9 namespace armnn
10 {
11 class ITensorHandle;
12 }
13
14 template <typename QueueDescriptor>
15 void AddInputToWorkload(QueueDescriptor& descriptor,
16     armnn::WorkloadInfo& info,
17     const armnn::TensorInfo& tensorInfo,
18     armnn::ITensorHandle* tensorHandle)
19 {
20     descriptor.m_Inputs.push_back(tensorHandle);
21     info.m_InputTensorInfos.push_back(tensorInfo);
22 }
23
24 template <typename QueueDescriptor>
25 void AddOutputToWorkload(QueueDescriptor& descriptor,
26     armnn::WorkloadInfo& info,
27     const armnn::TensorInfo& tensorInfo,
28     armnn::ITensorHandle* tensorHandle)
29 {
30     descriptor.m_Outputs.push_back(tensorHandle);
31     info.m_OutputTensorInfos.push_back(tensorInfo);
32 }
33
34 template <typename QueueDescriptor>
35 void SetWorkloadInput(QueueDescriptor& descriptor,
36     armnn::WorkloadInfo& info,
37     unsigned int index,
38     const armnn::TensorInfo& tensorInfo,
39     armnn::ITensorHandle* tensorHandle)
40 {
41     descriptor.m_Inputs[index] = tensorHandle;
42     info.m_InputTensorInfos[index] = tensorInfo;
43 }
44
45 template <typename QueueDescriptor>
46 void SetWorkloadOutput(QueueDescriptor& descriptor,
47     armnn::WorkloadInfo& info,
48     unsigned int index,
49     const armnn::TensorInfo& tensorInfo,
50     armnn::ITensorHandle* tensorHandle)
51 {
52     descriptor.m_Outputs[index] = tensorHandle;
53     info.m_OutputTensorInfos[index] = tensorInfo;
54 }