Release 18.02
[platform/upstream/armnn.git] / src / armnn / optimizations / OptimizeConsecutiveReshapes.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "Optimization.hpp"
8
9 namespace armnn
10 {
11 namespace optimizations
12 {
13
14 class OptimizeConsecutiveReshapesImpl
15 {
16 public:
17     /// Run for every connection between a base RashapeLayer and a child ReshapeLayer.
18     /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
19     void Run(Graph& graph, InputSlot& connection) const
20     {
21         auto& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22         auto& child = connection.GetOwningLayer();
23
24         BOOST_ASSERT(base.GetType() == LayerType::Reshape);
25         BOOST_ASSERT(child.GetType() == LayerType::Reshape);
26
27         OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
28
29         const TensorInfo& inInfo = parentOut->GetTensorInfo();
30         const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
31
32         if (inInfo.GetShape() != outInfo.GetShape())
33         {
34             // Insert equivalent reshape before base layer
35             const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
36             const ReshapeDescriptor descriptor{outInfo.GetShape()};
37             auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
38             // Set tensor info for new layer
39             newReshape.GetOutputHandler().SetTensorInfo(outInfo);
40             // Reconnect base with original parent
41             newReshape.GetOutputSlot().MoveAllConnections(*parentOut);
42             // Parent is now the new layer
43             parentOut = &newReshape.GetOutputSlot();
44         }
45
46         // Move connections in child output to parent layer.
47         // Child layer will be removed as it's left unconnected.
48         // Base layer will be removed if left unconnected.
49         child.GetOutputSlot().MoveAllConnections(*parentOut);
50     }
51
52 protected:
53     OptimizeConsecutiveReshapesImpl() = default;
54     ~OptimizeConsecutiveReshapesImpl() = default;
55 };
56
57 using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>;
58
59 } // namespace optimizations
60 } // namespace armnn