Release 18.08
[platform/upstream/armnn.git] / src / armnn / optimizations / OptimizeInverseConversions.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "Optimization.hpp"
8
9 namespace armnn
10 {
11 namespace optimizations
12 {
13
14 class OptimizeInverseConversionsImpl
15 {
16 public:
17     /// Run for every connection between two inverse data type conversion layers, i.e.
18     /// Fp16ToFp32 followed by Fp32ToFp16 or vice-versa.
19     void Run(Graph& graph, InputSlot& connection) const
20     {
21         Layer& base  = connection.GetConnectedOutputSlot()->GetOwningLayer();
22         Layer& child = connection.GetOwningLayer();
23
24         BOOST_ASSERT((base.GetType() == LayerType::ConvertFp16ToFp32 &&
25                      child.GetType() == LayerType::ConvertFp32ToFp16) ||
26                      (base.GetType() == LayerType::ConvertFp32ToFp16 &&
27                      child.GetType() == LayerType::ConvertFp16ToFp32));
28
29         // Bypass both conversion layers
30         child.GetOutputSlot().MoveAllConnections(*base.GetInputSlot(0).GetConnectedOutputSlot());
31     }
32
33 protected:
34     OptimizeInverseConversionsImpl()  = default;
35     ~OptimizeInverseConversionsImpl() = default;
36 };
37
38 using OptimizeInverseConversionsFp16 =
39     OptimizeForConnection<ConvertFp16ToFp32Layer, ConvertFp32ToFp16Layer, OptimizeInverseConversionsImpl>;
40 using OptimizeInverseConversionsFp32 =
41     OptimizeForConnection<ConvertFp32ToFp16Layer, ConvertFp16ToFp32Layer, OptimizeInverseConversionsImpl>;
42
43 } // namespace optimizations
44 } // namespace armnn