Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / blob_transform.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "cpu_detector.hpp"
6 #include "blob_transform.hpp"
7 #ifdef HAVE_SSE
8 #include "blob_transform_sse42.hpp"
9 #endif
10
11 #include <cstdint>
12 #include <cstdlib>
13
14 //----------------------------------------------------------------------
15
16 namespace InferenceEngine {
17
18 template <InferenceEngine::Precision::ePrecision PRC>
19 static void blob_copy_4d_t(Blob::Ptr src, Blob::Ptr dst) {
20     using data_t = typename InferenceEngine::PrecisionTrait<PRC>::value_type;
21
22     auto *src_ptr = src->buffer().as<data_t*>();
23     auto *dst_ptr = dst->buffer().as<data_t*>();
24
25     SizeVector dims = src->getTensorDesc().getDims();
26
27     size_t N = dims[0];
28     size_t C = dims[1];
29     size_t H = dims[2];
30     size_t W = dims[3];
31
32     const Layout src_l = src->layout();
33     const auto &src_blk_dsc = src->getTensorDesc().getBlockingDesc();
34     const auto &src_strides = src_blk_dsc.getStrides();
35     const auto N_src_stride = src_strides[0];
36     const auto C_src_stride = src_l == NHWC ? src_strides[3] : src_strides[1];
37     const auto H_src_stride = src_l == NHWC ? src_strides[1] : src_strides[2];
38     const auto W_src_stride = src_l == NHWC ? src_strides[2] : src_strides[3];
39     src_ptr += src_blk_dsc.getOffsetPadding();
40
41     const Layout dst_l = dst->layout();
42     const auto &dst_blk_desc = dst->getTensorDesc().getBlockingDesc();
43     const auto &dst_strides = dst_blk_desc.getStrides();
44     const auto N_dst_stride = dst_strides[0];
45     const auto C_dst_stride = dst_l == NHWC ? dst_strides[3] : dst_strides[1];
46     const auto H_dst_stride = dst_l == NHWC ? dst_strides[1] : dst_strides[2];
47     const auto W_dst_stride = dst_l == NHWC ? dst_strides[2] : dst_strides[3];
48
49     src_ptr += dst_blk_desc.getOffsetPadding();
50
51 #ifdef HAVE_SSE
52     if (src->layout() == NHWC && dst->layout() == NCHW && C == 3
53         && C_src_stride == 1 && W_src_stride == 3 && W_dst_stride == 1 &&
54         with_cpu_x86_sse42()) {
55         if (PRC == Precision::U8) {
56             blob_copy_4d_split_u8c3(reinterpret_cast<const uint8_t*>(src_ptr),
57                                     reinterpret_cast<      uint8_t*>(dst_ptr),
58                                     N_src_stride, H_src_stride,
59                                     N_dst_stride, H_dst_stride, C_dst_stride,
60                                     static_cast<int>(N), static_cast<int>(H),
61                                     static_cast<int>(W));
62             return;
63         }
64
65         if (PRC == Precision::FP32) {
66             blob_copy_4d_split_f32c3(reinterpret_cast<const float*>(src_ptr),
67                                      reinterpret_cast<      float*>(dst_ptr),
68                                      N_src_stride, H_src_stride,
69                                      N_dst_stride, H_dst_stride, C_dst_stride,
70                                      static_cast<int>(N), static_cast<int>(H),
71                                      static_cast<int>(W));
72             return;
73         }
74     }
75
76     if (src->layout() == NCHW && dst->layout() == NHWC && C == 3 &&
77         C_dst_stride == 1 && W_dst_stride == 3 && W_src_stride == 1 &&
78         with_cpu_x86_sse42()) {
79         if (PRC == Precision::U8) {
80             blob_copy_4d_merge_u8c3(reinterpret_cast<const uint8_t*>(src_ptr),
81                                     reinterpret_cast<      uint8_t*>(dst_ptr),
82                                     N_src_stride, H_src_stride, C_src_stride,
83                                     N_dst_stride, H_dst_stride,
84                                     static_cast<int>(N), static_cast<int>(H),
85                                     static_cast<int>(W));
86             return;
87         }
88
89         if (PRC == Precision::FP32) {
90             blob_copy_4d_merge_f32c3(reinterpret_cast<const float*>(src_ptr),
91                                      reinterpret_cast<      float*>(dst_ptr),
92                                      N_src_stride, H_src_stride, C_src_stride,
93                                      N_dst_stride, H_dst_stride,
94                                      static_cast<int>(N), static_cast<int>(H),
95                                      static_cast<int>(W));
96             return;
97         }
98     }
99 #endif  // HAVE_SSE
100
101     if (src->layout() == NHWC && dst->layout() == NCHW) {
102         for (int n = 0; n < N; n++) {
103             for (int c = 0; c < C; c++) {
104                 data_t *dst_ptr_l = dst_ptr + n * N_dst_stride + c * C_dst_stride;
105                 data_t *src_ptr_l = src_ptr + n * N_src_stride + c * C_src_stride;
106                 for (int h = 0; h < H; h++) {
107                     data_t *src_ptr_l_l = src_ptr_l + h*H_src_stride;
108                     for (int w = 0; w < W; w++) {
109                         *dst_ptr_l = *src_ptr_l_l;
110                         src_ptr_l_l += W_src_stride;
111                         dst_ptr_l++;
112                     }
113                 }
114             }
115         }
116     } else if (src->layout() == NCHW && dst->layout() == NHWC) {
117         for (int n = 0; n < N; n++) {
118             for (int c = 0; c < C; c++) {
119                 data_t *src_ptr_l = src_ptr + n * N_src_stride + c * C_src_stride;
120                 data_t *dst_ptr_l = dst_ptr + n * N_dst_stride + c;
121                 for (int h = 0; h < H; h++) {
122                     data_t *src_ptr_l_l = src_ptr_l + h*H_src_stride;
123                     for (int w = 0; w < W; w++) {
124                         *dst_ptr_l = *src_ptr_l_l;
125                         dst_ptr_l += W_dst_stride;
126                         src_ptr_l_l++;
127                     }
128                 }
129             }
130         }
131     } else {
132         for (int i = 0; i < N*C*H*W; i++) {
133             dst_ptr[i] = src_ptr[i];
134         }
135     }
136 }
137
138 static inline void blob_copy_4d(Blob::Ptr src, Blob::Ptr dst) {
139     switch (src->precision()) {
140         case Precision::FP32:
141         case Precision::I32:
142             blob_copy_4d_t<Precision::FP32>(src, dst);
143             break;
144
145         case Precision::FP16:
146         case Precision::U16:
147         case Precision::I16:
148             blob_copy_4d_t<Precision::U16>(src, dst);
149             break;
150
151         case Precision::U8:
152         case Precision::I8:
153             blob_copy_4d_t<Precision::U8>(src, dst);
154             break;
155
156         default:
157             THROW_IE_EXCEPTION << "Unsupported blob transformation for precision " << src->precision();
158     }
159 }
160
161 void blob_copy(Blob::Ptr src, Blob::Ptr dst) {
162     if (src->buffer() == nullptr)
163         THROW_IE_EXCEPTION << "Cannot copy blob data. Source is not allocated.";
164
165     if (dst->buffer() == nullptr)
166         THROW_IE_EXCEPTION << "Cannot copy blob data. Destination is not allocated.";
167
168     if (src->precision() != dst->precision())
169         THROW_IE_EXCEPTION << "Unimplemented blob transformation from precision "
170                            << src->precision() << " to " << src->precision();
171
172     if (src->dims() != dst->dims())
173         THROW_IE_EXCEPTION << "Unimplemented blob transformation from different shapes ";
174
175     if (src->dims().size() == 4)
176         blob_copy_4d(src, dst);
177     else
178         THROW_IE_EXCEPTION << "Unimplemented blob transformation. Only 4d supported.";
179 }
180
181 }  // namespace InferenceEngine