Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / memory.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <stddef.h>
19 #include <stdint.h>
20
21 #include "mkldnn.h"
22
23 #include "c_types_map.hpp"
24 #include "engine.hpp"
25 #include "memory_pd.hpp"
26 #include "type_helpers.hpp"
27 #include "utils.hpp"
28
29 using namespace mkldnn::impl;
30 using namespace mkldnn::impl::utils;
31 using namespace mkldnn::impl::status;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::data_type;
34
35 namespace {
36 bool memory_desc_sanity_check(int ndims,const dims_t dims,
37         data_type_t data_type, memory_format_t format) {
38     if (ndims == 0) return true;
39
40     bool ok = true
41         && dims != nullptr
42         && 0 < ndims && ndims <= TENSOR_MAX_DIMS
43         && one_of(data_type, f32, s32, s16, s8, u8, bin)
44         && format != memory_format::undef;
45     if (!ok) return false;
46     for (int d = 0; d < ndims; ++d)
47         if (dims[d] < 0) return false;
48
49     return true;
50 }
51
52 bool memory_desc_sanity_check(const memory_desc_t *md) {
53     if (md == nullptr) return false;
54     return memory_desc_sanity_check(md->ndims, md->dims, md->data_type,
55             md->format);
56 }
57 }
58
59 status_t mkldnn_memory_desc_init(memory_desc_t *memory_desc, int ndims,
60         const dims_t dims, data_type_t data_type, memory_format_t format) {
61     if (any_null(memory_desc)) return invalid_arguments;
62     if (ndims == 0 || format == memory_format::undef) {
63         *memory_desc = types::zero_md();
64         return success;
65     }
66
67     /* memory_desc != 0 */
68     bool args_ok = !any_null(memory_desc)
69         && memory_desc_sanity_check(ndims, dims, data_type, format);
70     if (!args_ok) return invalid_arguments;
71
72     memory_desc_t md;
73     md.ndims = ndims;
74     array_copy(md.dims, dims, ndims);
75     md.primitive_kind = primitive_kind::memory;
76     md.data_type = data_type;
77     md.format = format;
78
79     status_t status = success;
80     if (one_of(format, memory_format::undef, blocked, wino_fmt, rnn_packed)) {
81         status = invalid_arguments;
82     } else if (format == any) {
83         // nop
84     } else if (types::format_normalize(format) == blocked) {
85         status = memory_desc_wrapper::compute_blocking(md);
86     } else {
87         assert(!"unreachable");
88         status = invalid_arguments;
89     }
90
91     if (status == success)
92         *memory_desc = md;
93
94     return status;
95 }
96
97 status_t mkldnn_memory_primitive_desc_create(primitive_desc_t **memory_pd,
98         const memory_desc_t *memory_desc, engine_t *engine) {
99     bool args_ok = !any_null(memory_pd, memory_desc, engine)
100         && memory_desc_sanity_check(memory_desc)
101         && memory_desc_wrapper(*memory_desc).is_defined();
102     if (!args_ok) return invalid_arguments;
103     return engine->memory_primitive_desc_create(
104             (memory_pd_t**)memory_pd, memory_desc);
105 }
106
107 status_t mkldnn_view_primitive_desc_create(primitive_desc_t **view_pd,
108         const primitive_desc_t *memory_pd, const dims_t dims,
109         const dims_t offsets) {
110     const memory_pd_t *mpd =
111         (const memory_pd_t*)memory_pd;
112
113     bool args_ok = !any_null(view_pd, memory_pd, dims, offsets)
114         && memory_pd->kind() == primitive_kind::memory
115         && memory_desc_sanity_check(mpd->desc());
116     if (!args_ok) return invalid_arguments;
117
118     memory_desc_wrapper md(*mpd->desc());
119     for (int d = 0; d < md.ndims(); ++d) {
120         if (dims[d] < 0 || offsets[d] < 0
121                 || (offsets[d] + dims[d] > md.dims()[d]))
122             return invalid_arguments;
123     }
124     return memory_pd->engine()->view_primitive_desc_create(
125             (view_pd_t**)view_pd, mpd, dims, offsets);
126 }
127
128 int mkldnn_memory_primitive_desc_equal(const primitive_desc_t *lhs,
129         const primitive_desc_t *rhs) {
130     bool args_ok = !any_null(lhs, rhs)
131         && lhs->engine() == rhs->engine()
132         && one_of(lhs->kind(), primitive_kind::memory, primitive_kind::view)
133         && one_of(rhs->kind(), primitive_kind::memory, primitive_kind::view);
134     if (!args_ok) return 0;
135     auto l = (const memory_pd_t *)lhs;
136     auto r = (const memory_pd_t *)rhs;
137     /* FIXME: view! */
138     return l->is_equal(r);
139 }
140
141 size_t mkldnn_memory_primitive_desc_get_size(const primitive_desc_t *memory_pd)
142 {
143     bool args_ok = !any_null(memory_pd)
144         && memory_pd->kind() == primitive_kind::memory;
145     if (!args_ok) return 0;
146     /* FIXME: view? */
147     return ((memory_pd_t*)memory_pd)->get_size();
148 }
149
150 status_t mkldnn_memory_get_data_handle(const primitive_t *memory,
151         void **handle) {
152     if (any_null(handle))
153         return invalid_arguments;
154     if (memory == nullptr) {
155         *handle = nullptr;
156         return success;
157     }
158     if (memory->kind() != primitive_kind::memory)
159         return invalid_arguments;
160     return memory->get_data_handle(handle);
161 }
162
163 status_t mkldnn_memory_set_data_handle(primitive_t *memory, void *handle) {
164     if (any_null(memory) || memory->kind() != primitive_kind::memory)
165         return invalid_arguments;
166     return memory->set_data_handle(handle);
167 }
168
169 status_t mkldnn_concat_primitive_desc_create_v2(primitive_desc_t **concat_pd,
170         const memory_desc_t *output_d, int n, int concat_dim,
171         const primitive_desc_t **input_pds, const primitive_attr_t *attr) {
172     bool args_ok = !any_null(concat_pd, input_pds) && n > 0;
173     if (!args_ok) return invalid_arguments;
174     for (int i = 0; i < n; ++i) {
175         if (input_pds[i] == nullptr ||
176                 input_pds[i]->kind() != primitive_kind::memory)
177             return invalid_arguments;
178     }
179
180     const primitive_attr_t dummy_attr;
181     if (attr == NULL)
182         attr = &dummy_attr;
183
184     auto i_mpds = (const memory_pd_t **)input_pds;
185     engine_t *engine = i_mpds[0]->engine();
186     const int ndims = i_mpds[0]->desc()->ndims;
187     const dims_t &dims = i_mpds[0]->desc()->dims;
188     const data_type_t dt = i_mpds[0]->desc()->data_type;
189
190     int concat_dim_sz = dims[concat_dim];
191     for (int i = 1; i < n; ++i) {
192         if (i_mpds[i]->engine() != engine) return invalid_arguments;
193         if (i_mpds[i]->desc()->ndims != ndims) return invalid_arguments;
194         for (int d = 0; d < ndims; ++d) {
195             if (d == concat_dim) continue;
196             if (i_mpds[i]->desc()->dims[d] != dims[d])
197                 return invalid_arguments;
198         }
199         if (i_mpds[i]->desc()->data_type != dt) return invalid_arguments;
200         concat_dim_sz += i_mpds[i]->desc()->dims[concat_dim];
201     }
202
203     memory_desc_t dummy_output_d;
204     if (output_d) {
205         if (output_d->ndims != ndims) return invalid_arguments;
206         for (int d = 0; d < ndims; ++d) {
207             if (output_d->dims[d] !=
208                     (d == concat_dim ? concat_dim_sz : dims[d]))
209                 return invalid_arguments;
210         }
211     } else {
212         dummy_output_d = *i_mpds[0]->desc();
213         dummy_output_d.dims[concat_dim] = concat_dim_sz;
214         dummy_output_d.format = memory_format::any;
215         output_d = &dummy_output_d;
216     }
217
218     auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd);
219
220     for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
221         if ((*c)(c_pd, output_d, n, concat_dim, i_mpds, attr) == success) {
222             (*c_pd)->init_info();
223             return success;
224         }
225     }
226     return unimplemented;
227 }
228
229 status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd,
230         const memory_desc_t *output_d, int n, int concat_dim,
231         const primitive_desc_t **input_pds) {
232     return mkldnn_concat_primitive_desc_create_v2(concat_pd, output_d, n,
233             concat_dim, input_pds, nullptr);
234 }
235
236 status_t mkldnn_sum_primitive_desc_create_v2(primitive_desc_t **sum_pd,
237         const memory_desc_t *output_d, int n, const float *scales,
238         const primitive_desc_t **input_pds, const primitive_attr_t *attr) {
239     bool args_ok = !any_null(sum_pd, input_pds, scales) && n > 0;
240     if (!args_ok) return invalid_arguments;
241     for (int i = 0; i < n; ++i) {
242         if (input_pds[i] == nullptr ||
243                 input_pds[i]->kind() != primitive_kind::memory)
244             return invalid_arguments;
245     }
246
247     const primitive_attr_t dummy_attr;
248     if (attr == NULL)
249         attr = &dummy_attr;
250
251     auto i_mpds = (const memory_pd_t **)input_pds;
252     engine_t *engine = i_mpds[0]->engine();
253     const int ndims = i_mpds[0]->desc()->ndims;
254     const dims_t &dims = i_mpds[0]->desc()->dims;
255     const data_type_t dt = i_mpds[0]->desc()->data_type;
256
257     for (int i = 1; i < n; ++i) {
258         if (i_mpds[i]->engine() != engine) return invalid_arguments;
259         if (i_mpds[i]->desc()->ndims != ndims) return invalid_arguments;
260         for (int d = 0; d < ndims; ++d) {
261             if (i_mpds[i]->desc()->dims[d] != dims[d])
262                 return invalid_arguments;
263         }
264         if (i_mpds[i]->desc()->data_type != dt) return invalid_arguments;
265     }
266
267     memory_desc_t dummy_output_d;
268     if (output_d) {
269         if (output_d->ndims != ndims) return invalid_arguments;
270         for (int d = 0; d < ndims; ++d) {
271             if (output_d->dims[d] != dims[d])
272                 return invalid_arguments;
273         }
274     } else {
275         dummy_output_d = *i_mpds[0]->desc();
276         dummy_output_d.format = memory_format::any;
277         output_d = &dummy_output_d;
278     }
279
280     auto s_pd = reinterpret_cast<sum_pd_t **>(sum_pd);
281
282     for (auto s = engine->get_sum_implementation_list(); *s; ++s) {
283         if ((*s)(s_pd, output_d, n, scales, i_mpds, attr) == success) {
284             (*s_pd)->init_info();
285             return success;
286         }
287     }
288     return unimplemented;
289 }
290
291 status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd,
292         const memory_desc_t *output_d, int n, const float *scales,
293         const primitive_desc_t **input_pds) {
294     return mkldnn_sum_primitive_desc_create_v2(sum_pd, output_d, n, scales,
295             input_pds, nullptr);
296 }
297
298 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s