2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/TensorFwd.hpp>
9 #include <armnn/Types.hpp>
11 #include <armnn/INetwork.hpp>
23 /// Private implementation of INetwork
24 class Network final : public INetwork
30 const Graph& GetGraph() const { return *m_Graph; }
32 Status PrintGraph() override;
34 IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
36 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
37 const ConstTensor& weights,
38 const char* name = nullptr) override;
40 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
41 const ConstTensor& weights,
42 const ConstTensor& biases,
43 const char* name = nullptr) override;
45 IConnectableLayer* AddDepthwiseConvolution2dLayer(
46 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
47 const ConstTensor& weights,
48 const char* name = nullptr) override;
50 IConnectableLayer* AddDepthwiseConvolution2dLayer(
51 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
52 const ConstTensor& weights,
53 const ConstTensor& biases,
54 const char* name = nullptr) override;
56 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
57 const ConstTensor& weights,
58 const char* name = nullptr) override;
60 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
61 const ConstTensor& weights,
62 const ConstTensor& biases,
63 const char* name = nullptr) override;
65 IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
66 const char* name = nullptr) override;
68 IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
69 const char* name = nullptr) override;
71 IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
72 const char* name = nullptr) override;
74 IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
75 const char* name = nullptr) override;
77 IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
78 const char* name = nullptr) override;
80 IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
81 const char* name = nullptr) override;
83 IConnectableLayer* AddMergerLayer(const OriginsDescriptor& mergerDescriptor,
84 const char* name = nullptr) override;
86 IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
88 IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
90 IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
91 const ConstTensor& mean,
92 const ConstTensor& variance,
93 const ConstTensor& beta,
94 const ConstTensor& gamma,
95 const char* name = nullptr) override;
97 IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
98 const char* name = nullptr) override;
100 IConnectableLayer* AddL2NormalizationLayer(const char* name = nullptr) override;
102 IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
104 IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
105 const char* name = nullptr) override;
107 IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
109 IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
112 IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
113 const ConstTensor& weights,
114 const ConstTensor* biases,
117 IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
118 const ConstTensor& weights,
119 const ConstTensor* biases,
122 IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
123 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
124 const ConstTensor& weights,
125 const ConstTensor* biases,
128 std::unique_ptr<Graph> m_Graph;
131 class OptimizedNetwork final : public IOptimizedNetwork
134 OptimizedNetwork(std::unique_ptr<Graph> graph);
137 Status PrintGraph() override;
139 Graph& GetGraph() { return *m_Graph; }
142 std::unique_ptr<Graph> m_Graph;