5674ff3ef240fdcdcece3bb4822c3025a89c05d6
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / MatrixBandPart.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_MATRIX_BAND_PART_H__
19 #define __NNFW_CKER_MATRIX_BAND_PART_H__
20
21 #include "cker/Shape.h"
22
23 #include <algorithm>
24
25 namespace nnfw
26 {
27 namespace cker
28 {
29 template <typename T>
30 void MatrixBandPart(const T num_lower_diags, const T num_upper_diags, const Shape &input_shape,
31                     const float *input_data, const Shape &output_shape, float *output_data)
32 {
33   auto last_dim = input_shape.DimensionsCount() - 1;
34
35   T batch_num = 1;
36   for (int dim = 0; dim < input_shape.DimensionsCount() - 2; dim++)
37   {
38     batch_num *= input_shape.Dims(dim);
39   }
40
41   const T row_num = input_shape.Dims(last_dim - 1);
42   const T col_num = input_shape.Dims(last_dim);
43
44   if (!(num_lower_diags <= row_num))
45     throw std::runtime_error(
46         "MatrixBandPart : num_lower must be negative or less or equal to number of rows");
47
48   if (!(num_upper_diags <= col_num))
49     throw std::runtime_error(
50         "MatrixBandPart : num_upper must be negative or less or equal to number of columns");
51
52   std::fill(output_data, output_data + output_shape.FlatSize(), 0); // output matrix init
53
54   // reference code, without multithreading
55   for (T batch = 0; batch < batch_num; ++batch)
56   {
57     for (T row = 0; row < row_num; ++row)
58     {
59       auto output = output_data + (batch * row_num * col_num + row * col_num);
60       auto input = input_data + (batch * row_num * col_num + row * col_num);
61
62       const T band_start =
63           num_lower_diags < 0 ? 0 : std::min(col_num, std::max(T{0}, row - num_lower_diags));
64       const T band_end = num_upper_diags < 0 ? col_num : std::min(static_cast<T>(col_num),
65                                                                   row + num_upper_diags + 1);
66
67       for (T band_idx = band_start; band_idx < band_end; band_idx++)
68       {
69         output[band_idx] = input[band_idx];
70       }
71     }
72   }
73 }
74 } // namespace cker
75 } // namespace nnfw
76
77 #endif // __NNFW_CKER_MATRIX_BAND_PART_H__