Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / ShapeValidator.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ShapeValidator.h"
18
19 #include <typeinfo>
20
21 #include "ir/Graph.h"
22 #include "ir/operation/LowerInfo.h"
23
24 #include "util/logging.h"
25 #include "util/Utils.h"
26
27 #define OP_REQUIRES(EXP)                                                                     \
28   do                                                                                         \
29   {                                                                                          \
30     if (!(EXP))                                                                              \
31       throw std::runtime_error("ShapeValidator failed at line " + std::to_string(__LINE__)); \
32   } while (0)
33
34 namespace onert
35 {
36 namespace compiler
37 {
38
39 ShapeValidator::ShapeValidator(const ir::Graph &graph)
40     : _graph{graph}, _ctx{graph.operands()}, _current_layout{ir::Layout::UNKNOWN}
41 {
42 }
43
44 void ShapeValidator::checkUnaryOp(const ir::Operation &node)
45 {
46   const auto output_index{node.getOutputs().at(0)};
47   const auto input_index{node.getInputs().at(0)};
48
49   if (_ctx.at(output_index).info().isDynamic())
50     return;
51
52   // Check if I/O shapes match
53   OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
54 }
55
56 void ShapeValidator::operator()()
57 {
58   // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when
59   // creating Compiler
60   assert(_graph.subgraphs() == nullptr);
61
62   _current_layout = _graph.layout();
63
64   _graph.operations().iterate(
65       [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
66 }
67
68 void ShapeValidator::visit(const ir::operation::BatchMatMul &node)
69 {
70   const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
71   const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
72   const auto out_index{node.getOutputs().at(0)};
73
74   if (_ctx.at(out_index).info().isDynamic())
75     return;
76
77   OP_REQUIRES(_ctx.at(lhs_index).shape().rank() <= 4);
78   OP_REQUIRES(_ctx.at(rhs_index).shape().rank() <= 4);
79   OP_REQUIRES(_ctx.at(lhs_index).shape().rank() >= 2);
80   OP_REQUIRES(_ctx.at(rhs_index).shape().rank() >= 2);
81 }
82
83 void ShapeValidator::visit(const ir::operation::BatchToSpaceND &node)
84 {
85   const auto ofm_index{node.getOutputs().at(0)};
86   if (_ctx.at(ofm_index).info().isDynamic())
87     return;
88
89   const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
90   const auto block_size_index{
91       node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)};
92
93   const auto frontend_layout = _current_layout;
94   const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
95   const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
96
97   // All requirement as per NNAPI specification.
98   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
99   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
100   OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
101
102   OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
103
104   if (node.getInputs().size() != 2)
105   {
106     const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)};
107     OP_REQUIRES(_ctx.at(crops_index).shape().rank() == 2);
108     OP_REQUIRES(_ctx.at(crops_index).shape().dim(0) == (_ctx.at(ifm_index).shape().rank() - 2));
109     OP_REQUIRES(_ctx.at(crops_index).shape().dim(1) == 2);
110   }
111
112   OP_REQUIRES(input_shape.C == output_shape.C);
113 }
114
115 void ShapeValidator::visit(const ir::operation::BCQFullyConnected &node)
116 {
117   const auto ofm_index{node.getOutputs().at(0)};
118   if (_ctx.at(ofm_index).info().isDynamic())
119     return;
120
121   const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
122   const auto weight_scales_index{
123       node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_SCALES)};
124   const auto weight_binary_index{
125       node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_BINARY)};
126   const auto weight_cluster_index{
127       node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
128   // const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
129
130   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 2);
131   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 2);
132   OP_REQUIRES(_ctx.at(weight_scales_index).shape().rank() == 1);
133   OP_REQUIRES(_ctx.at(weight_binary_index).shape().rank() == 2);
134   OP_REQUIRES(_ctx.at(weight_cluster_index).shape().rank() == 2);
135
136   OP_REQUIRES(_ctx.at(ifm_index).shape().dim(1) == _ctx.at(ofm_index).shape().dim(1));
137
138   OP_REQUIRES(_ctx.at(weight_cluster_index).shape().dim(0) > 0);
139   OP_REQUIRES(_ctx.at(weight_cluster_index).shape().dim(1) == 2);
140
141   // more shape validation will be done inside kernel.
142
143   // TODO Check bias dimension (can be null tensor)
144 }
145
146 void ShapeValidator::visit(const ir::operation::BCQGather &node)
147 {
148   const auto ofm_index{node.getOutputs().at(0)};
149   if (_ctx.at(ofm_index).info().isDynamic())
150     return;
151
152   const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
153   const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
154   const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)};
155   const auto input_clusters_index{
156       node.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
157
158   OP_REQUIRES(_ctx.at(indices_index).shape().rank() <= 2); // TODO : support rank up to 4 or more
159   OP_REQUIRES(_ctx.at(input_binary_index).shape().rank() == 2);
160   OP_REQUIRES(_ctx.at(input_scales_index).shape().rank() == 1);
161   OP_REQUIRES(_ctx.at(input_clusters_index).shape().rank() == 2);
162
163   OP_REQUIRES(_ctx.at(input_clusters_index).shape().dim(0) > 0);
164   OP_REQUIRES(_ctx.at(input_clusters_index).shape().dim(1) == 2);
165
166   // more shape validation will be done inside kernel.
167 }
168
169 void ShapeValidator::visit(const ir::operation::Comparison &)
170 {
171   // TODO Shape validation of comparison
172 }
173
174 void ShapeValidator::visit(const ir::operation::Softmax &node)
175 {
176   const auto output_index{node.getOutputs().at(0)};
177   if (_ctx.at(output_index).info().isDynamic())
178     return;
179
180   const auto input_index{node.getInputs().at(0)};
181
182   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
183 }
184
185 void ShapeValidator::visit(const ir::operation::InstanceNorm &node)
186 {
187   const auto ofm_index{node.getOutputs().at(0)};
188   if (_ctx.at(ofm_index).info().isDynamic())
189     return;
190
191   const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
192   const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
193   const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
194
195   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
196   OP_REQUIRES(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape());
197   OP_REQUIRES(_ctx.at(gamma_index).shape().rank() == 1);
198   OP_REQUIRES(_ctx.at(beta_index).shape().rank() == 1);
199 }
200
201 void ShapeValidator::visit(const ir::operation::Pool2D &node)
202 {
203   const auto ofm_index{node.getOutputs().at(0)};
204   if (_ctx.at(ofm_index).info().isDynamic())
205     return;
206
207   const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
208
209   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
210 }
211
212 void ShapeValidator::visit(const ir::operation::Permute &node)
213 {
214   const auto output_index{node.getOutputs().at(0)};
215   if (_ctx.at(output_index).info().isDynamic())
216     return;
217
218   const auto input_index{node.getInputs().at(0)};
219
220   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
221 }
222
223 void ShapeValidator::visit(const ir::operation::Reduce &node)
224 {
225   const auto output_index{node.getOutputs().at(0)};
226   if (_ctx.at(output_index).info().isDynamic())
227     return;
228
229   const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
230   const auto input_shape = _ctx.at(input_index).shape();
231   const auto output_shape = _ctx.at(output_index).shape();
232
233   OP_REQUIRES(input_shape.rank() <= 4);
234   OP_REQUIRES(output_shape.rank() <= input_shape.rank());
235
236   // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
237   // supports cases reducing height and width or reducing depth.
238   // TODO We have to support all cases of dimensions up to 4.
239   // For correct permuting, we have to set output's shape to be equal in dimension position of the
240   // input. But the positions of the same dimensions in the input and output may be set differently.
241   // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
242   // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
243   // extend it in 4 dimensions, it should be {1,1,3,5}.
244   // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
245   // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
246   // next operation is not desired.
247   if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
248   {
249     if (output_shape.rank() == 2)
250     {
251       // Reducing HW
252       OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
253                   input_shape.dim(3) == output_shape.dim(1));
254     }
255     else if (output_shape.rank() == 3)
256     {
257       // Reducing C or
258       // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
259       OP_REQUIRES((input_shape.dim(0) == output_shape.dim(0) &&
260                    input_shape.dim(1) == output_shape.dim(1) &&
261                    input_shape.dim(2) == output_shape.dim(2)) ||
262                   (input_shape.dim(0) == output_shape.dim(0) &&
263                    (input_shape.dim(1) == output_shape.dim(1) ||
264                     input_shape.dim(2) == output_shape.dim(1)) &&
265                    input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
266     }
267   }
268 }
269
270 void ShapeValidator::visit(const ir::operation::Transpose &node)
271 {
272   const auto output_index{node.getOutputs().at(0)};
273   if (_ctx.at(output_index).info().isDynamic())
274     return;
275
276   const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
277   const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
278
279   const auto &output_shape = _ctx.at(output_index).shape();
280   const auto &input_shape = _ctx.at(input_index).shape();
281
282   OP_REQUIRES(_ctx.at(perm_index).shape().num_elements() == 0 ||
283               input_shape.rank() == static_cast<int>(_ctx.at(perm_index).shape().num_elements()));
284   OP_REQUIRES(input_shape.rank() == output_shape.rank());
285 }
286
287 void ShapeValidator::visit(const ir::operation::RNN &node)
288 {
289   // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
290   // TODO Support dynamic rnn
291   const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
292   if (_ctx.at(output_index).info().isDynamic())
293     return;
294
295   const auto hidden_state_out_index{
296       node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)};
297
298   const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
299   const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
300   const auto recurrent_weights_index{
301       node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)};
302   const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
303   const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
304
305   const auto batch_size = _ctx.at(output_index).shape().dim(0);
306   const auto num_units = _ctx.at(output_index).shape().dim(1);
307
308   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 2 &&
309               _ctx.at(hidden_state_out_index).shape().rank() == 2 &&
310               _ctx.at(input_index).shape().rank() == 2 &&
311               _ctx.at(weights_index).shape().rank() == 2 &&
312               _ctx.at(recurrent_weights_index).shape().rank() == 2 &&
313               _ctx.at(hidden_state_in_index).shape().rank() == 2);
314   OP_REQUIRES(_ctx.at(bias_index).shape().rank() == 1);
315
316   OP_REQUIRES(batch_size == _ctx.at(input_index).shape().dim(0) &&
317               batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) &&
318               batch_size == _ctx.at(hidden_state_out_index).shape().dim(0));
319   OP_REQUIRES(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1));
320
321   OP_REQUIRES(num_units == _ctx.at(weights_index).shape().dim(0) &&
322               num_units == _ctx.at(recurrent_weights_index).shape().dim(0) &&
323               num_units == _ctx.at(bias_index).shape().dim(0));
324   OP_REQUIRES(num_units == _ctx.at(output_index).shape().dim(1) &&
325               num_units == _ctx.at(recurrent_weights_index).shape().dim(1) &&
326               num_units == _ctx.at(hidden_state_in_index).shape().dim(1) &&
327               num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
328 }
329
330 void ShapeValidator::visit(const ir::operation::SpaceToBatchND &node)
331 {
332   const auto ofm_index{node.getOutputs().at(0)};
333   if (_ctx.at(ofm_index).info().isDynamic())
334     return;
335
336   const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
337   const auto block_size_index{
338       node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
339   const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
340
341   const auto frontend_layout = _current_layout;
342   const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
343   const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
344
345   // All requirement as per NNAPI specification.
346   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
347   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
348   OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
349   OP_REQUIRES(_ctx.at(paddings_index).shape().rank() == 2);
350
351   OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
352   OP_REQUIRES(_ctx.at(paddings_index).shape().dim(0) == 2);
353   OP_REQUIRES(_ctx.at(paddings_index).shape().dim(1) == 2);
354
355   OP_REQUIRES(input_shape.C == output_shape.C);
356 }
357
358 void ShapeValidator::visit(const ir::operation::SpaceToDepth &node)
359 {
360   const auto ofm_index{node.getOutputs().at(0)};
361   if (_ctx.at(ofm_index).info().isDynamic())
362     return;
363
364   const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
365
366   const auto frontend_layout = _current_layout;
367   const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
368   const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
369   const auto block_size = node.param().block_size;
370
371   // All assertions as per NNAPI specification.
372   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
373   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
374   OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
375   OP_REQUIRES(input_shape.N == output_shape.N);
376   OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
377 }
378
379 void ShapeValidator::visit(const ir::operation::ElementwiseActivation &node) { checkUnaryOp(node); }
380
381 void ShapeValidator::visit(const ir::operation::ElementwiseBinary &)
382 {
383   // TODO Shape validation of ElementwiseBinary
384 }
385
386 void ShapeValidator::visit(const ir::operation::ElementwiseUnary &node)
387 {
388   const auto output_index{node.getOutputs().at(0)};
389   const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
390
391   if (_ctx.at(output_index).info().isDynamic())
392     return;
393
394   OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
395 }
396
397 void ShapeValidator::visit(const ir::operation::EmbeddingLookup &node)
398 {
399   const auto output_index{node.getOutputs().at(0)};
400   const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
401   const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
402
403   const auto &output_obj = _ctx.at(output_index);
404   const auto &lookups_obj = _ctx.at(lookups_index);
405   const auto &values_obj = _ctx.at(values_index);
406
407   // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
408   // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
409   {
410     if (_ctx.at(output_index).info().isDynamic())
411       return;
412
413     const auto &output_shape = output_obj.shape();
414     const auto &lookups_shape = lookups_obj.shape();
415     const auto &values_shape = values_obj.shape();
416
417     OP_REQUIRES(lookups_shape.rank() == 1);
418     OP_REQUIRES(values_shape.rank() >= 2);
419
420     // output should be a n-D tensor with the same rank and shape as the values tensor, except for
421     // the first dimension which has the same size as lookups' only dimension.
422     OP_REQUIRES(output_shape.rank() == values_shape.rank());
423     OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
424     for (int n = 1; n < output_shape.rank(); ++n)
425     {
426       OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
427     }
428   }
429 }
430
431 void ShapeValidator::visit(const ir::operation::ExpandDims &node)
432 {
433   const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
434
435   if (_ctx.at(axis_index).info().isDynamic())
436     return;
437   OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1);
438 }
439
440 void ShapeValidator::visit(const ir::operation::HashtableLookup &node)
441 {
442   const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
443   const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
444   const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
445   const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
446
447   const auto &output_obj = _ctx.at(output_index);
448   const auto &lookups_obj = _ctx.at(lookups_index);
449   const auto &keys_obj = _ctx.at(keys_index);
450   const auto &values_obj = _ctx.at(values_index);
451
452   if (_ctx.at(output_index).info().isDynamic())
453     return;
454
455   const auto &output_shape = output_obj.shape();
456   const auto &lookups_shape = lookups_obj.shape();
457   const auto &keys_shape = keys_obj.shape();
458   const auto &values_shape = values_obj.shape();
459
460   OP_REQUIRES(values_shape.rank() == output_shape.rank());
461   OP_REQUIRES(lookups_shape.rank() == 1);
462   OP_REQUIRES(keys_shape.rank() == 1);
463   OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
464   OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
465 }
466
467 void ShapeValidator::visit(const ir::operation::TransposeConv &node)
468 {
469   // shape check
470   const auto ofm_index{node.getOutputs().at(0)};
471   if (_ctx.at(ofm_index).info().isDynamic())
472     return;
473
474   const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
475   const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
476
477   // Only 4D tensors are supported
478   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
479   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ifm_index).shape().rank());
480   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ker_index).shape().rank());
481
482   const auto frontend_layout = _current_layout;
483   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
484   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
485   // The kernel has only IHWO layout on frontend
486   // So ker_shape is treated here below
487   // I -> N
488   // H -> H
489   // W -> W
490   // O -> C
491   const auto ker_shape = _ctx.at(ker_index).shape().asFeature(ir::Layout::NHWC);
492
493   OP_REQUIRES(ifm_shape.N == ofm_shape.N);
494   OP_REQUIRES(ifm_shape.C == ker_shape.C);
495   OP_REQUIRES(ker_shape.N == ofm_shape.C);
496 }
497
498 void ShapeValidator::visit(const ir::operation::Gather &node)
499 {
500   const auto ofm_index{node.getOutputs().at(0)};
501   if (_ctx.at(ofm_index).info().isDynamic())
502     return;
503
504   const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
505   const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
506
507   const auto ifm_shape = _ctx.at(ifm_index).shape();
508   const auto indices_shape = _ctx.at(indices_index).shape();
509   const auto ofm_shape = _ctx.at(ofm_index).shape();
510
511   OP_REQUIRES(ifm_shape.rank() <= 4);
512   OP_REQUIRES(indices_shape.rank() <= 3);
513   OP_REQUIRES(ofm_shape.rank() <= 4);
514 }
515
516 void ShapeValidator::visit(const ir::operation::DepthToSpace &node)
517 {
518   int32_t block_size = node.param().block_size;
519
520   // shape check
521   const auto output_index{node.getOutputs().at(0)};
522   if (_ctx.at(output_index).info().isDynamic())
523     return;
524
525   const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
526
527   const auto frontend_layout = _current_layout;
528   const auto output_shape = _ctx.at(output_index).shape().asFeature(frontend_layout);
529   const auto input_shape = _ctx.at(input_index).shape().asFeature(frontend_layout);
530
531   OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
532   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
533
534   {
535     OP_REQUIRES(output_shape.N == input_shape.N);
536     OP_REQUIRES(output_shape.H == input_shape.H * block_size);
537     OP_REQUIRES(output_shape.W == input_shape.W * block_size);
538     OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
539     OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
540   }
541 }
542
543 void ShapeValidator::visit(const ir::operation::Pack &node)
544 {
545   const auto axis{node.param().axis};
546   const auto output_index{node.getOutputs().at(0)};
547   if (_ctx.at(output_index).info().isDynamic())
548     return;
549
550   // shape check
551   const auto &output_shape = _ctx.at(output_index).shape();
552   const auto output_rank = static_cast<int32_t>(output_shape.rank());
553
554   const auto input1_index{node.getInputs().at(0)};
555   const auto input_shape = _ctx.at(input1_index).shape();
556
557   OP_REQUIRES(axis >= -output_rank && axis < output_rank);
558   for (const auto &index : node.getInputs())
559   {
560     OP_REQUIRES(input_shape == _ctx.at(index).shape());
561   }
562 }
563
564 void ShapeValidator::visit(const ir::operation::LSTM &node)
565 {
566   // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
567   // TODO Support dynamic rnn
568   const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
569   if (_ctx.at(output_index).info().isDynamic())
570     return;
571
572   const auto scratch_buffer_index{
573       node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional
574   const auto output_state_out_index{
575       node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional
576   const auto cell_state_out_index{
577       node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional
578
579   const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
580   const auto input_to_input_weights_index{
581       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional
582   const auto input_to_forget_weights_index{
583       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
584   const auto input_to_cell_weights_index{
585       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
586   const auto input_to_output_weights_index{
587       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
588   const auto recurrent_to_input_weights_index{
589       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional
590   const auto recurrent_to_forget_weights_index{
591       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
592   const auto recurrent_to_cell_weights_index{
593       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
594   const auto recurrent_to_output_weights_index{
595       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
596   const auto cell_to_input_weights_index{
597       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional
598   const auto cell_to_forget_weights_index{
599       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional
600   const auto cell_to_output_weights_index{
601       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional
602   const auto input_gate_bias_index{
603       node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional
604   const auto forget_gate_bias_index{
605       node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
606   const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
607   const auto output_gate_bias_index{
608       node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
609   const auto projection_weights_index{
610       node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional
611   const auto projection_bias_index{
612       node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional
613   const auto output_state_in_index{
614       node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
615   const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
616
617   OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
618   for (int i = 0; i < _ctx.at(input_index).shape().rank() - 1; ++i)
619   {
620     OP_REQUIRES(_ctx.at(input_index).shape().dim(i) == _ctx.at(output_index).shape().dim(i));
621   }
622   OP_REQUIRES(
623       (_ctx.at(output_index).shape().rank() == 2 || _ctx.at(output_index).shape().rank() == 3) &&
624       (_ctx.at(input_index).shape().rank() == 2 || _ctx.at(input_index).shape().rank() == 3) &&
625       (!_ctx.exist(input_to_input_weights_index) ||
626        _ctx.at(input_to_input_weights_index).shape().rank() == 2) &&
627       _ctx.at(input_to_forget_weights_index).shape().rank() == 2 &&
628       _ctx.at(input_to_cell_weights_index).shape().rank() == 2 &&
629       _ctx.at(input_to_output_weights_index).shape().rank() == 2 &&
630       (!_ctx.exist(recurrent_to_input_weights_index) ||
631        _ctx.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
632       _ctx.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
633       _ctx.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
634       _ctx.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
635       (!_ctx.exist(projection_weights_index) ||
636        _ctx.at(projection_weights_index).shape().rank() == 2) &&
637       _ctx.at(output_state_in_index).shape().rank() == 2 &&
638       _ctx.at(cell_state_in_index).shape().rank() == 2);
639
640   OP_REQUIRES(
641       (!_ctx.exist(cell_to_input_weights_index) ||
642        _ctx.at(cell_to_input_weights_index).shape().rank() == 1) &&
643       (!_ctx.exist(cell_to_forget_weights_index) ||
644        _ctx.at(cell_to_forget_weights_index).shape().rank() == 1) &&
645       (!_ctx.exist(cell_to_output_weights_index) ||
646        _ctx.at(cell_to_output_weights_index).shape().rank() == 1) &&
647       (!_ctx.exist(input_gate_bias_index) || _ctx.at(input_gate_bias_index).shape().rank() == 1) &&
648       _ctx.at(forget_gate_bias_index).shape().rank() == 1 &&
649       _ctx.at(cell_bias_index).shape().rank() == 1 &&
650       _ctx.at(output_gate_bias_index).shape().rank() == 1 &&
651       (!_ctx.exist(projection_bias_index) || _ctx.at(projection_bias_index).shape().rank() == 1));
652
653   // CIFG assertion
654   OP_REQUIRES(
655       ((!_ctx.exist(input_to_input_weights_index) ||
656         (_ctx.at(input_to_input_weights_index).shape().dim(0) == 0 &&
657          _ctx.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
658        (!_ctx.exist(recurrent_to_input_weights_index) ||
659         (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
660          _ctx.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
661        (!_ctx.exist(input_gate_bias_index) || _ctx.at(input_gate_bias_index).shape().dim(0) == 0) &&
662        (!_ctx.exist(cell_to_input_weights_index) ||
663         _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
664       ((_ctx.exist(input_to_input_weights_index) &&
665         (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
666          _ctx.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
667        (_ctx.exist(recurrent_to_input_weights_index) &&
668         (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
669          _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
670        (_ctx.exist(input_gate_bias_index) && _ctx.at(input_gate_bias_index).shape().dim(0) != 0)));
671
672   // Peephole assertion
673   OP_REQUIRES(((!_ctx.exist(cell_to_forget_weights_index) ||
674                 _ctx.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
675                (!_ctx.exist(cell_to_output_weights_index) ||
676                 _ctx.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
677               ((_ctx.exist(cell_to_forget_weights_index) &&
678                 _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
679                (_ctx.exist(cell_to_output_weights_index) &&
680                 _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0)));
681
682   bool has_input_to_input_weights = _ctx.exist(input_to_input_weights_index) &&
683                                     (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
684                                      _ctx.at(input_to_input_weights_index).shape().dim(1) != 0);
685   bool has_recurrent_to_input_weights =
686       _ctx.exist(recurrent_to_input_weights_index) &&
687       (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
688        _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
689   bool has_input_gate_bias =
690       _ctx.exist(input_gate_bias_index) && _ctx.at(input_gate_bias_index).shape().dim(0) != 0;
691   bool has_cell_to_input_weights = _ctx.exist(cell_to_input_weights_index) &&
692                                    _ctx.at(cell_to_input_weights_index).shape().dim(0) != 0;
693   bool has_cell_to_forget_weights = _ctx.exist(cell_to_forget_weights_index) &&
694                                     _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0;
695   bool has_cell_to_output_weights = _ctx.exist(cell_to_output_weights_index) &&
696                                     _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0;
697   bool has_projection_weights = _ctx.exist(projection_weights_index) &&
698                                 (_ctx.at(projection_weights_index).shape().dim(0) != 0 &&
699                                  _ctx.at(projection_weights_index).shape().dim(1) != 0);
700   bool has_projection_bias =
701       _ctx.exist(projection_bias_index) && _ctx.at(projection_bias_index).shape().dim(0) != 0;
702
703   // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
704   // true: no CIFG
705   // false: CIFG
706   bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
707
708   // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
709   // true: peephole
710   // false: no peephole
711   bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
712
713   // NOTE The projection weights may have data but the projection bias may not.
714   bool has_projection_param = has_projection_weights;
715
716   const auto batch_size = (_ctx.at(input_index).shape().rank() == 3 && node.param().time_major)
717                               ? _ctx.at(input_index).shape().dim(1)
718                               : _ctx.at(input_index).shape().dim(0);
719   OP_REQUIRES(batch_size == _ctx.at(output_state_in_index).shape().dim(0) &&
720               batch_size == _ctx.at(cell_state_in_index).shape().dim(0));
721
722   const auto input_size = _ctx.at(input_index).shape().dim(_ctx.at(input_index).shape().rank() - 1);
723   OP_REQUIRES(input_size == _ctx.at(input_to_forget_weights_index).shape().dim(1) &&
724               input_size == _ctx.at(input_to_cell_weights_index).shape().dim(1) &&
725               input_size == _ctx.at(input_to_output_weights_index).shape().dim(1));
726
727   const auto num_units = _ctx.at(input_to_output_weights_index).shape().dim(0);
728   OP_REQUIRES(num_units == _ctx.at(input_to_cell_weights_index).shape().dim(0) &&
729               num_units == _ctx.at(input_to_output_weights_index).shape().dim(0) &&
730               num_units == _ctx.at(recurrent_to_forget_weights_index).shape().dim(0) &&
731               num_units == _ctx.at(recurrent_to_cell_weights_index).shape().dim(0) &&
732               num_units == _ctx.at(recurrent_to_output_weights_index).shape().dim(0) &&
733               num_units == _ctx.at(forget_gate_bias_index).shape().dim(0) &&
734               num_units == _ctx.at(cell_bias_index).shape().dim(0) &&
735               num_units == _ctx.at(output_gate_bias_index).shape().dim(0) &&
736               num_units == _ctx.at(cell_state_in_index).shape().dim(1));
737
738   const auto output_size =
739       _ctx.at(output_index).shape().dim(_ctx.at(output_index).shape().rank() - 1);
740   OP_REQUIRES(output_size == _ctx.at(recurrent_to_forget_weights_index).shape().dim(1) &&
741               output_size == _ctx.at(recurrent_to_cell_weights_index).shape().dim(1) &&
742               output_size == _ctx.at(recurrent_to_output_weights_index).shape().dim(1) &&
743               output_size == _ctx.at(output_state_in_index).shape().dim(1));
744
745   if (has_cifg_param)
746   {
747     OP_REQUIRES(input_size == _ctx.at(input_to_input_weights_index).shape().dim(1));
748     OP_REQUIRES(num_units == _ctx.at(input_to_input_weights_index).shape().dim(0) &&
749                 num_units == _ctx.at(recurrent_to_input_weights_index).shape().dim(0) &&
750                 ((_ctx.exist(cell_to_input_weights_index) &&
751                   num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0)) ||
752                  (!_ctx.exist(cell_to_input_weights_index) ||
753                   _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
754                 num_units == _ctx.at(input_gate_bias_index).shape().dim(0));
755     OP_REQUIRES(output_size == _ctx.at(recurrent_to_input_weights_index).shape().dim(1));
756     OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
757                 has_input_gate_bias);
758     if (has_cell_to_input_weights)
759     {
760       // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
761       OP_REQUIRES(has_peephole_param);
762     }
763     if (_ctx.exist(scratch_buffer_index))
764       OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
765   }
766   else
767   {
768     if (_ctx.exist(scratch_buffer_index))
769       OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
770   }
771
772   if (has_peephole_param)
773   {
774     OP_REQUIRES(num_units == _ctx.at(cell_to_forget_weights_index).shape().dim(0) &&
775                 num_units == _ctx.at(cell_to_output_weights_index).shape().dim(0) &&
776                 (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
777                  _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
778   }
779
780   if (has_projection_param)
781   {
782     OP_REQUIRES(num_units == _ctx.at(projection_weights_index).shape().dim(1));
783     OP_REQUIRES(output_size == _ctx.at(projection_weights_index).shape().dim(0));
784     if (has_projection_bias)
785     {
786       OP_REQUIRES(output_size == _ctx.at(projection_bias_index).shape().dim(0));
787     }
788   }
789
790   if (_ctx.exist(scratch_buffer_index))
791   {
792     OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().rank() == 2);
793     OP_REQUIRES(batch_size == _ctx.at(scratch_buffer_index).shape().dim(0));
794   }
795
796   if (_ctx.exist(output_state_out_index))
797   {
798     OP_REQUIRES(_ctx.at(output_state_out_index).shape().rank() == 2);
799     OP_REQUIRES(batch_size == _ctx.at(output_state_out_index).shape().dim(0));
800     OP_REQUIRES(output_size == _ctx.at(output_state_out_index).shape().dim(1));
801   }
802
803   if (_ctx.exist(cell_state_out_index))
804   {
805     OP_REQUIRES(_ctx.at(cell_state_out_index).shape().rank() == 2);
806     OP_REQUIRES(batch_size == _ctx.at(cell_state_out_index).shape().dim(0));
807     OP_REQUIRES(num_units == _ctx.at(cell_state_out_index).shape().dim(1));
808   }
809 }
810
811 void ShapeValidator::visit(const ir::operation::L2Normalization &node)
812 {
813   const auto ofm_index{node.getOutputs().at(0)};
814   if (_ctx.at(ofm_index).info().isDynamic())
815     return;
816
817   const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
818
819   auto ifm_shape = _ctx.at(ifm_index).shape();
820   auto ofm_shape = _ctx.at(ofm_index).shape();
821
822   OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
823
824   for (auto i = 0; i < ifm_shape.rank(); i++)
825   {
826     OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
827   }
828 }
829
830 void ShapeValidator::visit(const ir::operation::Unpack &node)
831 {
832   const auto axis{node.param().axis};
833   const auto output_index{node.getInputs().at(0)};
834   if (_ctx.at(output_index).info().isDynamic())
835     return;
836
837   const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
838
839   const auto &input_shape = _ctx.at(input_index).shape();
840   const auto input_rank = static_cast<int32_t>(input_shape.rank());
841
842   OP_REQUIRES(axis >= -input_rank && axis < input_rank);
843 }
844
845 void ShapeValidator::visit(const ir::operation::Pad &node)
846 {
847   const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
848   OP_REQUIRES(_ctx.at(pad_index).typeInfo().type() == ir::DataType::INT32);
849
850   const auto output_index{node.getInputs().at(0)};
851   if (_ctx.at(output_index).info().isDynamic())
852     return;
853
854   const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
855
856   const auto &pad_shape = _ctx.at(pad_index).shape();
857   const auto input_rank = static_cast<int32_t>(_ctx.at(input_index).shape().rank());
858
859   OP_REQUIRES(pad_shape.rank() == 2);
860   OP_REQUIRES(pad_shape.dim(0) == input_rank);
861   OP_REQUIRES(pad_shape.dim(1) == 2);
862   OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
863 }
864
865 void ShapeValidator::visit(const ir::operation::Select &)
866 {
867   // TODO Shape validation of select
868 }
869
870 void ShapeValidator::visit(const ir::operation::StridedSlice &node)
871 {
872   const auto output_index{node.getOutputs().at(0)};
873   const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
874
875   if (_ctx.at(output_index).info().isDynamic())
876     return;
877
878   OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4);
879 }
880
881 void ShapeValidator::visit(const ir::operation::Split &node)
882 {
883   const auto output_index{node.getOutputs().at(0)};
884   if (_ctx.at(output_index).info().isDynamic())
885     return;
886
887   const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
888   const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
889
890   const auto num_splits = node.param().num_splits;
891   const auto input_rank = _ctx.at(input_index).shape().rank();
892   auto axis = *reinterpret_cast<const int32_t *>(_ctx.at(axis_index).data()->base());
893   axis = axis < 0 ? axis + input_rank : axis;
894
895   OP_REQUIRES(axis >= 0 && axis < input_rank);
896   OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0);
897 }
898
899 void ShapeValidator::visit(const ir::operation::Shape &node)
900 {
901   const auto output_index{node.getOutputs().at(0)};
902   if (_ctx.at(output_index).info().isDynamic())
903     return;
904
905   const auto input_index{node.getInputs().at(0)};
906   UNUSED_RELEASE(input_index);
907   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
908 }
909
910 void ShapeValidator::visit(const ir::operation::ResizeBilinear &node)
911 {
912   const auto output_index{node.getOutputs().at(0)};
913   const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
914
915   if (_ctx.at(output_index).info().isDynamic())
916   {
917     return;
918   }
919   OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
920   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
921 }
922
923 void ShapeValidator::visit(const ir::operation::Reverse &node)
924 {
925   const auto output_index{node.getOutputs().at(0)};
926   const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
927
928   if (_ctx.at(output_index).info().isDynamic())
929     return;
930   OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
931 }
932
933 void ShapeValidator::visit(const ir::operation::If &)
934 {
935   // TODO Add to validate with subgraphs
936 }
937
938 void ShapeValidator::visit(const ir::operation::While &)
939 {
940   // This validator does not check shape. So checking isDynamic() is skipped.
941   // TODO Add to validate with subgraphs
942 }
943
944 void ShapeValidator::visit(const ir::operation::SquaredDifference &node)
945 {
946   const auto output_index{node.getOutputs().at(0)};
947   const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
948   const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
949
950   // Check for dimension constraints
951   if (_ctx.at(output_index).info().isDynamic())
952     return;
953
954   auto output_shape = _ctx.at(output_index).shape();
955   auto lhs_shape = _ctx.at(lhs_index).shape();
956   auto rhs_shape = _ctx.at(rhs_index).shape();
957   // Check for output rank
958   OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
959   auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
960
961   for (int idx = 1; idx <= min_rank; idx++)
962   {
963     int l_idx = lhs_shape.rank() - idx;
964     int r_idx = rhs_shape.rank() - idx;
965     int out_idx = output_shape.rank() - idx;
966
967     OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
968
969     auto l_dims = lhs_shape.dim(l_idx);
970     auto r_dims = rhs_shape.dim(r_idx);
971     auto out_dims = output_shape.dim(out_idx);
972
973     OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
974                 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
975   }
976   auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
977   for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
978   {
979     int out_idx = output_shape.rank() - idx;
980     int tmp_idx = tmp_shape.rank() - idx;
981
982     OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
983                 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
984   }
985 }
986 void ShapeValidator::visit(const ir::operation::Tile &node)
987 {
988   const auto output_index{node.getOutputs().at(0)};
989   if (_ctx.at(output_index).info().isDynamic())
990     return;
991
992   const auto input_index{node.getInputs().at(0)};
993   const auto multiple_index{node.getInputs().at(1)};
994
995   OP_REQUIRES(_ctx.at(multiple_index).shape().rank() == 1);
996   OP_REQUIRES(_ctx.at(multiple_index).shape().dim(0) == _ctx.at(input_index).shape().rank());
997   OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
998 }
999
1000 void ShapeValidator::visit(const ir::operation::Range &node)
1001 {
1002   const auto output_index{node.getOutputs().at(0)};
1003   const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
1004   const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
1005   const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
1006
1007   // Check for dimension constraints
1008   if (_ctx.at(output_index).info().isDynamic())
1009     return;
1010
1011   OP_REQUIRES(_ctx.at(start_index).shape().rank() == 0);
1012   OP_REQUIRES(_ctx.at(limit_index).shape().rank() == 0);
1013   OP_REQUIRES(_ctx.at(delta_index).shape().rank() == 0);
1014 }
1015
1016 void ShapeValidator::visit(const ir::operation::MatrixBandPart &node)
1017 {
1018   const auto output_index{node.getOutputs().at(0)};
1019   const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
1020   const auto num_lower_index{
1021       node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)};
1022   const auto num_upper_index{
1023       node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)};
1024
1025   // Check for dimension constraints
1026   if (_ctx.at(output_index).info().isDynamic())
1027     return;
1028
1029   OP_REQUIRES(_ctx.at(input_index).shape().rank() >= 2);     // input must be more than 2 dim matrix
1030   OP_REQUIRES(_ctx.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
1031   OP_REQUIRES(_ctx.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
1032 }
1033
1034 void ShapeValidator::visit(const ir::operation::LogSoftmax &node)
1035 {
1036   const auto output_index{node.getOutputs().at(0)};
1037   if (_ctx.at(output_index).info().isDynamic())
1038     return;
1039
1040   const auto input_index{node.getInputs().at(0)};
1041
1042   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
1043 }
1044
1045 } // namespace compiler
1046 } // namespace onert