dfe32ca92ba173b82d0058988cd4de03bfca6d20
[platform/core/ml/nnfw.git] / compiler / locomotiv / src / Node / BiasAdd.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "NodeExecution.h"
18
19 #include "NodeDataImpl.h"
20 #include "NodeDomain.h"
21 #include "Validation.h"
22
23 #include <nncc/core/ADT/tensor/Shape.h>
24 #include <nncc/core/ADT/tensor/Buffer.h>
25 #include <nncc/core/ADT/tensor/IndexEnumerator.h>
26 #include <nncc/core/ADT/tensor/LexicalLayout.h>
27
28 using nncc::core::ADT::tensor::IndexEnumerator;
29 using nncc::core::ADT::tensor::LexicalLayout;
30 using nncc::core::ADT::tensor::make_buffer;
31
32 #include <cassert>
33 #include <stdexcept>
34
35 namespace
36 {
37 using locomotiv::NodeData;
38
39 std::unique_ptr<NodeData> calc(const NodeData *input_data, const NodeData *bias_data,
40                                uint32_t axis);
41
42 } // namespace
43
44 namespace locomotiv
45 {
46
47 void NodeExecution::execute(loco::BiasAdd<loco::Domain::Tensor> *bias_add)
48 {
49   validate(bias_add, "BiasAdd is nullptr");
50
51   auto input_data = locomotiv::annot_data(bias_add->value());
52   auto bias_data = locomotiv::annot_data(bias_add->bias());
53
54   validate(input_data && bias_data, "Input not ready");
55   validate(locomotiv::annot_domain(bias_add->value()) == loco::Domain::Tensor &&
56                locomotiv::annot_domain(bias_add->bias()) == loco::Domain::Bias,
57            "Wrong input domain");
58
59   std::unique_ptr<NodeData> bias_add_data = calc(input_data, bias_data, bias_add->axis());
60
61   assert(bias_add_data != nullptr);
62   annot_data(bias_add, std::move(bias_add_data));
63   annot_domain(bias_add, annot_domain(bias_add->value()));
64 }
65
66 void NodeExecution::execute(loco::BiasAdd<loco::Domain::Feature> *bias_add)
67 {
68   validate(bias_add, "BiasAdd is nullptr");
69
70   auto input_data = locomotiv::annot_data(bias_add->value());
71   auto bias_data = locomotiv::annot_data(bias_add->bias());
72
73   validate(input_data && bias_data, "Input not ready");
74   validate(locomotiv::annot_domain(bias_add->value()) == loco::Domain::Feature &&
75                locomotiv::annot_domain(bias_add->bias()) == loco::Domain::Bias,
76            "Wrong input domain");
77
78   std::unique_ptr<NodeData> bias_add_data = calc(input_data, bias_data, 3);
79
80   assert(bias_add_data != nullptr);
81   annot_data(bias_add, std::move(bias_add_data));
82   annot_domain(bias_add, loco::Domain::Feature);
83 }
84
85 } // namespace locomotiv
86
87 namespace
88 {
89 using locomotiv::NodeData;
90 using locomotiv::validate;
91 using locomotiv::make_data;
92
93 std::unique_ptr<NodeData> calc(const NodeData *input_data, const NodeData *bias_data, uint32_t axis)
94 {
95   validate(input_data->shape()->dim(axis) == bias_data->shape()->dim(0), "Bias size mismatch");
96
97   std::unique_ptr<NodeData> bias_add_data = nullptr;
98
99   switch (input_data->dtype())
100   {
101     case loco::DataType::FLOAT32:
102     {
103       auto input_bufptr = input_data->as_f32_bufptr();
104       auto bias_bufptr = bias_data->as_f32_bufptr();
105       auto bias_add_buf = make_buffer<float, LexicalLayout>(*input_data->shape());
106
107       auto *shape = input_data->shape();
108
109       for (IndexEnumerator e{*shape}; e.valid(); e.advance())
110       {
111         const auto &index = e.current();
112         nncc::core::ADT::tensor::Index bias_index({index.at(axis)});
113         bias_add_buf.at(index) = input_bufptr->at(index) + bias_bufptr->at(bias_index);
114       }
115
116       bias_add_data = make_data(bias_add_buf);
117       break;
118     }
119     default:
120       throw std::runtime_error("NYI for this DataType");
121   }
122   return bias_add_data;
123 }
124
125 } // namespace