arm_compute v18.05
[platform/upstream/armcl.git] / src / core / NEON / kernels / arm_gemm / mergeresults.hpp
1 /*
2  * Copyright (c) 2017-2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #pragma once
25
26 /* As some of the merges need these headers, but are all included in the
27  * arm_gemm namespace, put these headers here.  */
28 #include <arm_neon.h>
29
30 #include "asmlib.hpp"
31 #include "utils.hpp"
32
33 namespace arm_gemm
34 {
35 template <unsigned int width, unsigned int height, typename Tin, typename Tout>
36 inline void MergeResults(Tout *out, const Tin *in, int ldc, int y0, int ymax, int x0, int xmax, const Tout alpha, const Tout beta)
37 {
38     int full_y_blocks = (ymax - y0) / height;
39     int y_remainder   = (ymax - y0) % height;
40     int y_blocks      = full_y_blocks + (y_remainder ? 1 : 0);
41
42     int full_x_blocks = (xmax - x0) / width;
43     int x_remainder   = (xmax - x0) % width;
44     int x_blocks      = full_x_blocks + (x_remainder ? 1 : 0);
45
46     for(int y_block = 0; y_block < y_blocks; y_block++)
47     {
48         int ybase = y0 + (y_block * height);
49
50         int fill_rows = (y_block < full_y_blocks) ? height : y_remainder;
51
52         for(int x_block = 0; x_block < x_blocks; x_block++)
53         {
54             int xbase = x0 + (x_block * width);
55
56             int fill_cols = (x_block < full_x_blocks) ? width : x_remainder;
57
58             for(int row = 0; row < fill_rows; row++)
59             {
60                 for(int col = 0; col < fill_cols; col++)
61                 {
62                     Tout &p = out[(ybase + row) * ldc + xbase + col];
63
64                     p = (p * beta) + (alpha * in[row * width + col]);
65                 }
66             }
67
68             in += (width * height);
69         }
70     }
71 }
72
73 #include "merges/list.hpp"
74
75 } // namespace arm_gemm