8c6421744c27a10edf1641c3cbe5caa72b3591cc
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / ShapeValidator.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ShapeValidator.h"
18
19 #include <typeinfo>
20
21 #include "ir/Graph.h"
22 #include "util/logging.h"
23 #include "util/Utils.h"
24
25 #define OP_REQUIRES(EXP)                                                                     \
26   do                                                                                         \
27   {                                                                                          \
28     if (!(EXP))                                                                              \
29       throw std::runtime_error("ShapeValidator failed at line " + std::to_string(__LINE__)); \
30   } while (0)
31
32 namespace onert
33 {
34 namespace compiler
35 {
36
37 ShapeValidator::ShapeValidator(const ir::Graph &graph) : _graph{graph} {}
38
39 void ShapeValidator::checkUnaryOp(const ir::Operation &node)
40 {
41   const auto &operands = _graph.operands();
42   const auto output_index{node.getOutputs().at(0)};
43   const auto input_index{node.getInputs().at(0)};
44
45   if (operands.at(output_index).info().isDynamic())
46     return;
47
48   // Check if I/O shapes match
49   OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
50 }
51
52 void ShapeValidator::operator()()
53 {
54   _graph.operations().iterate(
55     [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
56 }
57
58 void ShapeValidator::visit(const ir::operation::BatchMatMul &node)
59 {
60   const auto &operands = _graph.operands();
61   const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
62   const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
63   const auto out_index{node.getOutputs().at(0)};
64
65   if (operands.at(out_index).info().isDynamic())
66     return;
67
68   OP_REQUIRES(operands.at(lhs_index).shape().rank() <= 4);
69   OP_REQUIRES(operands.at(rhs_index).shape().rank() <= 4);
70   OP_REQUIRES(operands.at(lhs_index).shape().rank() >= 2);
71   OP_REQUIRES(operands.at(rhs_index).shape().rank() >= 2);
72 }
73
74 void ShapeValidator::visit(const ir::operation::BatchToSpaceND &node)
75 {
76   const auto &operands = _graph.operands();
77   const auto ofm_index{node.getOutputs().at(0)};
78   if (operands.at(ofm_index).info().isDynamic())
79     return;
80
81   const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
82   const auto block_size_index{
83     node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)};
84
85   const auto frontend_layout = _graph.layout();
86   const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
87   const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
88
89   // All requirement as per NNAPI specification.
90   OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
91   OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
92   OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
93
94   OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
95
96   if (node.getInputs().size() != 2)
97   {
98     const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)};
99     OP_REQUIRES(operands.at(crops_index).shape().rank() == 2);
100     OP_REQUIRES(operands.at(crops_index).shape().dim(0) ==
101                 (operands.at(ifm_index).shape().rank() - 2));
102     OP_REQUIRES(operands.at(crops_index).shape().dim(1) == 2);
103   }
104
105   OP_REQUIRES(input_shape.C == output_shape.C);
106 }
107
108 void ShapeValidator::visit(const ir::operation::BCQFullyConnected &node)
109 {
110   const auto &operands = _graph.operands();
111   const auto ofm_index{node.getOutputs().at(0)};
112   if (operands.at(ofm_index).info().isDynamic())
113     return;
114
115   const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
116   const auto weight_scales_index{
117     node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_SCALES)};
118   const auto weight_binary_index{
119     node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_BINARY)};
120   const auto weight_cluster_index{
121     node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
122   // const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
123
124   OP_REQUIRES(operands.at(ifm_index).shape().rank() == 2);
125   OP_REQUIRES(operands.at(ofm_index).shape().rank() == 2);
126   OP_REQUIRES(operands.at(weight_scales_index).shape().rank() == 1);
127   OP_REQUIRES(operands.at(weight_binary_index).shape().rank() == 2);
128   OP_REQUIRES(operands.at(weight_cluster_index).shape().rank() == 2);
129
130   OP_REQUIRES(operands.at(ifm_index).shape().dim(1) == operands.at(ofm_index).shape().dim(1));
131
132   OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(0) > 0);
133   OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(1) == 2);
134
135   // more shape validation will be done inside kernel.
136
137   // TODO Check bias dimension (can be null tensor)
138 }
139
140 void ShapeValidator::visit(const ir::operation::BCQGather &node)
141 {
142   const auto &operands = _graph.operands();
143   const auto ofm_index{node.getOutputs().at(0)};
144   if (operands.at(ofm_index).info().isDynamic())
145     return;
146
147   const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
148   const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
149   const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)};
150   const auto input_clusters_index{
151     node.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
152
153   OP_REQUIRES(operands.at(indices_index).shape().rank() <=
154               2); // TODO : support rank up to 4 or more
155   OP_REQUIRES(operands.at(input_binary_index).shape().rank() == 2);
156   OP_REQUIRES(operands.at(input_scales_index).shape().rank() == 1);
157   OP_REQUIRES(operands.at(input_clusters_index).shape().rank() == 2);
158
159   OP_REQUIRES(operands.at(input_clusters_index).shape().dim(0) > 0);
160   OP_REQUIRES(operands.at(input_clusters_index).shape().dim(1) == 2);
161
162   // more shape validation will be done inside kernel.
163 }
164
165 void ShapeValidator::visit(const ir::operation::Comparison &)
166 {
167   // TODO Shape validation of comparison
168 }
169
170 void ShapeValidator::visit(const ir::operation::Softmax &node)
171 {
172   const auto &operands = _graph.operands();
173   const auto output_index{node.getOutputs().at(0)};
174   if (operands.at(output_index).info().isDynamic())
175     return;
176
177   const auto input_index{node.getInputs().at(0)};
178
179   OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
180 }
181
182 void ShapeValidator::visit(const ir::operation::InstanceNorm &node)
183 {
184   const auto &operands = _graph.operands();
185   const auto ofm_index{node.getOutputs().at(0)};
186   if (operands.at(ofm_index).info().isDynamic())
187     return;
188
189   const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
190   const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
191   const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
192
193   OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
194   OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
195   OP_REQUIRES(operands.at(gamma_index).shape().rank() == 1);
196   OP_REQUIRES(operands.at(beta_index).shape().rank() == 1);
197 }
198
199 void ShapeValidator::visit(const ir::operation::Pool2D &node)
200 {
201   const auto &operands = _graph.operands();
202   const auto ofm_index{node.getOutputs().at(0)};
203   if (operands.at(ofm_index).info().isDynamic())
204     return;
205
206   const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
207
208   OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
209 }
210
211 void ShapeValidator::visit(const ir::operation::Permute &node)
212 {
213   const auto &operands = _graph.operands();
214   const auto output_index{node.getOutputs().at(0)};
215   if (operands.at(output_index).info().isDynamic())
216     return;
217
218   const auto input_index{node.getInputs().at(0)};
219
220   OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
221 }
222
223 void ShapeValidator::visit(const ir::operation::Reduce &node)
224 {
225   const auto &operands = _graph.operands();
226   const auto output_index{node.getOutputs().at(0)};
227   if (operands.at(output_index).info().isDynamic())
228     return;
229
230   const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
231   const auto input_shape = operands.at(input_index).shape();
232   const auto output_shape = operands.at(output_index).shape();
233
234   OP_REQUIRES(input_shape.rank() <= 4);
235   OP_REQUIRES(output_shape.rank() <= input_shape.rank());
236
237   // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
238   // supports cases reducing height and width or reducing depth.
239   // TODO We have to support all cases of dimensions up to 4.
240   // For correct permuting, we have to set output's shape to be equal in dimension position of the
241   // input. But the positions of the same dimensions in the input and output may be set differently.
242   // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
243   // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
244   // extend it in 4 dimensions, it should be {1,1,3,5}.
245   // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
246   // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
247   // next operation is not desired.
248   if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
249   {
250     if (output_shape.rank() == 2)
251     {
252       // Reducing HW
253       OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
254                   input_shape.dim(3) == output_shape.dim(1));
255     }
256     else if (output_shape.rank() == 3)
257     {
258       // Reducing C or
259       // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
260       OP_REQUIRES(
261         (input_shape.dim(0) == output_shape.dim(0) && input_shape.dim(1) == output_shape.dim(1) &&
262          input_shape.dim(2) == output_shape.dim(2)) ||
263         (input_shape.dim(0) == output_shape.dim(0) &&
264          (input_shape.dim(1) == output_shape.dim(1) || input_shape.dim(2) == output_shape.dim(1)) &&
265          input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
266     }
267   }
268 }
269
270 void ShapeValidator::visit(const ir::operation::Transpose &node)
271 {
272   const auto &operands = _graph.operands();
273   const auto output_index{node.getOutputs().at(0)};
274   if (operands.at(output_index).info().isDynamic())
275     return;
276
277   const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
278   const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
279
280   const auto &output_shape = operands.at(output_index).shape();
281   const auto &input_shape = operands.at(input_index).shape();
282
283   OP_REQUIRES(operands.at(perm_index).shape().num_elements() == 0 ||
284               input_shape.rank() ==
285                 static_cast<int>(operands.at(perm_index).shape().num_elements()));
286   OP_REQUIRES(input_shape.rank() == output_shape.rank());
287 }
288
289 void ShapeValidator::visit(const ir::operation::RNN &node)
290 {
291   // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
292   // TODO Support dynamic rnn
293   const auto &operands = _graph.operands();
294   const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
295   if (operands.at(output_index).info().isDynamic())
296     return;
297
298   const auto hidden_state_out_index{
299     node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)};
300
301   const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
302   const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
303   const auto recurrent_weights_index{
304     node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)};
305   const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
306   const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
307
308   const auto batch_size = operands.at(output_index).shape().dim(0);
309   const auto num_units = operands.at(output_index).shape().dim(1);
310
311   OP_REQUIRES(operands.at(output_index).shape().rank() == 2 &&
312               operands.at(hidden_state_out_index).shape().rank() == 2 &&
313               operands.at(input_index).shape().rank() == 2 &&
314               operands.at(weights_index).shape().rank() == 2 &&
315               operands.at(recurrent_weights_index).shape().rank() == 2 &&
316               operands.at(hidden_state_in_index).shape().rank() == 2);
317   OP_REQUIRES(operands.at(bias_index).shape().rank() == 1);
318
319   OP_REQUIRES(batch_size == operands.at(input_index).shape().dim(0) &&
320               batch_size == operands.at(hidden_state_in_index).shape().dim(0) &&
321               batch_size == operands.at(hidden_state_out_index).shape().dim(0));
322   OP_REQUIRES(operands.at(input_index).shape().dim(1) == operands.at(weights_index).shape().dim(1));
323
324   OP_REQUIRES(num_units == operands.at(weights_index).shape().dim(0) &&
325               num_units == operands.at(recurrent_weights_index).shape().dim(0) &&
326               num_units == operands.at(bias_index).shape().dim(0));
327   OP_REQUIRES(num_units == operands.at(output_index).shape().dim(1) &&
328               num_units == operands.at(recurrent_weights_index).shape().dim(1) &&
329               num_units == operands.at(hidden_state_in_index).shape().dim(1) &&
330               num_units == operands.at(hidden_state_out_index).shape().dim(1));
331 }
332
333 void ShapeValidator::visit(const ir::operation::SpaceToBatchND &node)
334 {
335   const auto &operands = _graph.operands();
336   const auto ofm_index{node.getOutputs().at(0)};
337   if (operands.at(ofm_index).info().isDynamic())
338     return;
339
340   const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
341   const auto block_size_index{
342     node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
343   const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
344
345   const auto frontend_layout = _graph.layout();
346   const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
347   const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
348
349   // All requirement as per NNAPI specification.
350   OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
351   OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
352   OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
353   OP_REQUIRES(operands.at(paddings_index).shape().rank() == 2);
354
355   OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
356   OP_REQUIRES(operands.at(paddings_index).shape().dim(0) == 2);
357   OP_REQUIRES(operands.at(paddings_index).shape().dim(1) == 2);
358
359   OP_REQUIRES(input_shape.C == output_shape.C);
360 }
361
362 void ShapeValidator::visit(const ir::operation::SpaceToDepth &node)
363 {
364   const auto &operands = _graph.operands();
365   const auto ofm_index{node.getOutputs().at(0)};
366   if (operands.at(ofm_index).info().isDynamic())
367     return;
368
369   const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
370
371   const auto frontend_layout = _graph.layout();
372   const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
373   const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
374   const auto block_size = node.param().block_size;
375
376   // All assertions as per NNAPI specification.
377   OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
378   OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
379   OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
380   OP_REQUIRES(input_shape.N == output_shape.N);
381   OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
382 }
383
384 void ShapeValidator::visit(const ir::operation::ElementwiseActivation &node) { checkUnaryOp(node); }
385
386 void ShapeValidator::visit(const ir::operation::ElementwiseBinary &)
387 {
388   // TODO Shape validation of ElementwiseBinary
389 }
390
391 void ShapeValidator::visit(const ir::operation::ElementwiseUnary &node)
392 {
393   const auto &operands = _graph.operands();
394   const auto output_index{node.getOutputs().at(0)};
395   const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
396
397   if (operands.at(output_index).info().isDynamic())
398     return;
399
400   OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
401 }
402
403 void ShapeValidator::visit(const ir::operation::EmbeddingLookup &node)
404 {
405   const auto &operands = _graph.operands();
406   const auto output_index{node.getOutputs().at(0)};
407   const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
408   const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
409
410   const auto &output_obj = operands.at(output_index);
411   const auto &lookups_obj = operands.at(lookups_index);
412   const auto &values_obj = operands.at(values_index);
413
414   // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
415   // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
416   {
417     if (operands.at(output_index).info().isDynamic())
418       return;
419
420     const auto &output_shape = output_obj.shape();
421     const auto &lookups_shape = lookups_obj.shape();
422     const auto &values_shape = values_obj.shape();
423
424     OP_REQUIRES(lookups_shape.rank() == 1);
425     OP_REQUIRES(values_shape.rank() >= 2);
426
427     // output should be a n-D tensor with the same rank and shape as the values tensor, except for
428     // the first dimension which has the same size as lookups' only dimension.
429     OP_REQUIRES(output_shape.rank() == values_shape.rank());
430     OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
431     for (int n = 1; n < output_shape.rank(); ++n)
432     {
433       OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
434     }
435   }
436 }
437
438 void ShapeValidator::visit(const ir::operation::ExpandDims &node)
439 {
440   const auto &operands = _graph.operands();
441   const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
442
443   if (operands.at(axis_index).info().isDynamic())
444     return;
445   OP_REQUIRES(operands.at(axis_index).shape().rank() <= 1);
446 }
447
448 void ShapeValidator::visit(const ir::operation::HashtableLookup &node)
449 {
450   const auto &operands = _graph.operands();
451   const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
452   const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
453   const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
454   const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
455
456   const auto &output_obj = operands.at(output_index);
457   const auto &lookups_obj = operands.at(lookups_index);
458   const auto &keys_obj = operands.at(keys_index);
459   const auto &values_obj = operands.at(values_index);
460
461   if (operands.at(output_index).info().isDynamic())
462     return;
463
464   const auto &output_shape = output_obj.shape();
465   const auto &lookups_shape = lookups_obj.shape();
466   const auto &keys_shape = keys_obj.shape();
467   const auto &values_shape = values_obj.shape();
468
469   OP_REQUIRES(values_shape.rank() == output_shape.rank());
470   OP_REQUIRES(lookups_shape.rank() == 1);
471   OP_REQUIRES(keys_shape.rank() == 1);
472   OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
473   OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
474 }
475
476 void ShapeValidator::visit(const ir::operation::TransposeConv &node)
477 {
478   // shape check
479   const auto &operands = _graph.operands();
480   const auto ofm_index{node.getOutputs().at(0)};
481
482   if (operands.at(ofm_index).info().isDynamic())
483     return;
484
485   const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
486   const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
487
488   // Only 4D tensors are supported
489   OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
490   OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ifm_index).shape().rank());
491   OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ker_index).shape().rank());
492
493   const auto frontend_layout = _graph.layout();
494   const auto ofm_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
495   const auto ifm_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
496   // The kernel has only IHWO layout on frontend
497   // So ker_shape is treated here below
498   // I -> N
499   // H -> H
500   // W -> W
501   // O -> C
502   const auto ker_shape = operands.at(ker_index).shape().asFeature(ir::Layout::NHWC);
503
504   OP_REQUIRES(ifm_shape.N == ofm_shape.N);
505   OP_REQUIRES(ifm_shape.C == ker_shape.C);
506   OP_REQUIRES(ker_shape.N == ofm_shape.C);
507 }
508
509 void ShapeValidator::visit(const ir::operation::Gather &node)
510 {
511   const auto &operands = _graph.operands();
512   const auto ofm_index{node.getOutputs().at(0)};
513   if (operands.at(ofm_index).info().isDynamic())
514     return;
515
516   const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
517   const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
518
519   const auto ifm_shape = operands.at(ifm_index).shape();
520   const auto indices_shape = operands.at(indices_index).shape();
521   const auto ofm_shape = operands.at(ofm_index).shape();
522
523   OP_REQUIRES(ifm_shape.rank() <= 4);
524   OP_REQUIRES(indices_shape.rank() <= 3);
525   OP_REQUIRES(ofm_shape.rank() <= 4);
526 }
527
528 void ShapeValidator::visit(const ir::operation::DepthToSpace &node)
529 {
530   const auto &operands = _graph.operands();
531   int32_t block_size = node.param().block_size;
532
533   // shape check
534   const auto output_index{node.getOutputs().at(0)};
535   if (operands.at(output_index).info().isDynamic())
536     return;
537
538   const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
539
540   const auto frontend_layout = _graph.layout();
541   const auto output_shape = operands.at(output_index).shape().asFeature(frontend_layout);
542   const auto input_shape = operands.at(input_index).shape().asFeature(frontend_layout);
543
544   OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
545   OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
546
547   {
548     OP_REQUIRES(output_shape.N == input_shape.N);
549     OP_REQUIRES(output_shape.H == input_shape.H * block_size);
550     OP_REQUIRES(output_shape.W == input_shape.W * block_size);
551     OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
552     OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
553   }
554 }
555
556 void ShapeValidator::visit(const ir::operation::Pack &node)
557 {
558   const auto &operands = _graph.operands();
559   const auto axis{node.param().axis};
560   const auto output_index{node.getOutputs().at(0)};
561   if (operands.at(output_index).info().isDynamic())
562     return;
563
564   // shape check
565   const auto &output_shape = operands.at(output_index).shape();
566   const auto output_rank = static_cast<int32_t>(output_shape.rank());
567
568   const auto input1_index{node.getInputs().at(0)};
569   const auto input_shape = operands.at(input1_index).shape();
570
571   OP_REQUIRES(axis >= -output_rank && axis < output_rank);
572   for (const auto &index : node.getInputs())
573   {
574     OP_REQUIRES(input_shape == operands.at(index).shape());
575   }
576 }
577
578 void ShapeValidator::visit(const ir::operation::LSTM &node)
579 {
580   // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
581   // TODO Support dynamic rnn
582   const auto &operands = _graph.operands();
583   const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
584   if (operands.at(output_index).info().isDynamic())
585     return;
586
587   const auto scratch_buffer_index{
588     node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional
589   const auto output_state_out_index{
590     node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional
591   const auto cell_state_out_index{
592     node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional
593
594   const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
595   const auto input_to_input_weights_index{
596     node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional
597   const auto input_to_forget_weights_index{
598     node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
599   const auto input_to_cell_weights_index{
600     node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
601   const auto input_to_output_weights_index{
602     node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
603   const auto recurrent_to_input_weights_index{
604     node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional
605   const auto recurrent_to_forget_weights_index{
606     node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
607   const auto recurrent_to_cell_weights_index{
608     node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
609   const auto recurrent_to_output_weights_index{
610     node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
611   const auto cell_to_input_weights_index{
612     node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional
613   const auto cell_to_forget_weights_index{
614     node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional
615   const auto cell_to_output_weights_index{
616     node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional
617   const auto input_gate_bias_index{
618     node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional
619   const auto forget_gate_bias_index{
620     node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
621   const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
622   const auto output_gate_bias_index{
623     node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
624   const auto projection_weights_index{
625     node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional
626   const auto projection_bias_index{
627     node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional
628   const auto output_state_in_index{
629     node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
630   const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
631
632   OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
633   for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
634   {
635     OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
636                 operands.at(output_index).shape().dim(i));
637   }
638   OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
639                operands.at(output_index).shape().rank() == 3) &&
640               (operands.at(input_index).shape().rank() == 2 ||
641                operands.at(input_index).shape().rank() == 3) &&
642               (!operands.exist(input_to_input_weights_index) ||
643                operands.at(input_to_input_weights_index).shape().rank() == 2) &&
644               operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
645               operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
646               operands.at(input_to_output_weights_index).shape().rank() == 2 &&
647               (!operands.exist(recurrent_to_input_weights_index) ||
648                operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
649               operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
650               operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
651               operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
652               (!operands.exist(projection_weights_index) ||
653                operands.at(projection_weights_index).shape().rank() == 2) &&
654               operands.at(output_state_in_index).shape().rank() == 2 &&
655               operands.at(cell_state_in_index).shape().rank() == 2);
656
657   OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
658                operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
659               (!operands.exist(cell_to_forget_weights_index) ||
660                operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
661               (!operands.exist(cell_to_output_weights_index) ||
662                operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
663               (!operands.exist(input_gate_bias_index) ||
664                operands.at(input_gate_bias_index).shape().rank() == 1) &&
665               operands.at(forget_gate_bias_index).shape().rank() == 1 &&
666               operands.at(cell_bias_index).shape().rank() == 1 &&
667               operands.at(output_gate_bias_index).shape().rank() == 1 &&
668               (!operands.exist(projection_bias_index) ||
669                operands.at(projection_bias_index).shape().rank() == 1));
670
671   // CIFG assertion
672   OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
673                 (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
674                  operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
675                (!operands.exist(recurrent_to_input_weights_index) ||
676                 (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
677                  operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
678                (!operands.exist(input_gate_bias_index) ||
679                 operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
680                (!operands.exist(cell_to_input_weights_index) ||
681                 operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
682               ((operands.exist(input_to_input_weights_index) &&
683                 (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
684                  operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
685                (operands.exist(recurrent_to_input_weights_index) &&
686                 (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
687                  operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
688                (operands.exist(input_gate_bias_index) &&
689                 operands.at(input_gate_bias_index).shape().dim(0) != 0)));
690
691   // Peephole assertion
692   OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
693                 operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
694                (!operands.exist(cell_to_output_weights_index) ||
695                 operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
696               ((operands.exist(cell_to_forget_weights_index) &&
697                 operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
698                (operands.exist(cell_to_output_weights_index) &&
699                 operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
700
701   bool has_input_to_input_weights =
702     operands.exist(input_to_input_weights_index) &&
703     (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
704      operands.at(input_to_input_weights_index).shape().dim(1) != 0);
705   bool has_recurrent_to_input_weights =
706     operands.exist(recurrent_to_input_weights_index) &&
707     (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
708      operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
709   bool has_input_gate_bias =
710     operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
711   bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
712                                    operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
713   bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
714                                     operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
715   bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
716                                     operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
717   bool has_projection_weights = operands.exist(projection_weights_index) &&
718                                 (operands.at(projection_weights_index).shape().dim(0) != 0 &&
719                                  operands.at(projection_weights_index).shape().dim(1) != 0);
720   bool has_projection_bias =
721     operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
722
723   // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
724   // true: no CIFG
725   // false: CIFG
726   bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
727
728   // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
729   // true: peephole
730   // false: no peephole
731   bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
732
733   // NOTE The projection weights may have data but the projection bias may not.
734   bool has_projection_param = has_projection_weights;
735
736   const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
737                             ? operands.at(input_index).shape().dim(1)
738                             : operands.at(input_index).shape().dim(0);
739   OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
740               batch_size == operands.at(cell_state_in_index).shape().dim(0));
741
742   const auto input_size =
743     operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
744   OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
745               input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
746               input_size == operands.at(input_to_output_weights_index).shape().dim(1));
747
748   const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
749   OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
750               num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
751               num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
752               num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
753               num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
754               num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
755               num_units == operands.at(cell_bias_index).shape().dim(0) &&
756               num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
757               num_units == operands.at(cell_state_in_index).shape().dim(1));
758
759   const auto output_size =
760     operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
761   OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
762               output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
763               output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
764               output_size == operands.at(output_state_in_index).shape().dim(1));
765
766   if (has_cifg_param)
767   {
768     OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
769     OP_REQUIRES(
770       num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
771       num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
772       ((operands.exist(cell_to_input_weights_index) &&
773         num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
774        (!operands.exist(cell_to_input_weights_index) ||
775         operands.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
776       num_units == operands.at(input_gate_bias_index).shape().dim(0));
777     OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
778     OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
779                 has_input_gate_bias);
780     if (has_cell_to_input_weights)
781     {
782       // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
783       OP_REQUIRES(has_peephole_param);
784     }
785     if (operands.exist(scratch_buffer_index))
786       OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
787   }
788   else
789   {
790     if (operands.exist(scratch_buffer_index))
791       OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
792   }
793
794   if (has_peephole_param)
795   {
796     OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
797                 num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
798                 (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
799                  operands.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
800   }
801
802   if (has_projection_param)
803   {
804     OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
805     OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
806     if (has_projection_bias)
807     {
808       OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
809     }
810   }
811
812   if (operands.exist(scratch_buffer_index))
813   {
814     OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
815     OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
816   }
817
818   if (operands.exist(output_state_out_index))
819   {
820     OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
821     OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
822     OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
823   }
824
825   if (operands.exist(cell_state_out_index))
826   {
827     OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
828     OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
829     OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
830   }
831 }
832
833 void ShapeValidator::visit(const ir::operation::L2Normalization &node)
834 {
835   const auto &operands = _graph.operands();
836   const auto ofm_index{node.getOutputs().at(0)};
837   if (operands.at(ofm_index).info().isDynamic())
838     return;
839
840   const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
841
842   auto ifm_shape = operands.at(ifm_index).shape();
843   auto ofm_shape = operands.at(ofm_index).shape();
844
845   OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
846
847   for (auto i = 0; i < ifm_shape.rank(); i++)
848   {
849     OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
850   }
851 }
852
853 void ShapeValidator::visit(const ir::operation::Unpack &node)
854 {
855   const auto &operands = _graph.operands();
856   const auto axis{node.param().axis};
857   const auto output_index{node.getInputs().at(0)};
858   if (operands.at(output_index).info().isDynamic())
859     return;
860
861   const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
862
863   const auto &input_shape = operands.at(input_index).shape();
864   const auto input_rank = static_cast<int32_t>(input_shape.rank());
865
866   OP_REQUIRES(axis >= -input_rank && axis < input_rank);
867 }
868
869 void ShapeValidator::visit(const ir::operation::Pad &node)
870 {
871   const auto &operands = _graph.operands();
872   const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
873   OP_REQUIRES(operands.at(pad_index).typeInfo().type() == ir::DataType::INT32);
874
875   const auto output_index{node.getInputs().at(0)};
876   if (operands.at(output_index).info().isDynamic())
877     return;
878
879   const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
880
881   const auto &pad_shape = operands.at(pad_index).shape();
882   const auto input_rank = static_cast<int32_t>(operands.at(input_index).shape().rank());
883
884   OP_REQUIRES(pad_shape.rank() == 2);
885   OP_REQUIRES(pad_shape.dim(0) == input_rank);
886   OP_REQUIRES(pad_shape.dim(1) == 2);
887   OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
888 }
889
890 void ShapeValidator::visit(const ir::operation::Select &)
891 {
892   // TODO Shape validation of select
893 }
894
895 void ShapeValidator::visit(const ir::operation::StridedSlice &node)
896 {
897   const auto &operands = _graph.operands();
898   const auto output_index{node.getOutputs().at(0)};
899   const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
900
901   if (operands.at(output_index).info().isDynamic())
902     return;
903
904   OP_REQUIRES(operands.at(input_index).shape().rank() <= 4);
905 }
906
907 void ShapeValidator::visit(const ir::operation::Split &node)
908 {
909   const auto &operands = _graph.operands();
910   const auto output_index{node.getOutputs().at(0)};
911   if (operands.at(output_index).info().isDynamic())
912     return;
913
914   const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
915   const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
916
917   const auto num_splits = node.param().num_splits;
918   const auto input_rank = operands.at(input_index).shape().rank();
919   auto axis = *reinterpret_cast<const int32_t *>(operands.at(axis_index).data()->base());
920   axis = axis < 0 ? axis + input_rank : axis;
921
922   OP_REQUIRES(axis >= 0 && axis < input_rank);
923   OP_REQUIRES(operands.at(input_index).shape().dim(axis) % num_splits == 0);
924 }
925
926 void ShapeValidator::visit(const ir::operation::Shape &node)
927 {
928   const auto &operands = _graph.operands();
929   const auto output_index{node.getOutputs().at(0)};
930   if (operands.at(output_index).info().isDynamic())
931     return;
932
933   const auto input_index{node.getInputs().at(0)};
934   UNUSED_RELEASE(input_index);
935   OP_REQUIRES(operands.at(output_index).shape().rank() == 1);
936 }
937
938 void ShapeValidator::visit(const ir::operation::ResizeBilinear &node)
939 {
940   const auto &operands = _graph.operands();
941   const auto output_index{node.getOutputs().at(0)};
942   const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
943
944   if (operands.at(output_index).info().isDynamic())
945   {
946     return;
947   }
948   OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
949   OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
950 }
951
952 void ShapeValidator::visit(const ir::operation::Reverse &node)
953 {
954   const auto &operands = _graph.operands();
955   const auto output_index{node.getOutputs().at(0)};
956   const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
957
958   if (operands.at(output_index).info().isDynamic())
959     return;
960   OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
961 }
962
963 void ShapeValidator::visit(const ir::operation::If &)
964 {
965   // TODO Add to validate with subgraphs
966 }
967
968 void ShapeValidator::visit(const ir::operation::While &)
969 {
970   // This validator does not check shape. So checking isDynamic() is skipped.
971   // TODO Add to validate with subgraphs
972 }
973
974 void ShapeValidator::visit(const ir::operation::SquaredDifference &node)
975 {
976   const auto &operands = _graph.operands();
977   const auto output_index{node.getOutputs().at(0)};
978   const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
979   const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
980
981   // Check for dimension constraints
982   if (operands.at(output_index).info().isDynamic())
983     return;
984
985   auto output_shape = operands.at(output_index).shape();
986   auto lhs_shape = operands.at(lhs_index).shape();
987   auto rhs_shape = operands.at(rhs_index).shape();
988   // Check for output rank
989   OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
990   auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
991
992   for (int idx = 1; idx <= min_rank; idx++)
993   {
994     int l_idx = lhs_shape.rank() - idx;
995     int r_idx = rhs_shape.rank() - idx;
996     int out_idx = output_shape.rank() - idx;
997
998     OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
999
1000     auto l_dims = lhs_shape.dim(l_idx);
1001     auto r_dims = rhs_shape.dim(r_idx);
1002     auto out_dims = output_shape.dim(out_idx);
1003
1004     OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
1005                 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
1006   }
1007   auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
1008   for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
1009   {
1010     int out_idx = output_shape.rank() - idx;
1011     int tmp_idx = tmp_shape.rank() - idx;
1012
1013     OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
1014                 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
1015   }
1016 }
1017 void ShapeValidator::visit(const ir::operation::Tile &node)
1018 {
1019   const auto &operands = _graph.operands();
1020   const auto output_index{node.getOutputs().at(0)};
1021   if (operands.at(output_index).info().isDynamic())
1022     return;
1023
1024   const auto input_index{node.getInputs().at(0)};
1025   const auto multiple_index{node.getInputs().at(1)};
1026
1027   OP_REQUIRES(operands.at(multiple_index).shape().rank() == 1);
1028   OP_REQUIRES(operands.at(multiple_index).shape().dim(0) ==
1029               operands.at(input_index).shape().rank());
1030   OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
1031 }
1032
1033 void ShapeValidator::visit(const ir::operation::Range &node)
1034 {
1035   const auto &operands = _graph.operands();
1036   const auto output_index{node.getOutputs().at(0)};
1037   const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
1038   const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
1039   const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
1040
1041   // Check for dimension constraints
1042   if (operands.at(output_index).info().isDynamic())
1043     return;
1044
1045   OP_REQUIRES(operands.at(start_index).shape().rank() == 0);
1046   OP_REQUIRES(operands.at(limit_index).shape().rank() == 0);
1047   OP_REQUIRES(operands.at(delta_index).shape().rank() == 0);
1048 }
1049
1050 void ShapeValidator::visit(const ir::operation::MatrixBandPart &node)
1051 {
1052   const auto &operands = _graph.operands();
1053   const auto output_index{node.getOutputs().at(0)};
1054   const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
1055   const auto num_lower_index{
1056     node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)};
1057   const auto num_upper_index{
1058     node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)};
1059
1060   // Check for dimension constraints
1061   if (operands.at(output_index).info().isDynamic())
1062     return;
1063
1064   OP_REQUIRES(operands.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix
1065   OP_REQUIRES(operands.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
1066   OP_REQUIRES(operands.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
1067 }
1068
1069 void ShapeValidator::visit(const ir::operation::LogSoftmax &node)
1070 {
1071   const auto &operands = _graph.operands();
1072   const auto output_index{node.getOutputs().at(0)};
1073   if (operands.at(output_index).info().isDynamic())
1074     return;
1075
1076   const auto input_index{node.getInputs().at(0)};
1077
1078   OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
1079 }
1080
1081 } // namespace compiler
1082 } // namespace onert