Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / ValidateHelpers.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ValidateHelpers.h"
18
19 namespace luci
20 {
21
22 bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args)
23 {
24   const auto &inputs = args.op.inputs;
25   if (inputs.size() != 3)
26     return false;
27
28   // input 1 and 2 should have INT32/INT64 type
29   const auto &tensors = args.reader.tensors();
30   const auto &tensor_1 = tensors.at(inputs.at(1));
31   switch (tensor_1->type)
32   {
33     case circle::TensorType_INT32:
34     case circle::TensorType_INT64:
35       break;
36     default:
37       return false;
38   }
39   const auto &tensor_2 = tensors.at(inputs.at(2));
40   switch (tensor_2->type)
41   {
42     case circle::TensorType_INT32:
43     case circle::TensorType_INT64:
44       break;
45     default:
46       return false;
47   }
48
49   // Only support input shape dimension 3 and 4 only
50   const auto &tensor_0 = tensors.at(inputs.at(0));
51   const auto t_0_s = tensor_0->shape.size();
52   if (t_0_s != 3 && t_0_s != 4)
53     return false;
54
55   // TODO check input shape
56
57   return true;
58 }
59
60 bool validate_minmax(const GraphBuilderBase::ValidateArgs &args)
61 {
62   const auto &inputs = args.op.inputs;
63   const auto &outputs = args.op.outputs;
64
65   if (inputs.size() != 2)
66     return false;
67
68   if (outputs.size() != 1)
69     return false;
70
71   const auto &tensors = args.reader.tensors();
72   const auto &tensor = tensors.at(inputs.at(0));
73
74   switch (tensor->type)
75   {
76     case circle::TensorType_FLOAT16:
77     case circle::TensorType_FLOAT32:
78     case circle::TensorType_FLOAT64:
79     case circle::TensorType_INT32:
80     case circle::TensorType_INT64:
81       break;
82     default:
83       return false;
84   }
85
86   if (tensors[inputs.at(1)]->type != tensor->type)
87     return false;
88
89   if (tensors[outputs[0]]->type != tensor->type)
90     return false;
91
92   return true;
93 }
94
95 bool validate_reduce_minmax(const GraphBuilderBase::ValidateArgs &args)
96 {
97   const auto &inputs = args.op.inputs;
98   const auto &outputs = args.op.outputs;
99
100   if (inputs.size() != 2)
101     return false;
102
103   if (outputs.size() != 1)
104     return false;
105
106   const auto &tensors = args.reader.tensors();
107   const auto &tensor_axis = tensors.at(inputs.at(1));
108
109   switch (tensor_axis->type)
110   {
111     case circle::TensorType_INT32:
112     case circle::TensorType_INT64:
113       break;
114     default:
115       return false;
116   }
117
118   return true;
119 }
120
121 } // namespace luci