Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / Nodes / CircleCast.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Import/Nodes/CircleCast.h"
18
19 #include <luci/IR/Nodes/CircleCast.h>
20
21 #include <luci/UserSettings.h>
22 #include <luci/Log.h>
23
24 #include <loco.h>
25
26 namespace luci
27 {
28
29 bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
30 {
31   LOGGER(l);
32
33   auto settings = luci::UserSettings::settings();
34
35   const auto &inputs = args.op.inputs;
36   const auto &outputs = args.op.outputs;
37   if (inputs.size() != 1)
38     return false;
39   if (outputs.size() != 1)
40     return false;
41
42   // NOTE real models do have type mismatch
43   const auto *options = args.op.builtin_options.AsCastOptions();
44   if (options != nullptr)
45   {
46     const auto &tensors = args.reader.tensors();
47     const circle::TensorT &output_tensor = *tensors[outputs[0]];
48     auto name = tensor_name(output_tensor);
49
50     const auto &tensor_in = tensors.at(inputs.at(0));
51     if (tensor_in->type != options->in_data_type)
52     {
53       if (settings->get(luci::UserSettings::Key::DisableValidation))
54       {
55         WARN(l) << "Warning: import Cast(" << name << ") dtype mismatch";
56       }
57       else
58         return false;
59     }
60     const auto &tensor_out = tensors.at(outputs[0]);
61     if (tensor_out->type != options->out_data_type)
62     {
63       if (settings->get(luci::UserSettings::Key::DisableValidation))
64       {
65         WARN(l) << "Warning: import Cast(" << name << ") dtype mismatch";
66       }
67       else
68         return false;
69     }
70   }
71
72   return true;
73 }
74
75 CircleNode *CircleCastGraphBuilder::build_node(const circle::OperatorT &op,
76                                                const std::vector<CircleNode *> &inputs,
77                                                loco::Graph *graph) const
78 {
79   auto *node = graph->nodes()->create<CircleCast>();
80   node->x(inputs.at(0));
81
82   const auto *options = op.builtin_options.AsCastOptions();
83   if (options != nullptr)
84   {
85     node->in_data_type(luci_datatype(options->in_data_type));
86     node->out_data_type(luci_datatype(options->out_data_type));
87   }
88   else
89   {
90     node->in_data_type(inputs.at(0)->dtype());
91     node->out_data_type(loco::DataType::Unknown);
92     // type inference should use node->dtype() for Unknown
93     // export should use BuiltinOptions_NONE for Unknown
94   }
95
96   return node;
97 }
98
99 } // namespace luci