2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include "Optimization.hpp"
11 namespace optimizations
14 class OptimizeConsecutiveReshapesImpl
17 /// Run for every connection between a base RashapeLayer and a child ReshapeLayer.
18 /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
19 void Run(Graph& graph, InputSlot& connection) const
21 auto& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22 auto& child = connection.GetOwningLayer();
24 BOOST_ASSERT(base.GetType() == LayerType::Reshape);
25 BOOST_ASSERT(child.GetType() == LayerType::Reshape);
27 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
29 const TensorInfo& inInfo = parentOut->GetTensorInfo();
30 const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
32 if (inInfo.GetShape() != outInfo.GetShape())
34 // Insert equivalent reshape before base layer
35 const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
36 const ReshapeDescriptor descriptor{outInfo.GetShape()};
37 auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
38 // Set tensor info for new layer
39 newReshape.GetOutputHandler().SetTensorInfo(outInfo);
40 // Reconnect base with original parent
41 newReshape.GetOutputSlot().MoveAllConnections(*parentOut);
42 // Parent is now the new layer
43 parentOut = &newReshape.GetOutputSlot();
46 // Move connections in child output to parent layer.
47 // Child layer will be removed as it's left unconnected.
48 // Base layer will be removed if left unconnected.
49 child.GetOutputSlot().MoveAllConnections(*parentOut);
53 OptimizeConsecutiveReshapesImpl() = default;
54 ~OptimizeConsecutiveReshapesImpl() = default;
57 using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>;
59 } // namespace optimizations