updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_deformable_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_deformable_convolution.hpp"
20 #include "utils.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
23 #include <cstring>
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using namespace mkldnn::impl::status;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::memory_tracking::names;
32 using namespace mkldnn::impl::utils;
33
34 template <cpu_isa_t isa>
35 void jit_uni_deformable_convolution_fwd_t<isa>::execute_forward() const {
36     auto src = reinterpret_cast<const float *>(this->input_memory(0));
37     auto offsets = reinterpret_cast<const float *>(this->input_memory(1));
38     auto weights = reinterpret_cast<const float *>(this->input_memory(2));
39     auto bias = reinterpret_cast<const float *>(this->input_memory(3));
40     auto dst = reinterpret_cast<float *>(this->memory());
41
42     const memory_desc_wrapper src_d(pd()->src_pd(0));
43     const memory_desc_wrapper offsets_d(pd()->src_pd(1));
44     const memory_desc_wrapper dst_d(pd()->dst_pd());
45 //    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
46     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
47
48     const auto &jcp = kernel_->jcp;
49
50     if (bias && jcp.oc != jcp.oc_padded) {
51         auto padded_bias = this->scratchpad().template get<float>(key_conv_padded_bias);
52         utils::array_copy(padded_bias, (float*)bias, jcp.oc);
53         utils::array_set(padded_bias + jcp.oc, 0, jcp.oc_padded - jcp.oc);
54         bias = (float *)padded_bias;
55     }
56
57     auto input_buffer = this->scratchpad().template get<float>(key_def_conv_buffer);
58
59     const size_t work_amount = jcp.mb * jcp.ngroups * jcp.oh;
60
61     auto ker = [&](const int ithr, const int nthr) {
62         size_t start{0}, end{0};
63         balance211(work_amount, nthr, ithr, start, end);
64
65         size_t n{0}, g{0}, oh{0};
66         nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, oh, jcp.oh);
67         for (size_t iwork = start; iwork < end; ++iwork) {
68             auto par_conv = jit_def_conv_call_s();
69
70             const size_t _oc = g * jcp.nb_oc;
71             const size_t _ic = g * jcp.nb_ic;
72
73             par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, oh * jcp.stride_h - jcp.t_pad, 0 - jcp.l_pad)];
74             par_conv.off = &offsets[offsets_d.blk_off(n, 0, oh, 0)];
75             par_conv.filt = weights;//weights_d(0, 0, 0, 0);
76             if (bias)
77                 par_conv.bias = &bias[bias_d.blk_off(_oc * jcp.oc_block*jcp.typesize_bia)];
78             par_conv.dst = &dst[dst_d.blk_off(n, _oc*jcp.oc_block, oh, 0)];
79
80             par_conv.buf = input_buffer + ithr * jcp.ur_w * jcp.kh * jcp.kw * jcp.ic;
81
82             par_conv.oh_pos = oh;
83
84             kernel_->jit_ker(&par_conv);
85             nd_iterator_step(n, jcp.mb, g, jcp.ngroups, oh, jcp.oh);
86         }
87     };
88
89     parallel(0, ker);
90 }
91
92 template struct jit_uni_deformable_convolution_fwd_t<avx512_common>;
93 template struct jit_uni_deformable_convolution_fwd_t<avx2>;
94 template struct jit_uni_deformable_convolution_fwd_t<sse42>;
95
96 }
97 }
98 }