Release 18.08
[platform/upstream/armnn.git] / src / armnn / Network.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/LstmParams.hpp>
9 #include <armnn/TensorFwd.hpp>
10 #include <armnn/Types.hpp>
11
12 #include <armnn/INetwork.hpp>
13
14 #include <string>
15 #include <vector>
16 #include <memory>
17
18 #include "Layer.hpp"
19
20 namespace armnn
21 {
22 class Graph;
23
24 /// Private implementation of INetwork.
25 class Network final : public INetwork
26 {
27 public:
28     Network();
29     ~Network();
30
31     const Graph& GetGraph() const { return *m_Graph; }
32
33     Status PrintGraph() override;
34
35     IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
36
37     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
38         const ConstTensor& weights,
39         const char* name = nullptr) override;
40
41     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
42         const ConstTensor& weights,
43         const ConstTensor& biases,
44         const char* name = nullptr) override;
45
46     IConnectableLayer* AddDepthwiseConvolution2dLayer(
47         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
48         const ConstTensor&                      weights,
49         const char*                             name = nullptr) override;
50
51     IConnectableLayer* AddDepthwiseConvolution2dLayer(
52         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
53         const ConstTensor&                      weights,
54         const ConstTensor&                      biases,
55         const char*                             name = nullptr) override;
56
57     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
58         const ConstTensor& weights,
59         const char* name = nullptr) override;
60
61     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
62         const ConstTensor& weights,
63         const ConstTensor& biases,
64         const char* name = nullptr) override;
65
66     IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
67                                        const char* name = nullptr) override;
68
69     IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
70         const char* name = nullptr) override;
71
72     IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
73         const char* name = nullptr) override;
74
75     IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
76         const char* name = nullptr) override;
77
78     IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
79         const char* name = nullptr) override;
80
81     IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
82         const char* name = nullptr) override;
83
84     IConnectableLayer* AddMergerLayer(const OriginsDescriptor& mergerDescriptor,
85         const char* name = nullptr) override;
86
87     IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
88
89     IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
90
91     IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
92                                                   const ConstTensor&                  mean,
93                                                   const ConstTensor&                  variance,
94                                                   const ConstTensor&                  beta,
95                                                   const ConstTensor&                  gamma,
96                                                   const char*                         name = nullptr) override;
97
98     IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
99                                               const char* name = nullptr) override;
100
101     IConnectableLayer* AddL2NormalizationLayer(const char* name = nullptr) override;
102
103     IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
104
105     IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
106                                        const char* name = nullptr) override;
107
108     IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
109
110     IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
111
112     IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
113                                     const LstmInputParams& params,
114                                     const char* name = nullptr) override;
115
116 private:
117     IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
118         const ConstTensor& weights,
119         const ConstTensor* biases,
120         const char* name);
121
122     IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
123         const ConstTensor& weights,
124         const ConstTensor* biases,
125         const char* name);
126
127     IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
128         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
129         const ConstTensor& weights,
130         const ConstTensor* biases,
131         const char* name);
132
133     std::unique_ptr<Graph> m_Graph;
134 };
135
136 class OptimizedNetwork final : public IOptimizedNetwork
137 {
138 public:
139     OptimizedNetwork(std::unique_ptr<Graph> graph);
140     ~OptimizedNetwork();
141
142     Status PrintGraph() override;
143     Status SerializeToDot(std::ostream& stream) const override;
144
145     Graph& GetGraph() { return *m_Graph; }
146
147 private:
148     std::unique_ptr<Graph> m_Graph;
149 };
150
151 } // namespace armnn