1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
23 #include "c_types_map.hpp"
25 #include "memory_pd.hpp"
26 #include "type_helpers.hpp"
29 using namespace mkldnn::impl;
30 using namespace mkldnn::impl::utils;
31 using namespace mkldnn::impl::status;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::data_type;
36 bool memory_desc_sanity_check(int ndims,const dims_t dims,
37 data_type_t data_type, memory_format_t format) {
38 if (ndims == 0) return true;
42 && 0 < ndims && ndims <= TENSOR_MAX_DIMS
43 && one_of(data_type, f32, s32, s16, s8, u8, bin)
44 && format != memory_format::undef;
45 if (!ok) return false;
46 for (int d = 0; d < ndims; ++d)
47 if (dims[d] < 0) return false;
52 bool memory_desc_sanity_check(const memory_desc_t *md) {
53 if (md == nullptr) return false;
54 return memory_desc_sanity_check(md->ndims, md->dims, md->data_type,
59 status_t mkldnn_memory_desc_init(memory_desc_t *memory_desc, int ndims,
60 const dims_t dims, data_type_t data_type, memory_format_t format) {
61 if (any_null(memory_desc)) return invalid_arguments;
62 if (ndims == 0 || format == memory_format::undef) {
63 *memory_desc = types::zero_md();
67 /* memory_desc != 0 */
68 bool args_ok = !any_null(memory_desc)
69 && memory_desc_sanity_check(ndims, dims, data_type, format);
70 if (!args_ok) return invalid_arguments;
74 array_copy(md.dims, dims, ndims);
75 md.primitive_kind = primitive_kind::memory;
76 md.data_type = data_type;
79 status_t status = success;
80 if (one_of(format, memory_format::undef, blocked, wino_fmt, rnn_packed)) {
81 status = invalid_arguments;
82 } else if (format == any) {
84 } else if (types::format_normalize(format) == blocked) {
85 status = memory_desc_wrapper::compute_blocking(md);
87 assert(!"unreachable");
88 status = invalid_arguments;
91 if (status == success)
97 status_t mkldnn_memory_primitive_desc_create(primitive_desc_t **memory_pd,
98 const memory_desc_t *memory_desc, engine_t *engine) {
99 bool args_ok = !any_null(memory_pd, memory_desc, engine)
100 && memory_desc_sanity_check(memory_desc)
101 && memory_desc_wrapper(*memory_desc).is_defined();
102 if (!args_ok) return invalid_arguments;
103 return engine->memory_primitive_desc_create(
104 (memory_pd_t**)memory_pd, memory_desc);
107 status_t mkldnn_view_primitive_desc_create(primitive_desc_t **view_pd,
108 const primitive_desc_t *memory_pd, const dims_t dims,
109 const dims_t offsets) {
110 const memory_pd_t *mpd =
111 (const memory_pd_t*)memory_pd;
113 bool args_ok = !any_null(view_pd, memory_pd, dims, offsets)
114 && memory_pd->kind() == primitive_kind::memory
115 && memory_desc_sanity_check(mpd->desc());
116 if (!args_ok) return invalid_arguments;
118 memory_desc_wrapper md(*mpd->desc());
119 for (int d = 0; d < md.ndims(); ++d) {
120 if (dims[d] < 0 || offsets[d] < 0
121 || (offsets[d] + dims[d] > md.dims()[d]))
122 return invalid_arguments;
124 return memory_pd->engine()->view_primitive_desc_create(
125 (view_pd_t**)view_pd, mpd, dims, offsets);
128 int mkldnn_memory_primitive_desc_equal(const primitive_desc_t *lhs,
129 const primitive_desc_t *rhs) {
130 bool args_ok = !any_null(lhs, rhs)
131 && lhs->engine() == rhs->engine()
132 && one_of(lhs->kind(), primitive_kind::memory, primitive_kind::view)
133 && one_of(rhs->kind(), primitive_kind::memory, primitive_kind::view);
134 if (!args_ok) return 0;
135 auto l = (const memory_pd_t *)lhs;
136 auto r = (const memory_pd_t *)rhs;
138 return l->is_equal(r);
141 size_t mkldnn_memory_primitive_desc_get_size(const primitive_desc_t *memory_pd)
143 bool args_ok = !any_null(memory_pd)
144 && memory_pd->kind() == primitive_kind::memory;
145 if (!args_ok) return 0;
147 return ((memory_pd_t*)memory_pd)->get_size();
150 status_t mkldnn_memory_get_data_handle(const primitive_t *memory,
152 if (any_null(handle))
153 return invalid_arguments;
154 if (memory == nullptr) {
158 if (memory->kind() != primitive_kind::memory)
159 return invalid_arguments;
160 return memory->get_data_handle(handle);
163 status_t mkldnn_memory_set_data_handle(primitive_t *memory, void *handle) {
164 if (any_null(memory) || memory->kind() != primitive_kind::memory)
165 return invalid_arguments;
166 return memory->set_data_handle(handle);
169 status_t mkldnn_concat_primitive_desc_create_v2(primitive_desc_t **concat_pd,
170 const memory_desc_t *output_d, int n, int concat_dim,
171 const primitive_desc_t **input_pds, const primitive_attr_t *attr) {
172 bool args_ok = !any_null(concat_pd, input_pds) && n > 0;
173 if (!args_ok) return invalid_arguments;
174 for (int i = 0; i < n; ++i) {
175 if (input_pds[i] == nullptr ||
176 input_pds[i]->kind() != primitive_kind::memory)
177 return invalid_arguments;
180 const primitive_attr_t dummy_attr;
184 auto i_mpds = (const memory_pd_t **)input_pds;
185 engine_t *engine = i_mpds[0]->engine();
186 const int ndims = i_mpds[0]->desc()->ndims;
187 const dims_t &dims = i_mpds[0]->desc()->dims;
188 const data_type_t dt = i_mpds[0]->desc()->data_type;
190 int concat_dim_sz = dims[concat_dim];
191 for (int i = 1; i < n; ++i) {
192 if (i_mpds[i]->engine() != engine) return invalid_arguments;
193 if (i_mpds[i]->desc()->ndims != ndims) return invalid_arguments;
194 for (int d = 0; d < ndims; ++d) {
195 if (d == concat_dim) continue;
196 if (i_mpds[i]->desc()->dims[d] != dims[d])
197 return invalid_arguments;
199 if (i_mpds[i]->desc()->data_type != dt) return invalid_arguments;
200 concat_dim_sz += i_mpds[i]->desc()->dims[concat_dim];
203 memory_desc_t dummy_output_d;
205 if (output_d->ndims != ndims) return invalid_arguments;
206 for (int d = 0; d < ndims; ++d) {
207 if (output_d->dims[d] !=
208 (d == concat_dim ? concat_dim_sz : dims[d]))
209 return invalid_arguments;
212 dummy_output_d = *i_mpds[0]->desc();
213 dummy_output_d.dims[concat_dim] = concat_dim_sz;
214 dummy_output_d.format = memory_format::any;
215 output_d = &dummy_output_d;
218 auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd);
220 for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
221 if ((*c)(c_pd, output_d, n, concat_dim, i_mpds, attr) == success) {
222 (*c_pd)->init_info();
226 return unimplemented;
229 status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd,
230 const memory_desc_t *output_d, int n, int concat_dim,
231 const primitive_desc_t **input_pds) {
232 return mkldnn_concat_primitive_desc_create_v2(concat_pd, output_d, n,
233 concat_dim, input_pds, nullptr);
236 status_t mkldnn_sum_primitive_desc_create_v2(primitive_desc_t **sum_pd,
237 const memory_desc_t *output_d, int n, const float *scales,
238 const primitive_desc_t **input_pds, const primitive_attr_t *attr) {
239 bool args_ok = !any_null(sum_pd, input_pds, scales) && n > 0;
240 if (!args_ok) return invalid_arguments;
241 for (int i = 0; i < n; ++i) {
242 if (input_pds[i] == nullptr ||
243 input_pds[i]->kind() != primitive_kind::memory)
244 return invalid_arguments;
247 const primitive_attr_t dummy_attr;
251 auto i_mpds = (const memory_pd_t **)input_pds;
252 engine_t *engine = i_mpds[0]->engine();
253 const int ndims = i_mpds[0]->desc()->ndims;
254 const dims_t &dims = i_mpds[0]->desc()->dims;
255 const data_type_t dt = i_mpds[0]->desc()->data_type;
257 for (int i = 1; i < n; ++i) {
258 if (i_mpds[i]->engine() != engine) return invalid_arguments;
259 if (i_mpds[i]->desc()->ndims != ndims) return invalid_arguments;
260 for (int d = 0; d < ndims; ++d) {
261 if (i_mpds[i]->desc()->dims[d] != dims[d])
262 return invalid_arguments;
264 if (i_mpds[i]->desc()->data_type != dt) return invalid_arguments;
267 memory_desc_t dummy_output_d;
269 if (output_d->ndims != ndims) return invalid_arguments;
270 for (int d = 0; d < ndims; ++d) {
271 if (output_d->dims[d] != dims[d])
272 return invalid_arguments;
275 dummy_output_d = *i_mpds[0]->desc();
276 dummy_output_d.format = memory_format::any;
277 output_d = &dummy_output_d;
280 auto s_pd = reinterpret_cast<sum_pd_t **>(sum_pd);
282 for (auto s = engine->get_sum_implementation_list(); *s; ++s) {
283 if ((*s)(s_pd, output_d, n, scales, i_mpds, attr) == success) {
284 (*s_pd)->init_info();
288 return unimplemented;
291 status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd,
292 const memory_desc_t *output_d, int n, const float *scales,
293 const primitive_desc_t **input_pds) {
294 return mkldnn_sum_primitive_desc_create_v2(sum_pd, output_d, n, scales,
298 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s