2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
8 #include "backends/CpuTensorHandle.hpp"
10 #include <armnn/Tensor.hpp>
11 #include <armnn/Types.hpp>
14 #include <boost/polymorphic_cast.hpp>
19 ////////////////////////////////////////////
21 ////////////////////////////////////////////
23 inline const TensorInfo& GetTensorInfo(const ITensorHandle* tensorHandle)
25 // We know that reference workloads use CpuTensorHandles only, so this cast is legitimate.
26 const ConstCpuTensorHandle* cpuTensorHandle =
27 boost::polymorphic_downcast<const ConstCpuTensorHandle*>(tensorHandle);
28 return cpuTensorHandle->GetTensorInfo();
31 template <typename DataType>
32 inline const DataType* GetConstCpuData(const ITensorHandle* tensorHandle)
34 // We know that reference workloads use (Const)CpuTensorHandles only, so this cast is legitimate.
35 const ConstCpuTensorHandle* cpuTensorHandle =
36 boost::polymorphic_downcast<const ConstCpuTensorHandle*>(tensorHandle);
37 return cpuTensorHandle->GetConstTensor<DataType>();
40 template <typename DataType>
41 inline DataType* GetCpuData(const ITensorHandle* tensorHandle)
43 // We know that reference workloads use CpuTensorHandles only, so this cast is legitimate.
44 const CpuTensorHandle* cpuTensorHandle = boost::polymorphic_downcast<const CpuTensorHandle*>(tensorHandle);
45 return cpuTensorHandle->GetTensor<DataType>();
48 template <typename DataType, typename PayloadType>
49 const DataType* GetInputTensorData(unsigned int idx, const PayloadType& data)
51 const ITensorHandle* tensorHandle = data.m_Inputs[idx];
52 return GetConstCpuData<DataType>(tensorHandle);
55 template <typename DataType, typename PayloadType>
56 DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data)
58 const ITensorHandle* tensorHandle = data.m_Outputs[idx];
59 return GetCpuData<DataType>(tensorHandle);
62 template <typename PayloadType>
63 const float* GetInputTensorDataFloat(unsigned int idx, const PayloadType& data)
65 return GetInputTensorData<float>(idx, data);
68 template <typename PayloadType>
69 float* GetOutputTensorDataFloat(unsigned int idx, const PayloadType& data)
71 return GetOutputTensorData<float>(idx, data);
74 template <typename PayloadType>
75 const Half* GetInputTensorDataHalf(unsigned int idx, const PayloadType& data)
77 return GetInputTensorData<Half>(idx, data);
80 template <typename PayloadType>
81 Half* GetOutputTensorDataHalf(unsigned int idx, const PayloadType& data)
83 return GetOutputTensorData<Half>(idx, data);
86 ////////////////////////////////////////////
88 ////////////////////////////////////////////
90 inline const uint8_t* GetConstCpuU8Data(const ITensorHandle* tensorHandle)
92 // We know that reference workloads use (Const)CpuTensorHandles only, so this cast is legitimate.
93 const ConstCpuTensorHandle* cpuTensorHandle =
94 boost::polymorphic_downcast<const ConstCpuTensorHandle*>(tensorHandle);
95 return cpuTensorHandle->GetConstTensor<uint8_t>();
98 inline uint8_t* GetCpuU8Data(const ITensorHandle* tensorHandle)
100 // We know that reference workloads use CpuTensorHandles only, so this cast is legitimate.
101 const CpuTensorHandle* cpuTensorHandle = boost::polymorphic_downcast<const CpuTensorHandle*>(tensorHandle);
102 return cpuTensorHandle->GetTensor<uint8_t>();
105 template <typename PayloadType>
106 const uint8_t* GetInputTensorDataU8(unsigned int idx, const PayloadType& data)
108 const ITensorHandle* tensorHandle = data.m_Inputs[idx];
109 return GetConstCpuU8Data(tensorHandle);
112 template <typename PayloadType>
113 uint8_t* GetOutputTensorDataU8(unsigned int idx, const PayloadType& data)
115 const ITensorHandle* tensorHandle = data.m_Outputs[idx];
116 return GetCpuU8Data(tensorHandle);
120 std::vector<float> Dequantize(const T* quant, const TensorInfo& info)
122 std::vector<float> ret(info.GetNumElements());
123 for (size_t i = 0; i < info.GetNumElements(); i++)
125 ret[i] = armnn::Dequantize(quant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
130 inline void Quantize(uint8_t* quant, const float* dequant, const TensorInfo& info)
132 for (size_t i = 0; i < info.GetNumElements(); i++)
134 quant[i] = armnn::Quantize<uint8_t>(dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());