Release 18.03
[platform/upstream/armnn.git] / src / armnn / optimizations / SquashEqualSiblings.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "Optimization.hpp"
8
9 namespace armnn
10 {
11 namespace optimizations
12 {
13
14 template <typename Comparable>
15 class SquashEqualSiblingsImpl
16 {
17 public:
18     /// Run for every connection between a base Layer (any) and a child ComparableLayer.
19     /// For all siblings of the child layer that compare equal to it, bypasses and removes
20     /// them. I.e., moves the connections in the outputs of the siblings to the outputs of
21     /// the child layer, so the siblings are left unconnected (and later removed).
22     void Run(Graph& graph, InputSlot& connection) const
23     {
24         auto& child = connection.GetOwningLayer();
25
26         if (!child.IsOutputUnconnected())
27         {
28             OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
29
30             if (baseOutput.GetNumConnections() > 1)
31             {
32                 auto& comparableChild = *boost::polymorphic_downcast<Comparable*>(&child);
33
34                 Layer* lowestPriorityChild = &child;
35                 for (auto&& it : baseOutput.GetConnections())
36                 {
37                     Layer* sibling = &it->GetOwningLayer();
38                     if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling))
39                     {
40                         if (sibling->GetPriority() < lowestPriorityChild->GetPriority())
41                         {
42                             std::swap(sibling, lowestPriorityChild);
43                         }
44                         // Bypass sibling. It will be removed as it's left unconnected.
45                         auto siblingOut = sibling->BeginOutputSlots();
46                         for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots();
47                              lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut)
48                         {
49                             siblingOut->MoveAllConnections(*lowestPriorityChildOut);
50                             ++siblingOut;
51                         }
52                     }
53                 }
54             }
55         }
56     }
57
58 protected:
59     SquashEqualSiblingsImpl() = default;
60     ~SquashEqualSiblingsImpl() = default;
61 };
62
63 using SquashEqualPermuteSiblings = OptimizeForConnection<Layer, PermuteLayer, SquashEqualSiblingsImpl<PermuteLayer>>;
64 using SquashEqualReshapeSiblings = OptimizeForConnection<Layer, ReshapeLayer, SquashEqualSiblingsImpl<ReshapeLayer>>;
65
66 } // namespace optimizations
67 } // namespace armnn