updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_pooling_pd.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_POOLING_PD_HPP
18 #define CPU_POOLING_PD_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "pooling_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "cpu_memory.hpp"
26 #include "cpu_primitive.hpp"
27 #include "type_helpers.hpp"
28 #include "utils.hpp"
29 #include "nstl.hpp"
30
31 namespace mkldnn {
32 namespace impl {
33 namespace cpu {
34
35 inline data_type_t pooling_index_data_type(const pooling_desc_t *p) {
36     using nstl::numeric_limits;
37     /* the simplest way to express 256... */
38     const int u8_max =
39         numeric_limits<typename prec_traits<data_type::u8>::type>::max();
40     /* value u8_max in the case of data_type::u8 is reserved for
41        designation of invalid index when pooling window is fully placed
42        outside of source domain */
43     if( p->src_desc.ndims == 5 || p->diff_src_desc.ndims == 5 ) {
44         return p->kernel[0] * p->kernel[1] * p->kernel[2] < u8_max
45             ? data_type::u8 : data_type::s32;
46     } else {
47         return p->kernel[0] * p->kernel[1] < u8_max
48             ? data_type::u8 : data_type::s32;
49     }
50 }
51
52 struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t {
53     using cpu_memory_pd_t = cpu_memory_t::pd_t;
54
55     cpu_pooling_fwd_pd_t(engine_t *engine, const pooling_desc_t *adesc,
56             const primitive_attr_t *attr, const pooling_fwd_pd_t *hint_fwd_pd)
57         : pooling_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
58         , src_pd_(engine_, &desc_.src_desc), dst_pd_(engine_, &desc_.dst_desc)
59         , ws_pd_(engine_) {}
60     virtual ~cpu_pooling_fwd_pd_t() {}
61
62     virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
63     { return index == 0 ? &src_pd_ : nullptr; }
64     virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override
65     { return index == 0 ? &dst_pd_ : nullptr; }
66     virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override
67     { return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr; }
68
69 protected:
70     cpu_memory_pd_t src_pd_;
71     cpu_memory_pd_t dst_pd_;
72     cpu_memory_pd_t ws_pd_;
73
74     virtual status_t init() = 0;
75
76     virtual status_t set_default_params() {
77         using namespace memory_format;
78         if (dst_pd_.desc()->format == any)
79             CHECK(dst_pd_.set_format(src_pd_.desc()->format));
80         return status::success;
81     }
82 };
83
84 struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t {
85     using cpu_memory_pd_t = cpu_memory_t::pd_t;
86
87     cpu_pooling_bwd_pd_t(engine_t *engine, const pooling_desc_t *adesc,
88             const primitive_attr_t *attr, const pooling_fwd_pd_t *hint_fwd_pd)
89         : pooling_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
90         , diff_src_pd_(engine_, &desc_.diff_src_desc)
91         , diff_dst_pd_(engine_, &desc_.diff_dst_desc)
92         , ws_pd_(engine_) {}
93     virtual ~cpu_pooling_bwd_pd_t() {}
94
95     virtual const cpu_memory_pd_t *diff_src_pd(int index = 0) const override
96     { return index == 0 ? &diff_src_pd_ : nullptr; }
97     virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override
98     { return index == 0 ? &diff_dst_pd_ : nullptr; }
99     virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override
100     { return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr; }
101
102 protected:
103     cpu_memory_pd_t diff_src_pd_;
104     cpu_memory_pd_t diff_dst_pd_;
105     cpu_memory_pd_t ws_pd_;
106
107     virtual status_t init() = 0;
108
109     virtual status_t set_default_params() {
110         using namespace memory_format;
111         if (diff_src_pd_.desc()->format == any)
112             CHECK(diff_src_pd_.set_format(diff_dst_pd_.desc()->format));
113         return status::success;
114     }
115 };
116
117 }
118 }
119 }
120
121 #endif
122
123 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s