2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Import/Nodes/CircleUnpack.h"
19 #include <luci/IR/Nodes/CircleUnpack.h>
20 #include <luci/IR/Nodes/CircleUnpackOut.h>
22 #include <luci/UserSettings.h>
26 #include <oops/UserExn.h>
31 bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const
35 auto settings = luci::UserSettings::settings();
37 const auto &inputs = args.op.inputs;
38 const auto &outputs = args.op.outputs;
39 const auto *options = args.op.builtin_options.AsUnpackOptions();
41 if (inputs.size() != 1)
44 // NOTE real models may have mismatch
45 if (static_cast<int32_t>(outputs.size()) != options->num)
47 if (settings->get(luci::UserSettings::Key::DisableValidation))
49 const auto &tensors = args.reader.tensors();
50 const circle::TensorT &output_tensor = *tensors[outputs[0]];
51 auto name = tensor_name(output_tensor);
52 WARN(l) << "Warning: import Unpack(" << name << ") 'num' is not same as outputs used";
61 const auto &tensors = args.reader.tensors();
62 const auto &tensor = tensors.at(inputs[0]);
63 const auto &shape = tensor->shape;
64 auto shape_size = static_cast<int32_t>(shape.size());
67 // NOTE for unknown shape, shape_size is 0
68 if (options->axis < -shape_size || options->axis >= shape_size)
76 * @brief Unpack Node builder
78 * @note Current loco does not provide multiple outputs
79 * We will create multiple CircleUnpackOut nodes to emulate this
80 * For two outputs that may look like this
82 * --- CircleUnpack --- FullyConnected ---
83 * \- FullyConnected ---
85 * will be created like this
87 * --- CircleUnpack --- CircleUnpackOut --- FullyConnected ---
88 * \- CircleUnpackOut --- FullyConnected ---
91 void CircleUnpackGraphBuilder::build(const circle::OperatorT &op,
92 GraphBuilderContext *context) const
94 assert(context != nullptr);
96 auto graph = context->graph();
98 const std::vector<int32_t> &inputs = op.inputs;
99 const std::vector<int32_t> &outputs = op.outputs;
100 const auto &tensors = context->reader()->tensors();
101 const auto &opcodes = context->reader()->opcodes();
102 auto tensors_ptr = context->reader()->tensors_ptr();
103 assert(tensors_ptr != nullptr);
105 // NOTE Unpack has only one input so running a loop is not necessary
106 // This is provided as a reference for other Ops as a reference
107 std::vector<CircleNode *> input_nodes;
108 for (const int32_t input_tensor_index : inputs)
110 input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
113 // Create CircleUnpack
114 CircleUnpack *node = graph->nodes()->create<CircleUnpack>();
115 node->value(input_nodes[0]);
117 const auto *options = op.builtin_options.AsUnpackOptions();
118 node->num(options->num);
119 node->axis(options->axis);
121 assert(outputs.size() > 0);
123 // Let's use name of output 0 as Unpack name
124 const circle::TensorT &output_tensor = *tensors[outputs[0]];
125 node->name(tensor_name(output_tensor));
126 node->op_version(opcodes[op.opcode_index].get()->version);
128 // NOTE We don't set quantization for Unpack itself but to virtual outputs
131 // Create virtual outputs of Unpack
132 for (int32_t n = 0; n < options->num; ++n)
134 const circle::TensorT &output_tensor = *tensors[outputs[n]];
136 auto *nodeout = graph->nodes()->create<CircleUnpackOut>();
137 copy_tensor_attributes(output_tensor, nodeout);
139 if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
140 nodeout->shape_status(ShapeStatus::NOSHAPE);
142 nodeout->shape_status(ShapeStatus::VALID);
144 nodeout->input(node);
147 context->nodefinder()->enroll(outputs[n], nodeout);