f0bfd3b2201c2d4e894c8b2d1c8bd77cdb78dcdf
[platform/upstream/armnn.git] / src / backends / backendsCommon / DynamicBackendUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "IBackendInternal.hpp"
9
10 #include <armnn/Exceptions.hpp>
11
12 #include <string>
13 #include <dlfcn.h>
14 #include <vector>
15
16 #include <boost/format.hpp>
17
18 #if !defined(DYNAMIC_BACKEND_PATHS)
19 #define DYNAMIC_BACKEND_PATHS ""
20 #endif
21
22 namespace armnn
23 {
24
25 class DynamicBackendUtils
26 {
27 public:
28     static void* OpenHandle(const std::string& sharedObjectPath);
29     static void CloseHandle(const void* sharedObjectHandle);
30
31     template<typename EntryPointType>
32     static EntryPointType GetEntryPoint(const void* sharedObjectHandle, const char* symbolName);
33
34     static bool IsBackendCompatible(const BackendVersion& backendVersion);
35
36     static std::vector<std::string> GetBackendPaths(const std::string& overrideBackendPath = "");
37     static bool IsPathValid(const std::string& path);
38
39 protected:
40     /// Protected methods for testing purposes
41     static bool IsBackendCompatibleImpl(const BackendVersion& backendApiVersion, const BackendVersion& backendVersion);
42     static std::vector<std::string> GetBackendPathsImpl(const std::string& backendPaths);
43
44 private:
45     static std::string GetDlError();
46
47     /// This class is to hold utility functions only
48     DynamicBackendUtils() = delete;
49 };
50
51 template<typename EntryPointType>
52 EntryPointType DynamicBackendUtils::GetEntryPoint(const void* sharedObjectHandle, const char* symbolName)
53 {
54     if (sharedObjectHandle == nullptr)
55     {
56         throw RuntimeException("GetEntryPoint error: invalid handle");
57     }
58
59     if (symbolName == nullptr)
60     {
61         throw RuntimeException("GetEntryPoint error: invalid symbol");
62     }
63
64     auto entryPoint = reinterpret_cast<EntryPointType>(dlsym(const_cast<void*>(sharedObjectHandle), symbolName));
65     if (!entryPoint)
66     {
67         throw RuntimeException(boost::str(boost::format("GetEntryPoint error: %1%") % GetDlError()));
68     }
69
70     return entryPoint;
71 }
72
73 } // namespace armnn