Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / SpaceToBatchND.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/SpaceToBatchND.h"
19 #include "kernels/Utils.h"
20
21 #include "PALSpaceToBatchND.h"
22
23 #include <stdexcept>
24
25 namespace luci_interpreter
26 {
27 namespace kernels
28 {
29 namespace
30 {
31
32 const int kInputMinDimensionNum = 3;
33 const int kInputMaxDimensionNum = 4;
34
35 } // namespace
36
37 SpaceToBatchND::SpaceToBatchND(const Tensor *input, const Tensor *block_shape,
38                                const Tensor *paddings, Tensor *output)
39   : Kernel({input, block_shape, paddings}, {output})
40 {
41 }
42
43 void SpaceToBatchND::configure()
44 {
45   const auto *block_shape_data = block_shape()->data<int32_t>();
46   const auto *paddings_data = paddings()->data<int32_t>();
47   LUCI_INTERPRETER_CHECK(input()->shape().num_dims() >= kInputMinDimensionNum);
48   LUCI_INTERPRETER_CHECK(input()->shape().num_dims() <= kInputMaxDimensionNum);
49   LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
50
51   int spatial_dims_num = input()->shape().num_dims() - 2;
52
53   LUCI_INTERPRETER_CHECK(block_shape()->shape().num_dims() == 1);
54   LUCI_INTERPRETER_CHECK(block_shape()->shape().dim(0) == spatial_dims_num);
55
56   LUCI_INTERPRETER_CHECK(paddings()->shape().num_dims() == 2);
57   LUCI_INTERPRETER_CHECK(paddings()->shape().dim(0) == spatial_dims_num);
58   LUCI_INTERPRETER_CHECK(paddings()->shape().dim(1) == 2);
59
60   Shape output_shape = Shape(input()->shape().num_dims());
61   int output_batch_size = input()->shape().dim(0);
62   for (int i = 0; i < spatial_dims_num; ++i)
63   {
64     int final_dim_size =
65       (input()->shape().dim(i + 1) + paddings_data[i * 2] + paddings_data[i * 2 + 1]);
66     LUCI_INTERPRETER_CHECK(final_dim_size % block_shape_data[i] == 0);
67     output_shape.dim(i + 1) = final_dim_size / block_shape_data[i];
68     output_batch_size = output_batch_size * block_shape_data[i];
69   }
70   output_shape.dim(0) = output_batch_size;
71   output_shape.dim(input()->shape().num_dims() - 1) =
72     input()->shape().dim(input()->shape().num_dims() - 1);
73   output()->resize(output_shape);
74 }
75
76 void SpaceToBatchND::execute() const
77 {
78   switch (input()->element_type())
79   {
80     tflite::SpaceToBatchParams op_params;
81     case DataType::FLOAT32:
82       op_params.output_offset = 0;
83       luci_interpreter_pal::SpaceToBatchND(
84         op_params, getTensorShape(input()), getTensorData<float>(input()),
85         getTensorShape(block_shape()), getTensorData<int32_t>(block_shape()),
86         getTensorShape(paddings()), getTensorData<int32_t>(paddings()), getTensorShape(output()),
87         getTensorData<float>(output()));
88       break;
89     case DataType::U8:
90       op_params.output_offset = output()->zero_point();
91       luci_interpreter_pal::SpaceToBatchND(
92         op_params, getTensorShape(input()), getTensorData<uint8_t>(input()),
93         getTensorShape(block_shape()), getTensorData<int32_t>(block_shape()),
94         getTensorShape(paddings()), getTensorData<int32_t>(paddings()), getTensorShape(output()),
95         getTensorData<uint8_t>(output()));
96       break;
97     default:
98       throw std::runtime_error("Unsupported type.");
99   }
100 }
101
102 } // namespace kernels
103 } // namespace luci_interpreter