Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / primitive_desc.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn.h"
18
19 #include "c_types_map.hpp"
20 #include "nstl.hpp"
21 #include "primitive_desc.hpp"
22 #include "memory_pd.hpp"
23
24 using namespace mkldnn::impl;
25 using namespace mkldnn::impl::status;
26
27 status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
28     auto safe_ret_pd = [&](const memory_pd_t *_) {
29         if (_ == nullptr) return not_required;
30         *(const primitive_desc_t **)result = _;
31         return success;
32     };
33
34     switch (what) {
35         case query::engine: *(engine_t**)result = engine(); break;
36         case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
37
38         case query::memory_consumption_s64:
39             *(ptrdiff_t*)result = scratchpad_registry().size(); break;
40
41         case query::op_d:
42             if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
43             *(const_c_op_desc_t *)result
44                 = static_cast<const_c_op_desc_t>(op_desc()); break;
45
46         case query::input_pd: return safe_ret_pd(input_pd(idx));
47         case query::output_pd: return safe_ret_pd(output_pd(idx));
48         case query::src_pd: return safe_ret_pd(src_pd(idx));
49         case query::diff_src_pd: return safe_ret_pd(diff_src_pd(idx));
50         case query::dst_pd: return safe_ret_pd(dst_pd(idx));
51         case query::diff_dst_pd: return safe_ret_pd(diff_dst_pd(idx));
52         case query::weights_pd: return safe_ret_pd(weights_pd(idx));
53         case query::diff_weights_pd: return safe_ret_pd(diff_weights_pd(idx));
54         case query::workspace_pd:
55             if (idx != 0) return status::invalid_arguments;
56             return safe_ret_pd(workspace_pd(idx));
57
58         case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
59         case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
60
61         case query::impl_info_str: *(const char **)result = name(); break;
62
63         default: return unimplemented;
64     }
65     return success;
66 }
67
68 status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
69         const primitive_attr_t **attr) {
70     if (utils::any_null(primitive_desc, attr))
71         return invalid_arguments;
72
73     *attr = primitive_desc->attr();
74     return success;
75 }