1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef BINARIZATION_PD_HPP
18 #define BINARIZATION_PD_HPP
20 #include <mkldnn_types.h>
23 #include "c_types_map.hpp"
24 #include "primitive_desc.hpp"
25 #include "memory_pd.hpp"
30 struct binarization_fwd_pd_t: public primitive_desc_t {
31 typedef binarization_fwd_pd_t base_class;
32 typedef binarization_fwd_pd_t hint_class;
33 static constexpr auto base_pkind = primitive_kind::binarization;
35 binarization_fwd_pd_t(mkldnn::impl::engine_t *engine,
36 const binarization_desc_t *adesc, const primitive_attr_t *attr,
37 const binarization_fwd_pd_t *hint_fwd_pd)
38 : primitive_desc_t(engine, attr, primitive_kind::binarization)
39 , desc_(*adesc), hint_fwd_pd_(hint_fwd_pd) {}
40 virtual ~binarization_fwd_pd_t() {}
42 const binarization_desc_t *desc() const { return &desc_; }
43 virtual const op_desc_t *op_desc() const override
44 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
45 virtual void init_info() override { init_info_binarization(this, this->info_); }
47 virtual const memory_pd_t *input_pd(int index = 0) const override {
49 case 0: return src_pd();
50 case 1: return weights_pd(index - 1);
51 default: return nullptr;
54 virtual const memory_pd_t *output_pd(int index = 0) const override
55 { return index == 0 ? dst_pd() : nullptr; }
57 virtual int n_inputs() const override { return 2; }
58 virtual int n_outputs() const override { return 1; }
60 virtual status_t query(query_t what, int idx, void *result) const override
63 case query::binarization_d:
64 *(const binarization_desc_t**)result = desc(); break;
65 default: return primitive_desc_t::query(what, idx, result);
67 return status::success;
70 /* common binarization aux functions */
72 inline int MB() const { return input_pd()->desc()->ndims > 0 ? input_pd()->desc()->dims[0] : 1; }
73 inline int C() const { return input_pd()->desc()->ndims > 1 ? input_pd()->desc()->dims[1] : 1; }
74 inline int D() const { return input_pd()->desc()->ndims > 4 ? input_pd()->desc()->dims[2] : 1; }
75 inline int H() const { return input_pd()->desc()->ndims > 4 ? input_pd()->desc()->dims[3] :
76 input_pd()->desc()->ndims > 2 ? input_pd()->desc()->dims[2] : 1; }
77 inline int W() const { return input_pd()->desc()->ndims > 4 ? input_pd()->desc()->dims[4] :
78 input_pd()->desc()->ndims > 3 ? input_pd()->desc()->dims[3] : 1; }
81 binarization_desc_t desc_;
82 const binarization_fwd_pd_t *hint_fwd_pd_;