1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_deformable_convolution.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
29 using namespace mkldnn::impl::status;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::memory_tracking::names;
32 using namespace mkldnn::impl::utils;
34 template <cpu_isa_t isa>
35 void jit_uni_deformable_convolution_fwd_t<isa>::execute_forward() const {
36 auto src = reinterpret_cast<const float *>(this->input_memory(0));
37 auto offsets = reinterpret_cast<const float *>(this->input_memory(1));
38 auto weights = reinterpret_cast<const float *>(this->input_memory(2));
39 auto bias = reinterpret_cast<const float *>(this->input_memory(3));
40 auto dst = reinterpret_cast<float *>(this->memory());
42 const memory_desc_wrapper src_d(pd()->src_pd(0));
43 const memory_desc_wrapper offsets_d(pd()->src_pd(1));
44 const memory_desc_wrapper dst_d(pd()->dst_pd());
45 // const memory_desc_wrapper weights_d(pd()->weights_pd(0));
46 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
48 const auto &jcp = kernel_->jcp;
50 if (bias && jcp.oc != jcp.oc_padded) {
51 auto padded_bias = this->scratchpad().template get<float>(key_conv_padded_bias);
52 utils::array_copy(padded_bias, (float*)bias, jcp.oc);
53 utils::array_set(padded_bias + jcp.oc, 0, jcp.oc_padded - jcp.oc);
54 bias = (float *)padded_bias;
57 auto input_buffer = this->scratchpad().template get<float>(key_def_conv_buffer);
59 const size_t work_amount = jcp.mb * jcp.ngroups * jcp.oh;
61 auto ker = [&](const int ithr, const int nthr) {
62 size_t start{0}, end{0};
63 balance211(work_amount, nthr, ithr, start, end);
65 size_t n{0}, g{0}, oh{0};
66 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, oh, jcp.oh);
67 for (size_t iwork = start; iwork < end; ++iwork) {
68 auto par_conv = jit_def_conv_call_s();
70 const size_t _oc = g * jcp.nb_oc;
71 const size_t _ic = g * jcp.nb_ic;
73 par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, oh * jcp.stride_h - jcp.t_pad, 0 - jcp.l_pad)];
74 par_conv.off = &offsets[offsets_d.blk_off(n, 0, oh, 0)];
75 par_conv.filt = weights;//weights_d(0, 0, 0, 0);
77 par_conv.bias = &bias[bias_d.blk_off(_oc * jcp.oc_block*jcp.typesize_bia)];
78 par_conv.dst = &dst[dst_d.blk_off(n, _oc*jcp.oc_block, oh, 0)];
80 par_conv.buf = input_buffer + ithr * jcp.ur_w * jcp.kh * jcp.kw * jcp.ic;
84 kernel_->jit_ker(&par_conv);
85 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, oh, jcp.oh);
92 template struct jit_uni_deformable_convolution_fwd_t<avx512_common>;
93 template struct jit_uni_deformable_convolution_fwd_t<avx2>;
94 template struct jit_uni_deformable_convolution_fwd_t<sse42>;