b327a9ea609989a2f4917de4d451b37dae2dcd3e
[platform/upstream/armnn.git] / src / backends / backendsCommon / DynamicBackendUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "IBackendInternal.hpp"
9
10 #include <armnn/Exceptions.hpp>
11
12 #include <string>
13 #include <dlfcn.h>
14 #include <vector>
15
16 #include <boost/format.hpp>
17
18 #if !defined(DYNAMIC_BACKEND_PATHS)
19 #define DYNAMIC_BACKEND_PATHS ""
20 #endif
21
22 namespace armnn
23 {
24
25 class DynamicBackendUtils
26 {
27 public:
28     static void* OpenHandle(const std::string& sharedObjectPath);
29     static void CloseHandle(const void* sharedObjectHandle);
30
31     template<typename EntryPointType>
32     static EntryPointType GetEntryPoint(const void* sharedObjectHandle, const char* symbolName);
33
34     static bool IsBackendCompatible(const BackendVersion& backendVersion);
35
36     static std::vector<std::string> GetBackendPaths(const std::string& overrideBackendPath = "");
37     static bool IsPathValid(const std::string& path);
38     static std::vector<std::string> GetSharedObjects(const std::vector<std::string>& backendPaths);
39
40 protected:
41     /// Protected methods for testing purposes
42     static bool IsBackendCompatibleImpl(const BackendVersion& backendApiVersion, const BackendVersion& backendVersion);
43     static std::vector<std::string> GetBackendPathsImpl(const std::string& backendPaths);
44
45 private:
46     static std::string GetDlError();
47
48     /// This class is to hold utility functions only
49     DynamicBackendUtils() = delete;
50 };
51
52 template<typename EntryPointType>
53 EntryPointType DynamicBackendUtils::GetEntryPoint(const void* sharedObjectHandle, const char* symbolName)
54 {
55     if (sharedObjectHandle == nullptr)
56     {
57         throw RuntimeException("GetEntryPoint error: invalid handle");
58     }
59
60     if (symbolName == nullptr)
61     {
62         throw RuntimeException("GetEntryPoint error: invalid symbol");
63     }
64
65     auto entryPoint = reinterpret_cast<EntryPointType>(dlsym(const_cast<void*>(sharedObjectHandle), symbolName));
66     if (!entryPoint)
67     {
68         throw RuntimeException(boost::str(boost::format("GetEntryPoint error: %1%") % GetDlError()));
69     }
70
71     return entryPoint;
72 }
73
74 } // namespace armnn