Publishing R5 content (#72)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / primitive_desc.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn.h"
18
19 #include "c_types_map.hpp"
20 #include "nstl.hpp"
21 #include "primitive_desc.hpp"
22 #include "memory_pd.hpp"
23
24 using namespace mkldnn::impl;
25 using namespace mkldnn::impl::status;
26
27 status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
28     auto safe_ret_pd = [&](const memory_pd_t *_) {
29         if (_ == nullptr) return not_required;
30         *(const primitive_desc_t **)result = _;
31         return success;
32     };
33
34     switch (what) {
35         case query::engine: *(engine_t**)result = engine(); break;
36         case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
37
38         case query::op_d:
39             if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
40             *(const_c_op_desc_t *)result
41                 = static_cast<const_c_op_desc_t>(op_desc()); break;
42
43         case query::input_pd: return safe_ret_pd(input_pd(idx));
44         case query::output_pd: return safe_ret_pd(output_pd(idx));
45         case query::src_pd: return safe_ret_pd(src_pd(idx));
46         case query::diff_src_pd: return safe_ret_pd(diff_src_pd(idx));
47         case query::dst_pd: return safe_ret_pd(dst_pd(idx));
48         case query::diff_dst_pd: return safe_ret_pd(diff_dst_pd(idx));
49         case query::weights_pd: return safe_ret_pd(weights_pd(idx));
50         case query::diff_weights_pd: return safe_ret_pd(diff_weights_pd(idx));
51         case query::workspace_pd:
52             if (idx != 0) return status::invalid_arguments;
53             return safe_ret_pd(workspace_pd(idx));
54
55         case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
56         case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
57
58         case query::impl_info_str: *(const char **)result = name(); break;
59
60         default: return unimplemented;
61     }
62     return success;
63 }
64
65 status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
66         const primitive_attr_t **attr) {
67     if (utils::any_null(primitive_desc, attr))
68         return invalid_arguments;
69
70     *attr = primitive_desc->attr();
71     return success;
72 }