Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / locomotiv / src / NodeExecution.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "NodeExecution.h"
18
19 #include "NodeDomain.h"
20 #include "NodeDataImpl.h"
21 #include "Validation.h"
22
23 #include <nncc/core/ADT/tensor/Shape.h>
24 #include <nncc/core/ADT/tensor/Buffer.h>
25 #include <nncc/core/ADT/tensor/Index.h>
26 #include <nncc/core/ADT/tensor/IndexEnumerator.h>
27 #include <nncc/core/ADT/tensor/LexicalLayout.h>
28
29 #include <cassert>
30 #include <stdexcept>
31
32 using nncc::core::ADT::tensor::Index;
33 using nncc::core::ADT::tensor::IndexEnumerator;
34 using nncc::core::ADT::tensor::LexicalLayout;
35 using nncc::core::ADT::tensor::make_buffer;
36
37 namespace locomotiv
38 {
39
40 float UnaryFunc::apply(float) const { throw std::runtime_error{"F32 is not supported yet"}; }
41 int32_t UnaryFunc::apply(int32_t) const { throw std::runtime_error{"S32 is not supported yet"}; }
42
43 float BinaryFunc::apply(float, float) const
44 {
45   throw std::runtime_error{"F32 is not supported yet"};
46 }
47
48 int32_t BinaryFunc::apply(int32_t, int32_t) const
49 {
50   throw std::runtime_error{"S32 is not supported yet"};
51 }
52
53 // TODO Use visitor pattern of loco when available
54 void NodeExecution::run(loco::Node *node)
55 {
56   erase_annot_data(node);
57
58 #define NODE(Name)                 \
59   if (as<loco::Name>(node))        \
60   {                                \
61     execute(as<loco::Name>(node)); \
62     return;                        \
63   }
64 #include "Node.lst"
65 #undef NODE
66
67   throw std::runtime_error("Not supported loco::Node type");
68 }
69
70 void NodeExecution::eltwise_unary(loco::Node *node, const UnaryFunc &f)
71 {
72   auto input_node = node->arg(0);
73   auto input_domain = annot_domain(input_node);
74   auto input_data = annot_data(input_node);
75   validate(input_data, "Input is not ready");
76   auto input_dtype = input_data->dtype();
77
78   validate(input_domain != loco::Domain::Unknown, "Input domain is unknown");
79
80   auto output_node = node;
81   // Element-wise Unary Operation does not affect Domain
82   auto output_domain = input_domain;
83   // Eltwise-wise Unary Operation does not affet Data Type (ASSUMPTION)
84   //
85   // TODO Check this assumption
86   auto output_dtype = input_dtype;
87   std::unique_ptr<NodeData> output_data = nullptr;
88
89   switch (output_dtype)
90   {
91     case loco::DataType::FLOAT32:
92     {
93       auto input_bufptr = input_data->as_f32_bufptr();
94       auto output_buf = make_buffer<float, LexicalLayout>(*input_data->shape());
95       auto *shape = input_data->shape();
96
97       for (IndexEnumerator e{*shape}; e.valid(); e.advance())
98       {
99         const auto &index = e.current();
100         output_buf.at(index) = f.apply(input_bufptr->at(index));
101       }
102
103       output_data = make_data(output_buf);
104       break;
105     }
106     default:
107       throw std::runtime_error("NYI for this DataType");
108   }
109
110   assert(output_data != nullptr);
111   annot_data(output_node, std::move(output_data));
112   annot_domain(output_node, output_domain);
113 }
114
115 void NodeExecution::eltwise_binary(loco::Node *node, const BinaryFunc &f)
116 {
117   auto lhs_node = node->arg(0);
118   auto rhs_node = node->arg(1);
119   auto lhs_data = annot_data(lhs_node);
120   auto rhs_data = annot_data(rhs_node);
121
122   validate(lhs_data && rhs_data, "Input not ready");
123   validate(annot_domain(lhs_node) == annot_domain(rhs_node), "Wrong input domain");
124   validate(lhs_data->dtype() == rhs_data->dtype(), "Wrong input type");
125   validate(*lhs_data->shape() == *rhs_data->shape(), "Wrong input shape");
126
127   auto out_node = node;
128   std::unique_ptr<NodeData> out_data = nullptr;
129
130   switch (lhs_data->dtype())
131   {
132     case loco::DataType::FLOAT32:
133     {
134       auto lhs_bufptr = lhs_data->as_f32_bufptr();
135       auto rhs_bufptr = rhs_data->as_f32_bufptr();
136       auto out_bufptr = make_buffer<float, LexicalLayout>(*lhs_data->shape());
137
138       auto *shape = lhs_data->shape();
139
140       for (IndexEnumerator e{*shape}; e.valid(); e.advance())
141       {
142         const auto &index = e.current();
143         out_bufptr.at(index) = f.apply(lhs_bufptr->at(index), rhs_bufptr->at(index));
144       }
145
146       out_data = make_data(out_bufptr);
147       break;
148     }
149     default:
150       throw std::runtime_error("NYI for this DataType");
151   }
152
153   assert(out_data != nullptr);
154   annot_data(out_node, std::move(out_data));
155   annot_domain(out_node, annot_domain(lhs_node));
156 }
157
158 } // namespace locomotiv