Release 18.03
[platform/upstream/armnn.git] / src / armnn / Runtime.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "Runtime.hpp"
6
7 #include "armnn/Version.hpp"
8
9 #ifdef ARMCOMPUTECL_ENABLED
10 #include <arm_compute/core/CL/OpenCL.h>
11 #include <arm_compute/core/CL/CLKernelLibrary.h>
12 #include <arm_compute/runtime/CL/CLScheduler.h>
13 #endif
14
15 #include <boost/log/trivial.hpp>
16 #include <boost/polymorphic_cast.hpp>
17
18 using namespace armnn;
19 using namespace std;
20
21 namespace armnn
22 {
23
24 IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
25 {
26     return new Runtime(options);
27 }
28
29 IRuntimePtr IRuntime::Create(const CreationOptions& options)
30 {
31     return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
32 }
33
34 void IRuntime::Destroy(IRuntime* runtime)
35 {
36     delete boost::polymorphic_downcast<Runtime*>(runtime);
37 }
38
39 int Runtime::GenerateNetworkId()
40 {
41     return m_NetworkIdCounter++;
42 }
43
44 Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
45 {
46     IOptimizedNetwork* rawNetwork = inNetwork.release();
47     unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
48         std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
49         m_WorkloadFactories);
50
51     if (!loadedNetwork)
52     {
53         return Status::Failure;
54     }
55
56     networkIdOut = GenerateNetworkId();
57
58     // store the network
59     m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
60
61     return Status::Success;
62 }
63
64 Status Runtime::UnloadNetwork(NetworkId networkId)
65 {
66 #ifdef ARMCOMPUTECL_ENABLED
67     if (arm_compute::CLScheduler::get().context()() != NULL)
68     {
69         arm_compute::CLScheduler::get().sync();
70     }
71 #endif
72     if (m_LoadedNetworks.erase(networkId) == 0)
73     {
74         BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
75         return Status::Failure;
76     }
77 #ifdef ARMCOMPUTECL_ENABLED
78     if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty())
79     {
80         m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime();
81     }
82 #endif
83     BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
84     return Status::Success;
85 }
86
87 Runtime::Runtime(const CreationOptions& options)
88 : m_NetworkIdCounter(0)
89 {
90     BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
91     BOOST_LOG_TRIVIAL(info) << "Using compute device: " << options.m_DefaultComputeDevice << "\n";
92     m_DeviceSpec.DefaultComputeDevice = options.m_DefaultComputeDevice;
93
94    // If useCpuRefAsFallback is false, the reference workload factory will be prevented from creating
95    // operation workloads, unless the default compute device is precisely the reference backend.
96     m_WorkloadFactories.m_CpuRef = make_shared<RefWorkloadFactory>(
97         options.m_DefaultComputeDevice == Compute::CpuRef ? true : options.m_UseCpuRefAsFallback);
98     m_WorkloadFactories.m_CpuAcc = make_shared<NeonWorkloadFactory>();
99     m_WorkloadFactories.m_GpuAcc = make_shared<ClWorkloadFactory>(options.m_ClTunedParameters);
100
101     if (options.m_DefaultComputeDevice == Compute::GpuAcc)
102     {
103         m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime();
104     }
105 }
106
107 Runtime::~Runtime()
108 {
109     std::vector<int> networkIDs;
110     std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
111                    std::back_inserter(networkIDs),
112                    [](const auto &pair) { return pair.first; });
113
114     for (auto networkID : networkIDs)
115     {
116         UnloadNetwork(networkID);
117     }
118 }
119
120 TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
121 {
122     LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
123     return net->GetInputTensorInfo(layerId);
124 }
125
126 TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
127 {
128     const LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
129     return net->GetOutputTensorInfo(layerId);
130 }
131
132 Status Runtime::EnqueueWorkload(NetworkId networkId,
133                                      const InputTensors& inputTensors,
134                                      const OutputTensors& outputTensors)
135 {
136     LoadedNetwork* loadedNetwork = m_LoadedNetworks.at(networkId).get();
137     return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors, m_WorkloadFactories);
138 }
139
140 }