Release 18.08
[platform/upstream/armnn.git] / src / armnn / optimizations / Optimization.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "Graph.hpp"
8 #include "LayersFwd.hpp"
9
10 namespace armnn
11 {
12
13 class Optimization
14 {
15 public:
16     Optimization() = default;
17     virtual ~Optimization() = default;
18     virtual void Run(Graph& graph, Layer& base) const = 0;
19 protected:
20 };
21
22 // Wrappers
23 // The implementation of the following wrappers make use of the CRTP C++ idiom
24 // (curiously recurring template pattern).
25 // For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
26
27 /// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
28 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
29 ///   after applying each optimization.
30 template <typename BaseType, typename Wrapped>
31 class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped
32 {
33 public:
34     using Wrapped::Wrapped;
35
36     void Run(Graph& graph, Layer& base) const override
37     {
38         if (base.GetType() == LayerEnumOf<BaseType>())
39         {
40             Wrapped::Run(graph, *boost::polymorphic_downcast<BaseType*>(&base));
41         }
42     }
43
44 protected:
45     ~OptimizeForTypeImpl() = default;
46 };
47
48 /// Specialization that calls Wrapped::Run() for any layer type.
49 template <typename Wrapped>
50 class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped
51 {
52 public:
53     using Wrapped::Wrapped;
54
55     void Run(Graph& graph, Layer& base) const override
56     {
57         Wrapped::Run(graph, base);
58     }
59
60 protected:
61     ~OptimizeForTypeImpl() = default;
62 };
63
64 template <typename BaseType, typename Wrapped>
65 class OptimizeForType final : public OptimizeForTypeImpl<BaseType, Wrapped>
66 {
67 public:
68     using OptimizeForTypeImpl<BaseType, Wrapped>::OptimizeForTypeImpl;
69 };
70
71 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
72 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
73 ///   after applying each optimization.
74 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
75 /// - Children layers are removed if left unconnected after applying the wrapped optimization.
76 template <typename BaseType, typename ChildType, typename Wrapped>
77 class OptimizeForConnectionImpl : public Wrapped
78 {
79 public:
80     using Wrapped::Wrapped;
81
82     void Run(Graph& graph, BaseType& base) const
83     {
84         for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
85         {
86             for (auto&& childInput : output->GetConnections())
87             {
88                 if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
89                 {
90                     Wrapped::Run(graph, *childInput);
91                 }
92             }
93
94             // Removes unconnected children.
95             for (unsigned int i = 0; i < output->GetNumConnections();)
96             {
97                 Layer* child = &output->GetConnection(i)->GetOwningLayer();
98
99                 if (child->IsOutputUnconnected())
100                 {
101                     graph.EraseLayer(child);
102                 }
103                 else
104                 {
105                     ++i;
106                 }
107             }
108         }
109     }
110
111 protected:
112     ~OptimizeForConnectionImpl() = default;
113 };
114
115 template <typename BaseType, typename ChildType, typename Wrapped>
116 class OptimizeForConnection final
117     : public OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>
118 {
119 public:
120     using OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl;
121 };
122
123 } // namespace armnn