Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / Nodes / CircleGreater.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Import/Nodes/CircleGreater.h"
18
19 #include <luci/IR/Nodes/CircleGreater.h>
20
21 #include <luci/UserSettings.h>
22 #include <luci/Log.h>
23
24 #include <loco.h>
25
26 namespace luci
27 {
28
29 bool CircleGreaterGraphBuilder::validate(const ValidateArgs &args) const
30 {
31   LOGGER(l);
32
33   auto settings = luci::UserSettings::settings();
34
35   const auto &inputs = args.op.inputs;
36   const auto &outputs = args.op.outputs;
37
38   if (inputs.size() != 2)
39     return false;
40
41   if (outputs.size() != 1)
42     return false;
43
44   const auto &tensors = args.reader.tensors();
45
46   if (tensors[inputs[0]]->type != tensors[inputs[1]]->type)
47     return false;
48
49   // NOTE: real models do have output dtype NOT BOOL
50   if (tensors[outputs[0]]->type != circle::TensorType_BOOL)
51   {
52     if (settings->get(luci::UserSettings::Key::DisableValidation))
53     {
54       const circle::TensorT &output_tensor = *tensors[outputs[0]];
55       auto name = tensor_name(output_tensor);
56       WARN(l) << "Warning: import Greater(" << name << ") output dtype is not boolean";
57     }
58     else
59       return false;
60   }
61
62   return true;
63 }
64
65 CircleNode *CircleGreaterGraphBuilder::build_node(const circle::OperatorT &,
66                                                   const std::vector<CircleNode *> &inputs,
67                                                   loco::Graph *graph) const
68 {
69   auto *node = graph->nodes()->create<CircleGreater>();
70   node->x(inputs[0]);
71   node->y(inputs[1]);
72
73   return node;
74 }
75
76 } // namespace luci