Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_memory.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_MEMORY_HPP
18 #define CPU_MEMORY_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "cpu_primitive.hpp"
24 #include "event.hpp"
25 #include "memory_pd.hpp"
26 #include "type_helpers.hpp"
27 #include "utils.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl;
34 using namespace mkldnn::impl::status;
35
36 struct cpu_memory_t: public cpu_primitive_t {
37     struct pd_t: public memory_pd_t {
38         pd_t(engine_t *engine): memory_pd_t(engine) {}
39         pd_t(engine_t *engine, const memory_desc_t *adesc)
40             : memory_pd_t(engine, adesc) {}
41         virtual ~pd_t() {}
42         virtual pd_t *clone() const { return new pd_t(engine(), desc()); }
43         virtual status_t create_primitive(primitive_t **primitive,
44                 const primitive_at_t *inputs, const primitive_t **outputs) const
45         {
46             UNUSED(inputs); UNUSED(outputs);
47             return safe_ptr_assign<primitive_t>(*primitive,
48                     new cpu_memory_t(this));
49         }
50     };
51
52     cpu_memory_t(const pd_t *apd)
53         : cpu_primitive_t(apd, input_vector(), output_vector(1, this))
54         , data_(nullptr) {}
55     virtual ~cpu_memory_t() {}
56
57     virtual void execute(mkldnn::impl::event_t *e) const
58     { e->set_state(event_t::ready); }
59
60     virtual status_t get_data_handle(void **handle) const {
61         *handle = static_cast<void *>(data_);
62         return success;
63     }
64     virtual mkldnn::impl::status_t set_data_handle(void *handle) {
65         data_ = static_cast<char *>(handle);
66         return zero_pad();
67     }
68
69     virtual char *memory(size_t output_index = 0) const
70     { assert(output_index == 0); return data_; }
71     virtual const char* const_memory(size_t output_index = 0) const
72     { assert(output_index == 0); return data_; }
73
74     mkldnn::impl::status_t zero_pad() const;
75
76 private:
77     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
78     char *data_;
79
80     template <mkldnn::impl::data_type_t>
81     mkldnn::impl::status_t typed_zero_pad() const;
82 };
83
84 struct cpu_view_t: public cpu_primitive_t {
85     struct pd_t: public view_pd_t {
86         pd_t(engine_t *engine)
87             : view_pd_t(engine), src_pd_(engine), dst_pd_(engine) {}
88         virtual ~pd_t() {}
89
90         status_t init(const cpu_memory_t::pd_t *src_pd, const dims_t dims,
91                 const dims_t offsets) {
92             using namespace memory_format;
93             using namespace status;
94
95             if (src_pd->engine() != engine()) return invalid_arguments;
96
97             src_pd_ = *src_pd;
98             const memory_desc_t &src_d = *src_pd_.desc();
99             if (src_d.format == wino_fmt) return unimplemented;
100             const auto &src_d_blk = src_d.layout_desc.blocking;
101
102             memory_desc_t dst_d = src_d;
103             auto &dst_d_blk = dst_d.layout_desc.blocking;
104
105             const int ndims = dst_d.ndims;
106             for (int d = 0; d < ndims; ++d) {
107                 /* very limited functionality for now */
108                 const bool ok = true
109                     && offsets[d] % src_d_blk.block_dims[d] == 0 /* [r1] */
110                     && src_d_blk.offset_padding_to_data[d] == 0
111                     && (false
112                             || dims[d] % src_d_blk.block_dims[d] == 0
113                             || dims[d] < src_d_blk.block_dims[d]);
114                 if (!ok)
115                     return unimplemented;
116
117                 const bool is_right_border
118                     = offsets[d] + dims[d] == src_d.dims[d];
119
120                 dst_d.dims[d] = dims[d];
121                 dst_d_blk.padding_dims[d] = is_right_border
122                     ? src_d_blk.padding_dims[d] - offsets[d]
123                     : dst_d.dims[d];
124                 dst_d_blk.offset_padding_to_data[d] =
125                     src_d_blk.offset_padding_to_data[d];
126                 dst_d_blk.offset_padding +=
127                     offsets[d] / src_d_blk.block_dims[d] /* [r1] */
128                     * dst_d_blk.strides[0][d];
129             }
130
131             dst_pd_ = cpu_memory_t::pd_t(engine(), &dst_d);
132             return success;
133         }
134
135         static status_t create(pd_t **cpu_view_pd,
136                 const cpu_memory_t::pd_t *src_pd, const dims_t dims,
137                 const dims_t offsets) {
138             pd_t *pd;
139             status_t status = safe_ptr_assign<pd_t>(pd,
140                     new pd_t(src_pd->engine()));
141             if (status != success) return status;
142             status = pd->init(src_pd, dims, offsets);
143             if (status != success) return status;
144             *cpu_view_pd = pd;
145             return success;
146         }
147
148         virtual pd_t *clone() const override { return new pd_t(*this); }
149         virtual status_t create_primitive(primitive_t **primitive,
150                 const primitive_at_t *inputs, const primitive_t **outputs)
151             const override
152         {
153             primitive_t::input_vector ins(inputs, inputs + 1);
154             UNUSED(outputs);
155             return safe_ptr_assign<primitive_t>(*primitive,
156                     new cpu_view_t(this, ins));
157         }
158
159         virtual const cpu_memory_t::pd_t *src_pd(int index = 0) const override
160         { return index == 0 ? &src_pd_ : nullptr; }
161         virtual const cpu_memory_t::pd_t *dst_pd(int index = 0) const override
162         { return index == 0 ? &dst_pd_ : nullptr; }
163
164         cpu_memory_t::pd_t src_pd_;
165         cpu_memory_t::pd_t dst_pd_;
166
167     protected:
168         pd_t(const cpu_memory_t::pd_t &src_pd, const cpu_memory_t::pd_t &dst_pd)
169             : view_pd_t(src_pd.engine()), src_pd_(src_pd), dst_pd_(dst_pd) {}
170     };
171
172     cpu_view_t(const pd_t *apd, const input_vector &inputs)
173         : cpu_primitive_t(apd, inputs, output_vector(1, this))
174     {}
175     virtual ~cpu_view_t() {}
176
177     virtual void execute(mkldnn::impl::event_t *e) const
178     { e->set_state(event_t::ready); }
179
180     virtual char *memory(size_t output_index = 0) const
181     { assert(output_index == 0); return const_cast<char *>(input_memory()); }
182     virtual const char* const_memory(size_t output_index = 0) const
183     { assert(output_index == 0); return input_memory(); }
184
185 private:
186     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
187 };
188
189 }
190 }
191 }
192
193 #endif
194
195 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s