cd81a3a4dfbc188e87ac0d0b188a51321931ce15
[platform/core/ml/nnfw.git] / compiler / locomotiv / src / Node / TensorConstantPad.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "NodeExecution.h"
18
19 #include "NodeDataImpl.h"
20 #include "NodeDomain.h"
21 #include "Validation.h"
22
23 #include <nncc/core/ADT/tensor/IndexEnumerator.h>
24 #include <nncc/core/ADT/tensor/LexicalLayout.h>
25
26 #include <cassert>
27
28 using nncc::core::ADT::tensor::Shape;
29 using nncc::core::ADT::tensor::Index;
30 using nncc::core::ADT::tensor::IndexEnumerator;
31 using nncc::core::ADT::tensor::LexicalLayout;
32 using nncc::core::ADT::tensor::make_buffer;
33
34 namespace locomotiv
35 {
36
37 void NodeExecution::execute(loco::TensorConstantPad *pad)
38 {
39   validate(pad, "TensorConstantPad is nullptr");
40
41   auto input_data = annot_data(pad->input());
42   auto input_domain = annot_domain(pad->input());
43   validate(input_data, "Input not ready");
44   validate(input_domain == loco::Domain::Tensor, "Input domain of TensorConstantPad is not Tensor");
45
46   auto input_shape = input_data->shape();
47   const uint32_t input_rank = input_shape->rank();
48
49   auto padding = pad->padding();
50   validate(input_rank == padding->rank(), "input and padding should have same rank");
51
52   auto constant_node = pad->constant();
53   auto constant_data = annot_data(constant_node);
54   validate(constant_data->dtype() == input_data->dtype(), "constant and input have same data type");
55   validate(constant_data->shape()->rank() == 1 && constant_data->shape()->dim(0) == 1,
56            "constant should have one rank with one dimension at zero axis");
57
58   std::unique_ptr<NodeData> pad_data = nullptr;
59   Index base_index;
60   base_index.resize(input_rank);
61
62   // Tensor is padded by relocating its base.
63   // padded output index = input index + base index
64   for (uint32_t axis = 0; axis < padding->rank(); axis++)
65   {
66     base_index.at(axis) = padding->front(axis);
67   }
68
69   // calculate output shape
70   Shape output_shape;
71   output_shape.resize(input_rank);
72   for (uint32_t i = 0; i < input_rank; i++)
73   {
74     output_shape.dim(i) = input_shape->dim(i) + padding->front(i) + padding->back(i);
75   }
76
77   switch (input_data->dtype())
78   {
79     case loco::DataType::FLOAT32:
80     {
81       auto input_buf = input_data->as_f32_bufptr();
82       auto constant_data_buf = constant_data->as_f32_bufptr();
83       const auto constant_value = constant_data_buf->at(Index{0});
84
85       auto output_buf = make_buffer<float, LexicalLayout>(output_shape);
86
87       for (IndexEnumerator ie{*input_shape}, oe{output_shape}; oe.valid(); oe.advance())
88       {
89         auto input_index = ie.current();
90         auto output_index = oe.current();
91
92         if ((input_index + base_index) == output_index)
93         {
94           output_buf.at(output_index) = input_buf->at(input_index);
95           ie.advance();
96         }
97         else
98         {
99           output_buf.at(output_index) = constant_value;
100         }
101       }
102
103       pad_data = make_data(output_buf);
104       break;
105     }
106     default:
107       throw std::runtime_error("NYI for this DataType");
108   }
109
110   assert(pad_data != nullptr);
111   annot_data(pad, std::move(pad_data));
112   annot_domain(pad, annot_domain(pad->input()));
113 }
114
115 } // namespace locomotiv