1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_MEMORY_HPP
18 #define CPU_MEMORY_HPP
22 #include "c_types_map.hpp"
23 #include "cpu_primitive.hpp"
25 #include "memory_pd.hpp"
26 #include "type_helpers.hpp"
33 using namespace mkldnn::impl;
34 using namespace mkldnn::impl::status;
36 struct cpu_memory_t: public cpu_primitive_t {
37 struct pd_t: public memory_pd_t {
38 pd_t(engine_t *engine): memory_pd_t(engine) {}
39 pd_t(engine_t *engine, const memory_desc_t *adesc)
40 : memory_pd_t(engine, adesc) {}
42 virtual pd_t *clone() const { return new pd_t(engine(), desc()); }
43 virtual status_t create_primitive(primitive_t **primitive,
44 const primitive_at_t *inputs, const primitive_t **outputs) const
46 UNUSED(inputs); UNUSED(outputs);
47 return safe_ptr_assign<primitive_t>(*primitive,
48 new cpu_memory_t(this));
52 cpu_memory_t(const pd_t *apd)
53 : cpu_primitive_t(apd, input_vector(), output_vector(1, this))
55 virtual ~cpu_memory_t() {}
57 virtual void execute(mkldnn::impl::event_t *e) const
58 { e->set_state(event_t::ready); }
60 virtual status_t get_data_handle(void **handle) const {
61 *handle = static_cast<void *>(data_);
64 virtual mkldnn::impl::status_t set_data_handle(void *handle) {
65 data_ = static_cast<char *>(handle);
69 virtual char *memory(size_t output_index = 0) const
70 { assert(output_index == 0); return data_; }
71 virtual const char* const_memory(size_t output_index = 0) const
72 { assert(output_index == 0); return data_; }
74 mkldnn::impl::status_t zero_pad() const;
77 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
80 template <mkldnn::impl::data_type_t>
81 mkldnn::impl::status_t typed_zero_pad() const;
84 struct cpu_view_t: public cpu_primitive_t {
85 struct pd_t: public view_pd_t {
86 pd_t(engine_t *engine)
87 : view_pd_t(engine), src_pd_(engine), dst_pd_(engine) {}
90 status_t init(const cpu_memory_t::pd_t *src_pd, const dims_t dims,
91 const dims_t offsets) {
92 using namespace memory_format;
93 using namespace status;
95 if (src_pd->engine() != engine()) return invalid_arguments;
98 const memory_desc_t &src_d = *src_pd_.desc();
99 if (src_d.format == wino_fmt) return unimplemented;
100 const auto &src_d_blk = src_d.layout_desc.blocking;
102 memory_desc_t dst_d = src_d;
103 auto &dst_d_blk = dst_d.layout_desc.blocking;
105 const int ndims = dst_d.ndims;
106 for (int d = 0; d < ndims; ++d) {
107 /* very limited functionality for now */
109 && offsets[d] % src_d_blk.block_dims[d] == 0 /* [r1] */
110 && src_d_blk.offset_padding_to_data[d] == 0
112 || dims[d] % src_d_blk.block_dims[d] == 0
113 || dims[d] < src_d_blk.block_dims[d]);
115 return unimplemented;
117 const bool is_right_border
118 = offsets[d] + dims[d] == src_d.dims[d];
120 dst_d.dims[d] = dims[d];
121 dst_d_blk.padding_dims[d] = is_right_border
122 ? src_d_blk.padding_dims[d] - offsets[d]
124 dst_d_blk.offset_padding_to_data[d] =
125 src_d_blk.offset_padding_to_data[d];
126 dst_d_blk.offset_padding +=
127 offsets[d] / src_d_blk.block_dims[d] /* [r1] */
128 * dst_d_blk.strides[0][d];
131 dst_pd_ = cpu_memory_t::pd_t(engine(), &dst_d);
135 static status_t create(pd_t **cpu_view_pd,
136 const cpu_memory_t::pd_t *src_pd, const dims_t dims,
137 const dims_t offsets) {
139 status_t status = safe_ptr_assign<pd_t>(pd,
140 new pd_t(src_pd->engine()));
141 if (status != success) return status;
142 status = pd->init(src_pd, dims, offsets);
143 if (status != success) return status;
148 virtual pd_t *clone() const override { return new pd_t(*this); }
149 virtual status_t create_primitive(primitive_t **primitive,
150 const primitive_at_t *inputs, const primitive_t **outputs)
153 primitive_t::input_vector ins(inputs, inputs + 1);
155 return safe_ptr_assign<primitive_t>(*primitive,
156 new cpu_view_t(this, ins));
159 virtual const cpu_memory_t::pd_t *src_pd(int index = 0) const override
160 { return index == 0 ? &src_pd_ : nullptr; }
161 virtual const cpu_memory_t::pd_t *dst_pd(int index = 0) const override
162 { return index == 0 ? &dst_pd_ : nullptr; }
164 cpu_memory_t::pd_t src_pd_;
165 cpu_memory_t::pd_t dst_pd_;
168 pd_t(const cpu_memory_t::pd_t &src_pd, const cpu_memory_t::pd_t &dst_pd)
169 : view_pd_t(src_pd.engine()), src_pd_(src_pd), dst_pd_(dst_pd) {}
172 cpu_view_t(const pd_t *apd, const input_vector &inputs)
173 : cpu_primitive_t(apd, inputs, output_vector(1, this))
175 virtual ~cpu_view_t() {}
177 virtual void execute(mkldnn::impl::event_t *e) const
178 { e->set_state(event_t::ready); }
180 virtual char *memory(size_t output_index = 0) const
181 { assert(output_index == 0); return const_cast<char *>(input_memory()); }
182 virtual const char* const_memory(size_t output_index = 0) const
183 { assert(output_index == 0); return input_memory(); }
186 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
195 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s