Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / primitive.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19
20 #include "cldnn_defs.h"
21 #include "compounds.h"
22 #include "layout.hpp"
23
24 #include <algorithm>
25 #include <string>
26 #include <vector>
27 #include <iostream>
28
29 namespace cldnn
30 {
31 /// @addtogroup cpp_api C++ API
32 /// @{
33
34 /// @addtogroup cpp_topology Network Topology
35 /// @{
36
37 /// @brief Globally unique primitive type id.
38 using primitive_type_id = cldnn_primitive_type_id;
39 /// @brief C API compatible unique @p id of a primitive within a topology.
40 using primitive_id_ref = cldnn_primitive_id;
41 /// @brief Unique @p id of a primitive within a topology.
42 using primitive_id = std::string;
43
44 /// @brief Dynamic cast to specified primitive description type.
45 template<class PType>
46 typename PType::dto* as_dto(CLDNN_PRIMITIVE_DESC(primitive)* dto)
47 {
48     if (dto->type != PType::type_id()) throw std::invalid_argument("type");
49     return reinterpret_cast<typename PType::dto*>(dto);
50 }
51
52 /// @brief Dynamic cast to specified primitive description type.
53 template<class PType>
54 const typename PType::dto* as_dto(const CLDNN_PRIMITIVE_DESC(primitive)* dto)
55 {
56     if (dto->type != PType::type_id()) throw std::invalid_argument("type");
57     return reinterpret_cast<const typename PType::dto*>(dto);
58 }
59
60 /// @brief Base class of network primitive description.
61 struct primitive
62 {
63     /// @brief Initialize fields common for all primitives.
64     struct fixed_size_vector_ref
65     {
66     private:
67         std::vector<primitive_id>& vref;
68
69     public:
70         fixed_size_vector_ref(std::vector<primitive_id>& ref) : vref(ref)
71         {}
72
73         auto size() const -> decltype(vref.size()) { return vref.size(); }
74         auto begin() const -> decltype(vref.begin()) { return vref.begin(); }
75         auto end() const -> decltype(vref.end()) { return vref.end(); }
76         auto cbegin() const -> decltype(vref.cbegin()) { return vref.cbegin(); }
77         auto cned() const -> decltype(vref.cend()) { return vref.cend(); }
78
79         primitive_id& operator[](size_t idx) { return vref[idx]; }
80         primitive_id const& operator[](size_t idx) const { return vref[idx]; }
81
82         primitive_id& at(size_t idx) { return vref.at(idx); }
83         primitive_id const& at(size_t idx) const { return vref.at(idx); }
84
85         primitive_id* data() { return vref.data(); }
86         const primitive_id* data() const { return vref.data(); }
87
88         const std::vector<primitive_id>& ref() const { return vref; }
89     };
90 public:
91     primitive(
92         const primitive_type_id& type,
93         const primitive_id& id,
94         const std::vector<primitive_id>& input,
95         const padding& output_padding = padding()
96     )
97         :type(type), id(id), input(_input.cpp_ids), output_padding(output_padding), _input(input)
98     {}
99
100     /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{primitive}
101     primitive(const CLDNN_PRIMITIVE_DESC(primitive)* dto)
102         :type(dto->type), id(dto->id), input(_input.cpp_ids), output_padding(dto->output_padding), _input(dto->input)
103     {}
104
105     virtual ~primitive() = default;
106
107     /// @brief Requested output padding.
108     /// @brief Requested output padding.
109     /// @brief Returns pointer to a C API primitive descriptor casted to @CLDNN_PRIMITIVE_DESC{primitive}.
110     virtual const CLDNN_PRIMITIVE_DESC(primitive)* get_dto() const = 0;
111
112     /// @brief Returns references to all primitive ids on which this primitive depends - inputs, weights, biases, etc.
113     std::vector<std::reference_wrapper<primitive_id>> dependencies()
114     {
115         std::vector<std::reference_wrapper<primitive_id>> result;
116         auto&& deps = get_dependencies();
117         
118         result.reserve(_input.size() + deps.size());
119         for (auto& pid : _input.cpp_ids)
120             result.push_back(std::ref(pid));
121         for (auto& pid : deps)
122             result.push_back(std::ref(const_cast<primitive_id&>(pid.get())));
123
124         return result;
125     }
126
127     /// @brief Returns copy of all primitive ids on which this primitive depends - inputs, weights, biases, etc.
128     std::vector<primitive_id> dependencies() const
129     {
130         auto result = input.ref();
131         auto deps = get_dependencies();
132         result.insert(result.end(), deps.begin(), deps.end());
133         return result;
134     }
135
136     /// @brief Implicit conversion to primiitive id.
137     operator primitive_id() const { return id; }
138
139     /// @brief Primitive's type id.
140     const primitive_type_id type;
141
142     /// @brief Primitive's id.
143     const primitive_id id;
144
145     /// @brief List of ids of input primitives.
146     fixed_size_vector_ref input;
147
148     /// @brief Requested output padding.
149     padding output_padding;
150
151 protected:
152     struct primitive_id_arr
153     {
154         primitive_id_arr(std::vector<primitive_id> const& vec) : cpp_ids(vec)
155         {}
156
157         primitive_id_arr(std::vector<primitive_id>&& vec) : cpp_ids(std::move(vec))
158         {}
159
160         //create from C API id array
161         primitive_id_arr(cldnn_primitive_id_arr c_id_arr)
162         {
163             cpp_ids.resize(c_id_arr.size);
164             for (size_t i = 0; i < c_id_arr.size; ++i)
165                 cpp_ids[i] = c_id_arr.data[i];
166         }
167
168         std::vector<primitive_id> cpp_ids;
169         mutable std::vector<cldnn_primitive_id> c_ids;
170         //get C API id array
171         auto ref() const -> decltype(cldnn_primitive_id_arr{c_ids.data(), c_ids.size()})
172         {
173             c_ids.resize(cpp_ids.size());
174             for (size_t i = 0; i < cpp_ids.size(); ++i)
175                 c_ids[i] = cpp_ids[i].c_str();
176
177             return cldnn_primitive_id_arr{ c_ids.data(), c_ids.size() };
178         }
179
180         size_t size() const { return cpp_ids.size(); }
181     };
182
183     primitive_id_arr _input;
184
185     virtual std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const { return{}; }
186 };
187
188 /// @brief base class for all primitives implementations.
189 template<class PType, class DTO>
190 class primitive_base : public primitive
191 {
192 public:
193     /// @brief Returns pointer to a C API primitive descriptor casted to @CLDNN_PRIMITIVE_DESC{primitive}.
194     const CLDNN_PRIMITIVE_DESC(primitive)* get_dto() const override
195     {
196         //update common dto fields
197         _dto.id = id.c_str();
198         _dto.type = type;
199         _dto.input = _input.ref();
200         _dto.output_padding = output_padding;
201
202         //call abstract method to update primitive-specific fields
203         update_dto(_dto);
204         return reinterpret_cast<const CLDNN_PRIMITIVE_DESC(primitive)*>(&_dto);
205     }
206
207 protected:
208     explicit primitive_base(
209         const primitive_id& id,
210         const std::vector<primitive_id>& input,
211         const padding& output_padding = padding())
212         : primitive(PType::type_id(), id, input, output_padding)
213     {}
214
215     primitive_base(const DTO* dto)
216         : primitive(reinterpret_cast<const CLDNN_PRIMITIVE_DESC(primitive)*>(dto))
217     {
218         if (dto->type != PType::type_id()) 
219             throw std::invalid_argument("DTO type mismatch");
220     }
221
222 private:
223     mutable DTO _dto;
224
225     virtual void update_dto(DTO& dto) const = 0;
226 };
227
228 #define CLDNN_DEFINE_TYPE_ID(PType) static primitive_type_id type_id()\
229     {\
230         return check_status<primitive_type_id>( #PType " type id failed", [](status_t* status)\
231         {\
232             return cldnn_##PType##_type_id(status);\
233         });\
234     }
235
236 #define CLDNN_DECLARE_PRIMITIVE(PType) typedef CLDNN_PRIMITIVE_DESC(PType) dto;\
237     CLDNN_DEFINE_TYPE_ID(PType)
238 /// @}
239 /// @}
240 }