Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClWorkloads / ClConvertFp32ToFp16Workload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include "ClConvertFp32ToFp16Workload.hpp"
7 #include "backends/ClTensorHandle.hpp"
8
9 namespace armnn
10 {
11 using namespace armcomputetensorutils;
12
13 static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE;
14
15 ClConvertFp32ToFp16Workload::ClConvertFp32ToFp16Workload(
16     const ConvertFp32ToFp16QueueDescriptor& descriptor, const WorkloadInfo& info) :
17     Float32ToFloat16Workload<ConvertFp32ToFp16QueueDescriptor>(descriptor, info)
18 {
19     this->m_Data.ValidateInputsOutputs("ClConvertFp32ToFp16Workload", 1, 1);
20
21     arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor();
22     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor();
23
24     m_Layer.configure(&input, &output, g_AclConvertPolicy, 0);
25 }
26
27 void ClConvertFp32ToFp16Workload::Execute() const
28 {
29     ARMNN_SCOPED_PROFILING_EVENT_CL("ClConvertFp32ToFp16Workload_Execute");
30     m_Layer.run();
31 }
32
33 arm_compute::Status ClConvertFp32ToFp16WorkloadValidate(const TensorInfo& input,
34                                                         const TensorInfo& output,
35                                                         std::string* reasonIfUnsupported)
36 {
37     if (input.GetDataType() != DataType::Float32)
38     {
39         *reasonIfUnsupported = "Input should be Float32";
40         return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR, *reasonIfUnsupported);
41     }
42     if (output.GetDataType() != DataType::Float16)
43     {
44         *reasonIfUnsupported = "Output should be Float16";
45         return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR, *reasonIfUnsupported);
46     }
47
48     const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
49     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
50
51     const arm_compute::Status aclStatus = arm_compute::CLDepthConvertLayer::validate(
52         &aclInputInfo, &aclOutputInfo, g_AclConvertPolicy, 0);
53
54     const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
55     if (!supported && reasonIfUnsupported)
56     {
57         *reasonIfUnsupported = aclStatus.error_description();
58     }
59
60     return aclStatus;
61 }
62
63
64 } //namespace armnn