arm_compute v18.02
[platform/upstream/armcl.git] / arm_compute / core / NEON / kernels / assembly / transforms / a64_block16_interleave4_8bit.hpp
1 /*
2  * Copyright (c) 2017-2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #pragma once
25
26 #ifdef __aarch64__
27
28 #include <arm_neon.h>
29 #include "asmlib.hpp"
30
31 template<>
32 template<typename T>
33 inline void TransformImpl<4, 16, false, 1, 1>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
34     uint8_t *outptr = (uint8_t *)out;
35     const uint8_t *inptr = (uint8_t *)in;
36
37     uint8_t zerobuff[16];
38
39     for (int y=y0; y<ymax; y+=4) {
40         const uint8_t *inptr0 = inptr + y * ldin + k0;
41         const uint8_t *inptr1 = inptr0 + ldin;
42         const uint8_t *inptr2 = inptr1 + ldin;
43         const uint8_t *inptr3 = inptr2 + ldin;
44
45         prefetch_2x(inptr0);
46         prefetch_2x(inptr1);
47         prefetch_2x(inptr2);
48         prefetch_2x(inptr3);
49
50         int x=(kmax-k0);
51         for (;x>15;x-=16) {
52             /* Cope with ragged cases by copying from a buffer of zeroes instead */
53             if ((y + 3) >= ymax) {
54                 switch ((y + 3) - ymax) {
55                     /* Everything falls through in here */
56                     case 2:
57                         inptr1 = zerobuff;
58                     case 1:
59                         inptr2 = zerobuff;
60                     case 0:
61                         inptr3 = zerobuff;
62                     default:
63                         break;
64                 }
65             }
66
67             __asm __volatile (
68                 "LDR    q0, [%[inptr0]], #16\n"
69                 ASM_PREFETCH("[%[inptr0], #176]")
70                 "LDR    q1, [%[inptr1]], #16\n"
71                 ASM_PREFETCH("[%[inptr1], #176]")
72                 "STP    q0, q1, [%[outptr]], #32\n"
73                 "LDR    q0, [%[inptr2]], #16\n"
74                 ASM_PREFETCH("[%[inptr2], #176]")
75                 "LDR    q1, [%[inptr3]], #16\n"
76                 ASM_PREFETCH("[%[inptr3], #176]")
77                 "STP    q0, q1, [%[outptr]], #32\n"
78                 : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3),
79                   [outptr] "+r" (outptr)
80                 :
81                 : "v0", "v1"
82             );
83         }
84
85         if (x>0) {
86             /* Need to duplicate this here, in case we didn't run the main loop. */
87             if ((y + 3) >= ymax) {
88                 switch ((y + 3) - ymax) {
89                     /* Everything falls through in here */
90                     case 2:
91                         inptr1 = zerobuff;
92                     case 1:
93                         inptr2 = zerobuff;
94                     case 0:
95                         inptr3 = zerobuff;
96                     default:
97                         break;
98                 }
99             }
100
101             /* We have to write out 16 values, copy as many legal values as there are and pad with 0 */
102             auto f = [&outptr, x](const uint8_t *&p) {
103                 for (int i=0; i<16; i++) {
104                     if (i < x) {
105                         *outptr++ = *p++;
106                     } else {
107                         *outptr++ = 0;
108                     }
109                 }
110             };
111
112             f(inptr0);
113             f(inptr1);
114             f(inptr2);
115             f(inptr3);
116         }
117     }
118 }
119
120 #endif  // __aarch64__