2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "NodeExecution.h"
19 #include "NodeDomain.h"
20 #include "NodeDataImpl.h"
21 #include "Validation.h"
23 #include <nncc/core/ADT/tensor/Shape.h>
24 #include <nncc/core/ADT/tensor/Buffer.h>
25 #include <nncc/core/ADT/tensor/Index.h>
26 #include <nncc/core/ADT/tensor/IndexEnumerator.h>
27 #include <nncc/core/ADT/tensor/LexicalLayout.h>
32 using nncc::core::ADT::tensor::Index;
33 using nncc::core::ADT::tensor::IndexEnumerator;
34 using nncc::core::ADT::tensor::LexicalLayout;
35 using nncc::core::ADT::tensor::make_buffer;
40 float UnaryFunc::apply(float) const { throw std::runtime_error{"F32 is not supported yet"}; }
41 int32_t UnaryFunc::apply(int32_t) const { throw std::runtime_error{"S32 is not supported yet"}; }
43 float BinaryFunc::apply(float, float) const
45 throw std::runtime_error{"F32 is not supported yet"};
48 int32_t BinaryFunc::apply(int32_t, int32_t) const
50 throw std::runtime_error{"S32 is not supported yet"};
53 // TODO Use visitor pattern of loco when available
54 void NodeExecution::run(loco::Node *node)
56 erase_annot_data(node);
59 if (as<loco::Name>(node)) \
61 execute(as<loco::Name>(node)); \
67 throw std::runtime_error("Not supported loco::Node type");
70 void NodeExecution::eltwise_unary(loco::Node *node, const UnaryFunc &f)
72 auto input_node = node->arg(0);
73 auto input_domain = annot_domain(input_node);
74 auto input_data = annot_data(input_node);
75 validate(input_data, "Input is not ready");
76 auto input_dtype = input_data->dtype();
78 validate(input_domain != loco::Domain::Unknown, "Input domain is unknown");
80 auto output_node = node;
81 // Element-wise Unary Operation does not affect Domain
82 auto output_domain = input_domain;
83 // Eltwise-wise Unary Operation does not affet Data Type (ASSUMPTION)
85 // TODO Check this assumption
86 auto output_dtype = input_dtype;
87 std::unique_ptr<NodeData> output_data = nullptr;
91 case loco::DataType::FLOAT32:
93 auto input_bufptr = input_data->as_f32_bufptr();
94 auto output_buf = make_buffer<float, LexicalLayout>(*input_data->shape());
95 auto *shape = input_data->shape();
97 for (IndexEnumerator e{*shape}; e.valid(); e.advance())
99 const auto &index = e.current();
100 output_buf.at(index) = f.apply(input_bufptr->at(index));
103 output_data = make_data(output_buf);
107 throw std::runtime_error("NYI for this DataType");
110 assert(output_data != nullptr);
111 annot_data(output_node, std::move(output_data));
112 annot_domain(output_node, output_domain);
115 void NodeExecution::eltwise_binary(loco::Node *node, const BinaryFunc &f)
117 auto lhs_node = node->arg(0);
118 auto rhs_node = node->arg(1);
119 auto lhs_data = annot_data(lhs_node);
120 auto rhs_data = annot_data(rhs_node);
122 validate(lhs_data && rhs_data, "Input not ready");
123 validate(annot_domain(lhs_node) == annot_domain(rhs_node), "Wrong input domain");
124 validate(lhs_data->dtype() == rhs_data->dtype(), "Wrong input type");
125 validate(*lhs_data->shape() == *rhs_data->shape(), "Wrong input shape");
127 auto out_node = node;
128 std::unique_ptr<NodeData> out_data = nullptr;
130 switch (lhs_data->dtype())
132 case loco::DataType::FLOAT32:
134 auto lhs_bufptr = lhs_data->as_f32_bufptr();
135 auto rhs_bufptr = rhs_data->as_f32_bufptr();
136 auto out_bufptr = make_buffer<float, LexicalLayout>(*lhs_data->shape());
138 auto *shape = lhs_data->shape();
140 for (IndexEnumerator e{*shape}; e.valid(); e.advance())
142 const auto &index = e.current();
143 out_bufptr.at(index) = f.apply(lhs_bufptr->at(index), rhs_bufptr->at(index));
146 out_data = make_data(out_bufptr);
150 throw std::runtime_error("NYI for this DataType");
153 assert(out_data != nullptr);
154 annot_data(out_node, std::move(out_data));
155 annot_domain(out_node, annot_domain(lhs_node));
158 } // namespace locomotiv