2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/Types.hpp>
8 #include <armnn/IRuntime.hpp>
9 #include <armnn/Deprecated.hpp>
11 #include <ISubgraphViewConverter.hpp>
12 #include <SubgraphView.hpp>
13 #include <optimizations/Optimization.hpp>
15 #include "IBackendContext.hpp"
16 #include "IMemoryManager.hpp"
17 #include "ITensorHandleFactory.hpp"
18 #include "OptimizationViews.hpp"
24 class IWorkloadFactory;
28 class IBackendInternal : public IBackend
31 // Creation must be done through a specific
33 IBackendInternal() = default;
36 // Allow backends created by the factory function
37 // to be destroyed through IBackendInternal.
38 ~IBackendInternal() override = default;
40 using IWorkloadFactoryPtr = std::unique_ptr<IWorkloadFactory>;
41 using IBackendContextPtr = std::unique_ptr<IBackendContext>;
42 using OptimizationPtr = std::unique_ptr<Optimization>;
43 using Optimizations = std::vector<OptimizationPtr>;
44 using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>;
46 using IMemoryManagerUniquePtr = std::unique_ptr<IMemoryManager>;
47 using IMemoryManagerSharedPtr = std::shared_ptr<IMemoryManager>;
49 using GraphUniquePtr = std::unique_ptr<Graph>;
50 using SubgraphViewUniquePtr = std::unique_ptr<SubgraphView>;
52 ARMNN_NO_DEPRECATE_WARN_BEGIN
53 using ISubGraphConverterPtr ARMNN_DEPRECATED_MSG("This type is no longer supported")
54 = std::unique_ptr<ISubGraphConverter>;
55 using SubGraphUniquePtr ARMNN_DEPRECATED_MSG("SubGraph is deprecated, use SubgraphView instead")
56 = std::unique_ptr<SubGraph>;
58 ARMNN_DEPRECATED_MSG("This method is no longer supported")
59 virtual ISubGraphConverterPtr CreateSubGraphConverter(const std::shared_ptr<SubGraph>& subGraph) const
61 return ISubGraphConverterPtr{};
64 ARMNN_DEPRECATED_MSG("Use \"OptimizationViews OptimizeSubgraphView(const SubgraphView&)\" instead")
65 virtual Optimizations GetOptimizations() const
67 return Optimizations{};
70 ARMNN_DEPRECATED_MSG("Use \"OptimizationViews OptimizeSubgraphView(const SubgraphView&)\" instead")
71 virtual SubGraphUniquePtr OptimizeSubGraph(const SubGraph& subGraph, bool& optimizationAttempted) const
73 optimizationAttempted = false;
76 ARMNN_NO_DEPRECATE_WARN_END
79 virtual IMemoryManagerUniquePtr CreateMemoryManager() const
81 return IMemoryManagerUniquePtr();
84 virtual IWorkloadFactoryPtr CreateWorkloadFactory(
85 const IMemoryManagerSharedPtr& memoryManager = nullptr) const = 0;
87 virtual IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const
89 return IBackendContextPtr{};
92 virtual ILayerSupportSharedPtr GetLayerSupport() const = 0;
94 // Default implementation of OptimizeSubgraphView for backward compatibility with the old API.
95 // Override this method with a custom optimization implementation.
96 virtual OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const
98 bool optimizationAttempted = false;
100 ARMNN_NO_DEPRECATE_WARN_BEGIN
101 SubGraphUniquePtr optSubgraph = OptimizeSubGraph(subgraph, optimizationAttempted);
102 ARMNN_NO_DEPRECATE_WARN_END
104 OptimizationViews result;
105 if (!optimizationAttempted)
107 result.AddUntouchedSubgraph(SubgraphView(subgraph));
113 result.AddSubstitution({subgraph, SubgraphView(*optSubgraph.get())});
117 result.AddFailedSubgraph(SubgraphView(subgraph));
123 bool SupportsTensorAllocatorAPI() const { return GetHandleFactoryPreferences().empty() == false; }
125 ITensorHandleFactory::FactoryId GetBackwardCompatibleFavoriteHandleFactory()
127 auto favorites = GetHandleFactoryPreferences();
128 if (favorites.empty())
130 return ITensorHandleFactory::LegacyFactoryId;
135 /// (Optional) Returns a vector of supported TensorHandleFactory ids in preference order.
136 virtual std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const
138 return std::vector<ITensorHandleFactory::FactoryId>();
141 /// (Optional) Register TensorHandleFactories
142 /// Either this method or CreateMemoryManager() and
143 /// IWorkloadFactory::CreateTensor()/IWorkloadFactory::CreateSubtensor() methods must be implemented.
144 virtual void RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry) {}
147 using IBackendInternalUniquePtr = std::unique_ptr<IBackendInternal>;