1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef SOFTMAX_FWD_PD_HPP
18 #define SOFTMAX_FWD_PD_HPP
22 #include "c_types_map.hpp"
23 #include "primitive_desc.hpp"
24 #include "memory_pd.hpp"
29 struct softmax_fwd_pd_t: public primitive_desc_t {
30 typedef softmax_fwd_pd_t base_class;
31 typedef softmax_fwd_pd_t hint_class;
32 static constexpr auto base_pkind = primitive_kind::softmax;
34 softmax_fwd_pd_t(mkldnn::impl::engine_t *engine,
35 const softmax_desc_t *adesc, const primitive_attr_t *attr,
36 const softmax_fwd_pd_t *hint_fwd_pd)
37 : primitive_desc_t(engine, attr, primitive_kind::softmax)
38 , desc_(*adesc), hint_fwd_pd_(hint_fwd_pd) {}
39 virtual ~softmax_fwd_pd_t() {}
41 const softmax_desc_t *desc() const { return &desc_; }
42 virtual const op_desc_t *op_desc() const override
43 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
44 virtual void init_info() override { init_info_softmax(this, this->info_); }
46 virtual const memory_pd_t *input_pd(int index = 0) const override
47 { return index == 0 ? src_pd() : nullptr; }
48 virtual const memory_pd_t *output_pd(int index = 0) const override {
49 if (index == 0) return dst_pd();
50 if (index == 1) return workspace_pd();
54 virtual int n_inputs() const override { return 1; }
55 virtual int n_outputs() const override
56 { return 1 + (workspace_pd() != nullptr); }
58 virtual status_t query(query_t what, int idx, void *result) const override
61 case query::softmax_d:
62 *(const softmax_desc_t**)result = desc(); break;
63 default: return primitive_desc_t::query(what, idx, result);
65 return status::success;
68 /* common softmax aux functions */
70 inline int MB() const { return desc_.data_desc.dims[0]; }
71 inline int C() const { return desc_.data_desc.dims[1]; }
72 inline int H() const { return desc_.data_desc.dims[2]; }
73 inline int W() const { return desc_.data_desc.dims[3]; }
77 const softmax_fwd_pd_t *hint_fwd_pd_;
80 struct softmax_bwd_pd_t: public primitive_desc_t {
81 typedef softmax_bwd_pd_t base_class;
82 typedef softmax_fwd_pd_t hint_class;
83 static constexpr auto base_pkind = primitive_kind::softmax;
85 softmax_bwd_pd_t(mkldnn::impl::engine_t *engine,
86 const softmax_desc_t *adesc, const primitive_attr_t *attr,
87 const softmax_fwd_pd_t *hint_fwd_pd) // FWD?
88 : primitive_desc_t(engine, attr, primitive_kind::softmax)
89 , desc_(*adesc), hint_fwd_pd_(hint_fwd_pd) {}
90 virtual ~softmax_bwd_pd_t() {}
92 const softmax_desc_t *desc() const { return &desc_; }
93 virtual const op_desc_t *op_desc() const override
94 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
95 virtual void init_info() override { init_info_softmax(this, this->info_); }
97 virtual const memory_pd_t *input_pd(int index = 0) const override {
98 if (index == 0) return dst_pd();
99 if (index == 1) return diff_dst_pd();
102 virtual const memory_pd_t *output_pd(int index = 0) const override
103 { return index == 0 ? diff_src_pd() : nullptr; }
105 virtual int n_inputs() const override
106 { return 2 + (workspace_pd() != nullptr); }
107 virtual int n_outputs() const override { return 1; }
109 virtual status_t query(query_t what, int idx, void *result) const override
112 case query::softmax_d:
113 *(const softmax_desc_t**)result = desc(); break;
114 default: return primitive_desc_t::query(what, idx, result);
116 return status::success;
119 /* common softmax aux functions */
121 inline int MB() const { return desc_.data_desc.dims[0]; }
122 inline int C() const { return desc_.data_desc.dims[1]; }
123 inline int H() const { return desc_.data_desc.dims[2]; }
124 inline int W() const { return desc_.data_desc.dims[3]; }
127 softmax_desc_t desc_;
128 const softmax_fwd_pd_t *hint_fwd_pd_;
137 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s