2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/LstmParams.hpp>
9 #include <armnn/TensorFwd.hpp>
10 #include <armnn/Types.hpp>
12 #include <armnn/INetwork.hpp>
24 /// Private implementation of INetwork.
25 class Network final : public INetwork
31 const Graph& GetGraph() const { return *m_Graph; }
33 Status PrintGraph() override;
35 IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
37 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
38 const ConstTensor& weights,
39 const char* name = nullptr) override;
41 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
42 const ConstTensor& weights,
43 const ConstTensor& biases,
44 const char* name = nullptr) override;
46 IConnectableLayer* AddDepthwiseConvolution2dLayer(
47 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
48 const ConstTensor& weights,
49 const char* name = nullptr) override;
51 IConnectableLayer* AddDepthwiseConvolution2dLayer(
52 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
53 const ConstTensor& weights,
54 const ConstTensor& biases,
55 const char* name = nullptr) override;
57 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
58 const ConstTensor& weights,
59 const char* name = nullptr) override;
61 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
62 const ConstTensor& weights,
63 const ConstTensor& biases,
64 const char* name = nullptr) override;
66 IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
67 const char* name = nullptr) override;
69 IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
70 const char* name = nullptr) override;
72 IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
73 const char* name = nullptr) override;
75 IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
76 const char* name = nullptr) override;
78 IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
79 const char* name = nullptr) override;
81 IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
82 const char* name = nullptr) override;
84 IConnectableLayer* AddMergerLayer(const OriginsDescriptor& mergerDescriptor,
85 const char* name = nullptr) override;
87 IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
89 IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
91 IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
92 const ConstTensor& mean,
93 const ConstTensor& variance,
94 const ConstTensor& beta,
95 const ConstTensor& gamma,
96 const char* name = nullptr) override;
98 IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
99 const char* name = nullptr) override;
101 IConnectableLayer* AddL2NormalizationLayer(const char* name = nullptr) override;
103 IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
105 IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
106 const char* name = nullptr) override;
108 IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
110 IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
112 IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
113 const LstmInputParams& params,
114 const char* name = nullptr) override;
117 IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
118 const ConstTensor& weights,
119 const ConstTensor* biases,
122 IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
123 const ConstTensor& weights,
124 const ConstTensor* biases,
127 IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
128 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
129 const ConstTensor& weights,
130 const ConstTensor* biases,
133 std::unique_ptr<Graph> m_Graph;
136 class OptimizedNetwork final : public IOptimizedNetwork
139 OptimizedNetwork(std::unique_ptr<Graph> graph);
142 Status PrintGraph() override;
143 Status SerializeToDot(std::ostream& stream) const override;
145 Graph& GetGraph() { return *m_Graph; }
148 std::unique_ptr<Graph> m_Graph;