2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
8 #include "LayersFwd.hpp"
16 Optimization() = default;
17 virtual ~Optimization() = default;
18 virtual void Run(Graph& graph, Layer& base) const = 0;
23 // The implementation of the following wrappers make use of the CRTP C++ idiom
24 // (curiously recurring template pattern).
25 // For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
27 /// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
28 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
29 /// after applying each optimization.
30 template <typename BaseType, typename Wrapped>
31 class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped
34 using Wrapped::Wrapped;
36 void Run(Graph& graph, Layer& base) const override
38 if (base.GetType() == LayerEnumOf<BaseType>())
40 Wrapped::Run(graph, *boost::polymorphic_downcast<BaseType*>(&base));
45 ~OptimizeForTypeImpl() = default;
48 /// Specialization that calls Wrapped::Run() for any layer type.
49 template <typename Wrapped>
50 class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped
53 using Wrapped::Wrapped;
55 void Run(Graph& graph, Layer& base) const override
57 Wrapped::Run(graph, base);
61 ~OptimizeForTypeImpl() = default;
64 template <typename BaseType, typename Wrapped>
65 class OptimizeForType final : public OptimizeForTypeImpl<BaseType, Wrapped>
68 using OptimizeForTypeImpl<BaseType, Wrapped>::OptimizeForTypeImpl;
71 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
72 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
73 /// after applying each optimization.
74 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
75 /// - Children layers are removed if left unconnected after applying the wrapped optimization.
76 template <typename BaseType, typename ChildType, typename Wrapped>
77 class OptimizeForConnectionImpl : public Wrapped
80 using Wrapped::Wrapped;
82 void Run(Graph& graph, BaseType& base) const
84 for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
86 for (auto&& childInput : output->GetConnections())
88 if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
90 Wrapped::Run(graph, *childInput);
94 // Removes unconnected children.
95 for (unsigned int i = 0; i < output->GetNumConnections();)
97 Layer* child = &output->GetConnection(i)->GetOwningLayer();
99 if (child->IsOutputUnconnected())
101 graph.EraseLayer(child);
112 ~OptimizeForConnectionImpl() = default;
115 template <typename BaseType, typename ChildType, typename Wrapped>
116 class OptimizeForConnection final
117 : public OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>
120 using OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl;