1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "c_types_map.hpp"
21 #include "primitive_desc.hpp"
22 #include "memory_pd.hpp"
24 using namespace mkldnn::impl;
25 using namespace mkldnn::impl::status;
27 status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
28 auto safe_ret_pd = [&](const memory_pd_t *_) {
29 if (_ == nullptr) return not_required;
30 *(const primitive_desc_t **)result = _;
35 case query::engine: *(engine_t**)result = engine(); break;
36 case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
39 if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
40 *(const_c_op_desc_t *)result
41 = static_cast<const_c_op_desc_t>(op_desc()); break;
43 case query::input_pd: return safe_ret_pd(input_pd(idx));
44 case query::output_pd: return safe_ret_pd(output_pd(idx));
45 case query::src_pd: return safe_ret_pd(src_pd(idx));
46 case query::diff_src_pd: return safe_ret_pd(diff_src_pd(idx));
47 case query::dst_pd: return safe_ret_pd(dst_pd(idx));
48 case query::diff_dst_pd: return safe_ret_pd(diff_dst_pd(idx));
49 case query::weights_pd: return safe_ret_pd(weights_pd(idx));
50 case query::diff_weights_pd: return safe_ret_pd(diff_weights_pd(idx));
51 case query::workspace_pd:
52 if (idx != 0) return status::invalid_arguments;
53 return safe_ret_pd(workspace_pd(idx));
55 case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
56 case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
58 case query::impl_info_str: *(const char **)result = name(); break;
60 default: return unimplemented;
65 status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
66 const primitive_attr_t **attr) {
67 if (utils::any_null(primitive_desc, attr))
68 return invalid_arguments;
70 *attr = primitive_desc->attr();