Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / Nodes / CircleIf.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Import/Nodes/CircleIf.h"
18
19 #include <luci/IR/Nodes/CircleIf.h>
20 #include <luci/IR/Nodes/CircleIfOut.h>
21
22 #include <loco.h>
23 #include <oops/UserExn.h>
24
25 namespace luci
26 {
27
28 bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const
29 {
30   const auto &inputs = args.op.inputs;
31   const auto *options = args.op.builtin_options.AsIfOptions();
32
33   if (inputs.size() < 2) // cond + input
34     return false;
35   if (args.op.outputs.size() < 1) // output
36     return false;
37
38   auto num_graphs = static_cast<int32_t>(args.reader.num_subgraph());
39   if (options->then_subgraph_index >= num_graphs)
40     return false;
41   if (options->else_subgraph_index >= num_graphs)
42     return false;
43
44   // input 0 should be BOOL type
45   const auto &tensors = args.reader.tensors();
46   const auto &tensor = tensors.at(inputs.at(0));
47   if (tensor->type != circle::TensorType_BOOL)
48     return false;
49
50   const auto &shape = tensor->shape;
51   if (shape.size() != 1 && shape.size() != 0)
52     return false;
53
54   return true;
55 }
56
57 /**
58  * @brief  If Node builder
59  *
60  * @note   Current loco does not provide multiple outputs
61  *         We will create multiple CircleIfOut nodes to emulate this
62  *         For two outputs that may look like this
63  *
64  *         --- CircleIf --- Node ---
65  *                       \- Node ---
66  *
67  *         will be created like this
68  *
69  *         --- CircleIf --- CircleIfOut --- Node ---
70  *                       \- CircleIfOut --- Node ---
71  */
72
73 void CircleIfGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const
74 {
75   assert(context != nullptr);
76
77   auto graph = context->graph();
78
79   const std::vector<int32_t> &inputs = op.inputs;
80   const std::vector<int32_t> &outputs = op.outputs;
81   const auto &tensors = context->reader()->tensors();
82   const auto &opcodes = context->reader()->opcodes();
83   auto tensors_ptr = context->reader()->tensors_ptr();
84   assert(tensors_ptr != nullptr);
85
86   std::vector<CircleNode *> input_nodes;
87   for (const int32_t input_tensor_index : inputs)
88   {
89     input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
90   }
91
92   uint32_t input_count = inputs.size() - 1;
93   uint32_t output_count = outputs.size();
94
95   // Create CircleIf
96   CircleIf *node = graph->nodes()->create<CircleIf>(input_count, output_count);
97
98   node->cond(input_nodes[0]);
99   for (uint32_t idx = 0; idx < input_count; ++idx)
100   {
101     node->input(idx, input_nodes[idx + 1]);
102   }
103
104   const auto *options = op.builtin_options.AsIfOptions();
105   node->then_branch(options->then_subgraph_index);
106   node->else_branch(options->else_subgraph_index);
107
108   assert(outputs.size() > 0);
109   {
110     // Lets use name of output 0 as If name
111     const circle::TensorT &output_tensor = *tensors[outputs[0]];
112     node->name(tensor_name(output_tensor));
113     node->op_version(opcodes[op.opcode_index].get()->version);
114
115     // NOTE We don't set quantization for If itself but to virtual outputs
116   }
117
118   // Create virtual outputs of If
119   for (uint32_t n = 0; n < output_count; ++n)
120   {
121     const circle::TensorT &output_tensor = *tensors[outputs[n]];
122
123     auto *nodeout = graph->nodes()->create<CircleIfOut>();
124     copy_tensor_attributes(output_tensor, nodeout);
125     // mark shape_status
126     if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
127       nodeout->shape_status(ShapeStatus::NOSHAPE);
128     else
129       nodeout->shape_status(ShapeStatus::VALID);
130
131     nodeout->input(node);
132     nodeout->index(n);
133
134     context->nodefinder()->enroll(outputs[n], nodeout);
135   }
136 }
137
138 } // namespace luci