c18178da9c1fd0b0336883efee69c697b596dc4c
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / ShapeValidator.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ShapeValidator.h"
18
19 #include <typeinfo>
20
21 #include "ir/Graph.h"
22 #include "ir/operation/LowerInfo.h"
23
24 #include "util/logging.h"
25 #include "util/Utils.h"
26
27 #define OP_REQUIRES(EXP)                                                                     \
28   do                                                                                         \
29   {                                                                                          \
30     if (!(EXP))                                                                              \
31       throw std::runtime_error("ShapeValidator failed at line " + std::to_string(__LINE__)); \
32   } while (0)
33
34 namespace onert
35 {
36 namespace compiler
37 {
38
39 ShapeValidator::ShapeValidator(const ir::Graph &graph)
40     : _graph{graph}, _ctx{graph.operands()}, _current_op_seq_layout{ir::Layout::UNKNOWN}
41 {
42 }
43
44 void ShapeValidator::checkUnaryOp(const ir::Operation &node)
45 {
46   const auto output_index{node.getOutputs().at(0)};
47   const auto input_index{node.getInputs().at(0)};
48
49   if (_ctx.at(output_index).info().isDynamic())
50     return;
51
52   // Check if I/O shapes match
53   OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
54 }
55
56 void ShapeValidator::operator()()
57 {
58   // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when
59   // creating Compiler
60   assert(_graph.subgraphs() == nullptr);
61
62   _current_op_seq_layout = _graph.layout();
63
64   _graph.operations().iterate(
65       [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
66 }
67
68 void ShapeValidator::visit(const ir::operation::BatchMatMul &node)
69 {
70   const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
71   const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
72   const auto out_index{node.getOutputs().at(0)};
73
74   if (_ctx.at(out_index).info().isDynamic())
75     return;
76
77   OP_REQUIRES(_ctx.at(lhs_index).shape().rank() <= 4);
78   OP_REQUIRES(_ctx.at(rhs_index).shape().rank() <= 4);
79   OP_REQUIRES(_ctx.at(lhs_index).shape().rank() >= 2);
80   OP_REQUIRES(_ctx.at(rhs_index).shape().rank() >= 2);
81 }
82
83 void ShapeValidator::visit(const ir::operation::BatchToSpaceND &node)
84 {
85   const auto ofm_index{node.getOutputs().at(0)};
86   if (_ctx.at(ofm_index).info().isDynamic())
87     return;
88
89   const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
90   const auto block_size_index{
91       node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)};
92
93   const auto frontend_layout = _current_op_seq_layout;
94   const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
95   const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
96
97   // All requirement as per NNAPI specification.
98   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
99   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
100   OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
101
102   OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
103
104   OP_REQUIRES(input_shape.C == output_shape.C);
105 }
106
107 void ShapeValidator::visit(const ir::operation::BCQFullyConnected &node)
108 {
109   const auto ofm_index{node.getOutputs().at(0)};
110   if (_ctx.at(ofm_index).info().isDynamic())
111     return;
112
113   const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
114   const auto weight_scales_index{
115       node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_SCALES)};
116   const auto weight_binary_index{
117       node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_BINARY)};
118   const auto weight_cluster_index{
119       node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
120   // const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
121
122   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 2);
123   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 2);
124   OP_REQUIRES(_ctx.at(weight_scales_index).shape().rank() == 1);
125   OP_REQUIRES(_ctx.at(weight_binary_index).shape().rank() == 2);
126   OP_REQUIRES(_ctx.at(weight_cluster_index).shape().rank() == 2);
127
128   OP_REQUIRES(_ctx.at(ifm_index).shape().dim(1) == _ctx.at(ofm_index).shape().dim(1));
129
130   OP_REQUIRES(_ctx.at(weight_cluster_index).shape().dim(0) > 0);
131   OP_REQUIRES(_ctx.at(weight_cluster_index).shape().dim(1) == 2);
132
133   // more shape validation will be done inside kernel.
134
135   // TODO Check bias dimension (can be null tensor)
136 }
137
138 void ShapeValidator::visit(const ir::operation::BCQGather &node)
139 {
140   const auto ofm_index{node.getOutputs().at(0)};
141   if (_ctx.at(ofm_index).info().isDynamic())
142     return;
143
144   const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
145   const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
146   const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)};
147   const auto input_clusters_index{
148       node.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
149
150   OP_REQUIRES(_ctx.at(indices_index).shape().rank() <= 2); // TODO : support rank up to 4 or more
151   OP_REQUIRES(_ctx.at(input_binary_index).shape().rank() == 2);
152   OP_REQUIRES(_ctx.at(input_scales_index).shape().rank() == 1);
153   OP_REQUIRES(_ctx.at(input_clusters_index).shape().rank() == 2);
154
155   OP_REQUIRES(_ctx.at(input_clusters_index).shape().dim(0) > 0);
156   OP_REQUIRES(_ctx.at(input_clusters_index).shape().dim(1) == 2);
157
158   // more shape validation will be done inside kernel.
159 }
160
161 void ShapeValidator::visit(const ir::operation::Comparison &)
162 {
163   // TODO Shape validation of comparison
164 }
165
166 void ShapeValidator::visit(const ir::operation::Softmax &node)
167 {
168   const auto output_index{node.getOutputs().at(0)};
169   if (_ctx.at(output_index).info().isDynamic())
170     return;
171
172   const auto input_index{node.getInputs().at(0)};
173
174   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
175 }
176
177 void ShapeValidator::visit(const ir::operation::InstanceNorm &node)
178 {
179   const auto ofm_index{node.getOutputs().at(0)};
180   if (_ctx.at(ofm_index).info().isDynamic())
181     return;
182
183   const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
184   const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
185   const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
186
187   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
188   OP_REQUIRES(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape());
189   OP_REQUIRES(_ctx.at(gamma_index).shape().rank() == 1);
190   OP_REQUIRES(_ctx.at(beta_index).shape().rank() == 1);
191 }
192
193 void ShapeValidator::visit(const ir::operation::Pool2D &node)
194 {
195   const auto ofm_index{node.getOutputs().at(0)};
196   if (_ctx.at(ofm_index).info().isDynamic())
197     return;
198
199   const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
200
201   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
202 }
203
204 void ShapeValidator::visit(const ir::operation::Permute &node)
205 {
206   const auto output_index{node.getOutputs().at(0)};
207   if (_ctx.at(output_index).info().isDynamic())
208     return;
209
210   const auto input_index{node.getInputs().at(0)};
211
212   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
213 }
214
215 void ShapeValidator::visit(const ir::operation::Reduce &node)
216 {
217   const auto output_index{node.getOutputs().at(0)};
218   if (_ctx.at(output_index).info().isDynamic())
219     return;
220
221   const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
222   const auto input_shape = _ctx.at(input_index).shape();
223   const auto output_shape = _ctx.at(output_index).shape();
224
225   OP_REQUIRES(input_shape.rank() <= 4);
226   OP_REQUIRES(output_shape.rank() <= input_shape.rank());
227
228   // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
229   // supports cases reducing height and width or reducing depth.
230   // TODO We have to support all cases of dimensions up to 4.
231   // For correct permuting, we have to set output's shape to be equal in dimension position of the
232   // input. But the positions of the same dimensions in the input and output may be set differently.
233   // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
234   // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
235   // extend it in 4 dimensions, it should be {1,1,3,5}.
236   // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
237   // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
238   // next operation is not desired.
239   if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
240   {
241     if (output_shape.rank() == 2)
242     {
243       // Reducing HW
244       OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
245                   input_shape.dim(3) == output_shape.dim(1));
246     }
247     else if (output_shape.rank() == 3)
248     {
249       // Reducing C or
250       // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
251       OP_REQUIRES((input_shape.dim(0) == output_shape.dim(0) &&
252                    input_shape.dim(1) == output_shape.dim(1) &&
253                    input_shape.dim(2) == output_shape.dim(2)) ||
254                   (input_shape.dim(0) == output_shape.dim(0) &&
255                    (input_shape.dim(1) == output_shape.dim(1) ||
256                     input_shape.dim(2) == output_shape.dim(1)) &&
257                    input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
258     }
259   }
260 }
261
262 void ShapeValidator::visit(const ir::operation::Transpose &node)
263 {
264   const auto output_index{node.getOutputs().at(0)};
265   if (_ctx.at(output_index).info().isDynamic())
266     return;
267
268   const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
269   const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
270
271   const auto &output_shape = _ctx.at(output_index).shape();
272   const auto &input_shape = _ctx.at(input_index).shape();
273
274   OP_REQUIRES(_ctx.at(perm_index).shape().num_elements() == 0 ||
275               input_shape.rank() == static_cast<int>(_ctx.at(perm_index).shape().num_elements()));
276   OP_REQUIRES(input_shape.rank() == output_shape.rank());
277 }
278
279 void ShapeValidator::visit(const ir::operation::RNN &node)
280 {
281   // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
282   // TODO Support dynamic rnn
283   const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
284   if (_ctx.at(output_index).info().isDynamic())
285     return;
286
287   const auto hidden_state_out_index{
288       node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)};
289
290   const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
291   const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
292   const auto recurrent_weights_index{
293       node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)};
294   const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
295   const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
296
297   const auto batch_size = _ctx.at(output_index).shape().dim(0);
298   const auto num_units = _ctx.at(output_index).shape().dim(1);
299
300   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 2 &&
301               _ctx.at(hidden_state_out_index).shape().rank() == 2 &&
302               _ctx.at(input_index).shape().rank() == 2 &&
303               _ctx.at(weights_index).shape().rank() == 2 &&
304               _ctx.at(recurrent_weights_index).shape().rank() == 2 &&
305               _ctx.at(hidden_state_in_index).shape().rank() == 2);
306   OP_REQUIRES(_ctx.at(bias_index).shape().rank() == 1);
307
308   OP_REQUIRES(batch_size == _ctx.at(input_index).shape().dim(0) &&
309               batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) &&
310               batch_size == _ctx.at(hidden_state_out_index).shape().dim(0));
311   OP_REQUIRES(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1));
312
313   OP_REQUIRES(num_units == _ctx.at(weights_index).shape().dim(0) &&
314               num_units == _ctx.at(recurrent_weights_index).shape().dim(0) &&
315               num_units == _ctx.at(bias_index).shape().dim(0));
316   OP_REQUIRES(num_units == _ctx.at(output_index).shape().dim(1) &&
317               num_units == _ctx.at(recurrent_weights_index).shape().dim(1) &&
318               num_units == _ctx.at(hidden_state_in_index).shape().dim(1) &&
319               num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
320 }
321
322 void ShapeValidator::visit(const ir::operation::SpaceToBatchND &node)
323 {
324   const auto ofm_index{node.getOutputs().at(0)};
325   if (_ctx.at(ofm_index).info().isDynamic())
326     return;
327
328   const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
329   const auto block_size_index{
330       node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
331   const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
332
333   const auto frontend_layout = _current_op_seq_layout;
334   const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
335   const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
336
337   // All requirement as per NNAPI specification.
338   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
339   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
340   OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
341   OP_REQUIRES(_ctx.at(paddings_index).shape().rank() == 2);
342
343   OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
344   OP_REQUIRES(_ctx.at(paddings_index).shape().dim(0) == 2);
345   OP_REQUIRES(_ctx.at(paddings_index).shape().dim(1) == 2);
346
347   OP_REQUIRES(input_shape.C == output_shape.C);
348 }
349
350 void ShapeValidator::visit(const ir::operation::SpaceToDepth &node)
351 {
352   const auto ofm_index{node.getOutputs().at(0)};
353   if (_ctx.at(ofm_index).info().isDynamic())
354     return;
355
356   const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
357
358   const auto frontend_layout = _current_op_seq_layout;
359   const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
360   const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
361   const auto block_size = node.param().block_size;
362
363   // All assertions as per NNAPI specification.
364   OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
365   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
366   OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
367   OP_REQUIRES(input_shape.N == output_shape.N);
368   OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
369 }
370
371 void ShapeValidator::visit(const ir::operation::ElementwiseActivation &node) { checkUnaryOp(node); }
372
373 void ShapeValidator::visit(const ir::operation::ElementwiseBinary &)
374 {
375   // TODO Shape validation of ElementwiseBinary
376 }
377
378 void ShapeValidator::visit(const ir::operation::ElementwiseUnary &node)
379 {
380   const auto output_index{node.getOutputs().at(0)};
381   const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
382
383   if (_ctx.at(output_index).info().isDynamic())
384     return;
385
386   OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
387 }
388
389 void ShapeValidator::visit(const ir::operation::EmbeddingLookup &node)
390 {
391   const auto output_index{node.getOutputs().at(0)};
392   const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
393   const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
394
395   const auto &output_obj = _ctx.at(output_index);
396   const auto &lookups_obj = _ctx.at(lookups_index);
397   const auto &values_obj = _ctx.at(values_index);
398
399   // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
400   // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
401   {
402     if (_ctx.at(output_index).info().isDynamic())
403       return;
404
405     const auto &output_shape = output_obj.shape();
406     const auto &lookups_shape = lookups_obj.shape();
407     const auto &values_shape = values_obj.shape();
408
409     OP_REQUIRES(lookups_shape.rank() == 1);
410     OP_REQUIRES(values_shape.rank() >= 2);
411
412     // output should be a n-D tensor with the same rank and shape as the values tensor, except for
413     // the first dimension which has the same size as lookups' only dimension.
414     OP_REQUIRES(output_shape.rank() == values_shape.rank());
415     OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
416     for (int n = 1; n < output_shape.rank(); ++n)
417     {
418       OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
419     }
420   }
421 }
422
423 void ShapeValidator::visit(const ir::operation::ExpandDims &node)
424 {
425   const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
426
427   if (_ctx.at(axis_index).info().isDynamic())
428     return;
429   OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1);
430 }
431
432 void ShapeValidator::visit(const ir::operation::HashtableLookup &node)
433 {
434   const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
435   const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
436   const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
437   const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
438
439   const auto &output_obj = _ctx.at(output_index);
440   const auto &lookups_obj = _ctx.at(lookups_index);
441   const auto &keys_obj = _ctx.at(keys_index);
442   const auto &values_obj = _ctx.at(values_index);
443
444   if (_ctx.at(output_index).info().isDynamic())
445     return;
446
447   const auto &output_shape = output_obj.shape();
448   const auto &lookups_shape = lookups_obj.shape();
449   const auto &keys_shape = keys_obj.shape();
450   const auto &values_shape = values_obj.shape();
451
452   OP_REQUIRES(values_shape.rank() == output_shape.rank());
453   OP_REQUIRES(lookups_shape.rank() == 1);
454   OP_REQUIRES(keys_shape.rank() == 1);
455   OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
456   OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
457 }
458
459 void ShapeValidator::visit(const ir::operation::TransposeConv &node)
460 {
461   // shape check
462   const auto ofm_index{node.getOutputs().at(0)};
463   if (_ctx.at(ofm_index).info().isDynamic())
464     return;
465
466   const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
467   const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
468
469   // Only 4D tensors are supported
470   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
471   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ifm_index).shape().rank());
472   OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ker_index).shape().rank());
473
474   const auto frontend_layout = _current_op_seq_layout;
475   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
476   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
477   // The kernel has only IHWO layout on frontend
478   // So ker_shape is treated here below
479   // I -> N
480   // H -> H
481   // W -> W
482   // O -> C
483   const auto ker_shape = _ctx.at(ker_index).shape().asFeature(ir::Layout::NHWC);
484
485   OP_REQUIRES(ifm_shape.N == ofm_shape.N);
486   OP_REQUIRES(ifm_shape.C == ker_shape.C);
487   OP_REQUIRES(ker_shape.N == ofm_shape.C);
488 }
489
490 void ShapeValidator::visit(const ir::operation::Gather &node)
491 {
492   const auto ofm_index{node.getOutputs().at(0)};
493   if (_ctx.at(ofm_index).info().isDynamic())
494     return;
495
496   const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
497   const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
498
499   const auto ifm_shape = _ctx.at(ifm_index).shape();
500   const auto indices_shape = _ctx.at(indices_index).shape();
501   const auto ofm_shape = _ctx.at(ofm_index).shape();
502
503   OP_REQUIRES(ifm_shape.rank() <= 4);
504   OP_REQUIRES(indices_shape.rank() <= 3);
505   OP_REQUIRES(ofm_shape.rank() <= 4);
506 }
507
508 void ShapeValidator::visit(const ir::operation::DepthToSpace &node)
509 {
510   int32_t block_size = node.param().block_size;
511
512   // shape check
513   const auto output_index{node.getOutputs().at(0)};
514   if (_ctx.at(output_index).info().isDynamic())
515     return;
516
517   const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
518
519   const auto frontend_layout = _current_op_seq_layout;
520   const auto output_shape = _ctx.at(output_index).shape().asFeature(frontend_layout);
521   const auto input_shape = _ctx.at(input_index).shape().asFeature(frontend_layout);
522
523   OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
524   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
525
526   {
527     OP_REQUIRES(output_shape.N == input_shape.N);
528     OP_REQUIRES(output_shape.H == input_shape.H * block_size);
529     OP_REQUIRES(output_shape.W == input_shape.W * block_size);
530     OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
531     OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
532   }
533 }
534
535 void ShapeValidator::visit(const ir::operation::Pack &node)
536 {
537   const auto axis{node.param().axis};
538   const auto output_index{node.getOutputs().at(0)};
539   if (_ctx.at(output_index).info().isDynamic())
540     return;
541
542   // shape check
543   const auto &output_shape = _ctx.at(output_index).shape();
544   const auto output_rank = static_cast<int32_t>(output_shape.rank());
545
546   const auto input1_index{node.getInputs().at(0)};
547   const auto input_shape = _ctx.at(input1_index).shape();
548
549   OP_REQUIRES(axis >= -output_rank && axis < output_rank);
550   for (const auto &index : node.getInputs())
551   {
552     OP_REQUIRES(input_shape == _ctx.at(index).shape());
553   }
554 }
555
556 void ShapeValidator::visit(const ir::operation::LSTM &node)
557 {
558   // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
559   // TODO Support dynamic rnn
560   const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
561   if (_ctx.at(output_index).info().isDynamic())
562     return;
563
564   const auto scratch_buffer_index{
565       node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional
566   const auto output_state_out_index{
567       node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional
568   const auto cell_state_out_index{
569       node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional
570
571   const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
572   const auto input_to_input_weights_index{
573       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional
574   const auto input_to_forget_weights_index{
575       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
576   const auto input_to_cell_weights_index{
577       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
578   const auto input_to_output_weights_index{
579       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
580   const auto recurrent_to_input_weights_index{
581       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional
582   const auto recurrent_to_forget_weights_index{
583       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
584   const auto recurrent_to_cell_weights_index{
585       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
586   const auto recurrent_to_output_weights_index{
587       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
588   const auto cell_to_input_weights_index{
589       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional
590   const auto cell_to_forget_weights_index{
591       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional
592   const auto cell_to_output_weights_index{
593       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional
594   const auto input_gate_bias_index{
595       node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional
596   const auto forget_gate_bias_index{
597       node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
598   const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
599   const auto output_gate_bias_index{
600       node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
601   const auto projection_weights_index{
602       node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional
603   const auto projection_bias_index{
604       node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional
605   const auto output_state_in_index{
606       node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
607   const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
608
609   OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
610   for (int i = 0; i < _ctx.at(input_index).shape().rank() - 1; ++i)
611   {
612     OP_REQUIRES(_ctx.at(input_index).shape().dim(i) == _ctx.at(output_index).shape().dim(i));
613   }
614   OP_REQUIRES(
615       (_ctx.at(output_index).shape().rank() == 2 || _ctx.at(output_index).shape().rank() == 3) &&
616       (_ctx.at(input_index).shape().rank() == 2 || _ctx.at(input_index).shape().rank() == 3) &&
617       (!_ctx.exist(input_to_input_weights_index) ||
618        _ctx.at(input_to_input_weights_index).shape().rank() == 2) &&
619       _ctx.at(input_to_forget_weights_index).shape().rank() == 2 &&
620       _ctx.at(input_to_cell_weights_index).shape().rank() == 2 &&
621       _ctx.at(input_to_output_weights_index).shape().rank() == 2 &&
622       (!_ctx.exist(recurrent_to_input_weights_index) ||
623        _ctx.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
624       _ctx.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
625       _ctx.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
626       _ctx.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
627       (!_ctx.exist(projection_weights_index) ||
628        _ctx.at(projection_weights_index).shape().rank() == 2) &&
629       _ctx.at(output_state_in_index).shape().rank() == 2 &&
630       _ctx.at(cell_state_in_index).shape().rank() == 2);
631
632   OP_REQUIRES(
633       (!_ctx.exist(cell_to_input_weights_index) ||
634        _ctx.at(cell_to_input_weights_index).shape().rank() == 1) &&
635       (!_ctx.exist(cell_to_forget_weights_index) ||
636        _ctx.at(cell_to_forget_weights_index).shape().rank() == 1) &&
637       (!_ctx.exist(cell_to_output_weights_index) ||
638        _ctx.at(cell_to_output_weights_index).shape().rank() == 1) &&
639       (!_ctx.exist(input_gate_bias_index) || _ctx.at(input_gate_bias_index).shape().rank() == 1) &&
640       _ctx.at(forget_gate_bias_index).shape().rank() == 1 &&
641       _ctx.at(cell_bias_index).shape().rank() == 1 &&
642       _ctx.at(output_gate_bias_index).shape().rank() == 1 &&
643       (!_ctx.exist(projection_bias_index) || _ctx.at(projection_bias_index).shape().rank() == 1));
644
645   // CIFG assertion
646   OP_REQUIRES(
647       ((!_ctx.exist(input_to_input_weights_index) ||
648         (_ctx.at(input_to_input_weights_index).shape().dim(0) == 0 &&
649          _ctx.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
650        (!_ctx.exist(recurrent_to_input_weights_index) ||
651         (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
652          _ctx.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
653        (!_ctx.exist(input_gate_bias_index) || _ctx.at(input_gate_bias_index).shape().dim(0) == 0) &&
654        (!_ctx.exist(cell_to_input_weights_index) ||
655         _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
656       ((_ctx.exist(input_to_input_weights_index) &&
657         (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
658          _ctx.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
659        (_ctx.exist(recurrent_to_input_weights_index) &&
660         (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
661          _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
662        (_ctx.exist(input_gate_bias_index) && _ctx.at(input_gate_bias_index).shape().dim(0) != 0)));
663
664   // Peephole assertion
665   OP_REQUIRES(((!_ctx.exist(cell_to_forget_weights_index) ||
666                 _ctx.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
667                (!_ctx.exist(cell_to_output_weights_index) ||
668                 _ctx.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
669               ((_ctx.exist(cell_to_forget_weights_index) &&
670                 _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
671                (_ctx.exist(cell_to_output_weights_index) &&
672                 _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0)));
673
674   bool has_input_to_input_weights = _ctx.exist(input_to_input_weights_index) &&
675                                     (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
676                                      _ctx.at(input_to_input_weights_index).shape().dim(1) != 0);
677   bool has_recurrent_to_input_weights =
678       _ctx.exist(recurrent_to_input_weights_index) &&
679       (_ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
680        _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
681   bool has_input_gate_bias =
682       _ctx.exist(input_gate_bias_index) && _ctx.at(input_gate_bias_index).shape().dim(0) != 0;
683   bool has_cell_to_input_weights = _ctx.exist(cell_to_input_weights_index) &&
684                                    _ctx.at(cell_to_input_weights_index).shape().dim(0) != 0;
685   bool has_cell_to_forget_weights = _ctx.exist(cell_to_forget_weights_index) &&
686                                     _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0;
687   bool has_cell_to_output_weights = _ctx.exist(cell_to_output_weights_index) &&
688                                     _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0;
689   bool has_projection_weights = _ctx.exist(projection_weights_index) &&
690                                 (_ctx.at(projection_weights_index).shape().dim(0) != 0 &&
691                                  _ctx.at(projection_weights_index).shape().dim(1) != 0);
692   bool has_projection_bias =
693       _ctx.exist(projection_bias_index) && _ctx.at(projection_bias_index).shape().dim(0) != 0;
694
695   // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
696   // true: no CIFG
697   // false: CIFG
698   bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
699
700   // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
701   // true: peephole
702   // false: no peephole
703   bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
704
705   // NOTE The projection weights may have data but the projection bias may not.
706   bool has_projection_param = has_projection_weights;
707
708   const auto batch_size = (_ctx.at(input_index).shape().rank() == 3 && node.param().time_major)
709                               ? _ctx.at(input_index).shape().dim(1)
710                               : _ctx.at(input_index).shape().dim(0);
711   OP_REQUIRES(batch_size == _ctx.at(output_state_in_index).shape().dim(0) &&
712               batch_size == _ctx.at(cell_state_in_index).shape().dim(0));
713
714   const auto input_size = _ctx.at(input_index).shape().dim(_ctx.at(input_index).shape().rank() - 1);
715   OP_REQUIRES(input_size == _ctx.at(input_to_forget_weights_index).shape().dim(1) &&
716               input_size == _ctx.at(input_to_cell_weights_index).shape().dim(1) &&
717               input_size == _ctx.at(input_to_output_weights_index).shape().dim(1));
718
719   const auto num_units = _ctx.at(input_to_output_weights_index).shape().dim(0);
720   OP_REQUIRES(num_units == _ctx.at(input_to_cell_weights_index).shape().dim(0) &&
721               num_units == _ctx.at(input_to_output_weights_index).shape().dim(0) &&
722               num_units == _ctx.at(recurrent_to_forget_weights_index).shape().dim(0) &&
723               num_units == _ctx.at(recurrent_to_cell_weights_index).shape().dim(0) &&
724               num_units == _ctx.at(recurrent_to_output_weights_index).shape().dim(0) &&
725               num_units == _ctx.at(forget_gate_bias_index).shape().dim(0) &&
726               num_units == _ctx.at(cell_bias_index).shape().dim(0) &&
727               num_units == _ctx.at(output_gate_bias_index).shape().dim(0) &&
728               num_units == _ctx.at(cell_state_in_index).shape().dim(1));
729
730   const auto output_size =
731       _ctx.at(output_index).shape().dim(_ctx.at(output_index).shape().rank() - 1);
732   OP_REQUIRES(output_size == _ctx.at(recurrent_to_forget_weights_index).shape().dim(1) &&
733               output_size == _ctx.at(recurrent_to_cell_weights_index).shape().dim(1) &&
734               output_size == _ctx.at(recurrent_to_output_weights_index).shape().dim(1) &&
735               output_size == _ctx.at(output_state_in_index).shape().dim(1));
736
737   if (has_cifg_param)
738   {
739     OP_REQUIRES(input_size == _ctx.at(input_to_input_weights_index).shape().dim(1));
740     OP_REQUIRES(num_units == _ctx.at(input_to_input_weights_index).shape().dim(0) &&
741                 num_units == _ctx.at(recurrent_to_input_weights_index).shape().dim(0) &&
742                 ((_ctx.exist(cell_to_input_weights_index) &&
743                   num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0)) ||
744                  (!_ctx.exist(cell_to_input_weights_index) ||
745                   _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
746                 num_units == _ctx.at(input_gate_bias_index).shape().dim(0));
747     OP_REQUIRES(output_size == _ctx.at(recurrent_to_input_weights_index).shape().dim(1));
748     OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
749                 has_input_gate_bias);
750     if (has_cell_to_input_weights)
751     {
752       // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
753       OP_REQUIRES(has_peephole_param);
754     }
755     if (_ctx.exist(scratch_buffer_index))
756       OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
757   }
758   else
759   {
760     if (_ctx.exist(scratch_buffer_index))
761       OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
762   }
763
764   if (has_peephole_param)
765   {
766     OP_REQUIRES(num_units == _ctx.at(cell_to_forget_weights_index).shape().dim(0) &&
767                 num_units == _ctx.at(cell_to_output_weights_index).shape().dim(0) &&
768                 (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
769                  _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
770   }
771
772   if (has_projection_param)
773   {
774     OP_REQUIRES(num_units == _ctx.at(projection_weights_index).shape().dim(1));
775     OP_REQUIRES(output_size == _ctx.at(projection_weights_index).shape().dim(0));
776     if (has_projection_bias)
777     {
778       OP_REQUIRES(output_size == _ctx.at(projection_bias_index).shape().dim(0));
779     }
780   }
781
782   if (_ctx.exist(scratch_buffer_index))
783   {
784     OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().rank() == 2);
785     OP_REQUIRES(batch_size == _ctx.at(scratch_buffer_index).shape().dim(0));
786   }
787
788   if (_ctx.exist(output_state_out_index))
789   {
790     OP_REQUIRES(_ctx.at(output_state_out_index).shape().rank() == 2);
791     OP_REQUIRES(batch_size == _ctx.at(output_state_out_index).shape().dim(0));
792     OP_REQUIRES(output_size == _ctx.at(output_state_out_index).shape().dim(1));
793   }
794
795   if (_ctx.exist(cell_state_out_index))
796   {
797     OP_REQUIRES(_ctx.at(cell_state_out_index).shape().rank() == 2);
798     OP_REQUIRES(batch_size == _ctx.at(cell_state_out_index).shape().dim(0));
799     OP_REQUIRES(num_units == _ctx.at(cell_state_out_index).shape().dim(1));
800   }
801 }
802
803 void ShapeValidator::visit(const ir::operation::L2Normalization &node)
804 {
805   const auto ofm_index{node.getOutputs().at(0)};
806   if (_ctx.at(ofm_index).info().isDynamic())
807     return;
808
809   const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
810
811   auto ifm_shape = _ctx.at(ifm_index).shape();
812   auto ofm_shape = _ctx.at(ofm_index).shape();
813
814   OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
815
816   for (auto i = 0; i < ifm_shape.rank(); i++)
817   {
818     OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
819   }
820 }
821
822 void ShapeValidator::visit(const ir::operation::Unpack &node)
823 {
824   const auto axis{node.param().axis};
825   const auto output_index{node.getInputs().at(0)};
826   if (_ctx.at(output_index).info().isDynamic())
827     return;
828
829   const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
830
831   const auto &input_shape = _ctx.at(input_index).shape();
832   const auto input_rank = static_cast<int32_t>(input_shape.rank());
833
834   OP_REQUIRES(axis >= -input_rank && axis < input_rank);
835 }
836
837 void ShapeValidator::visit(const ir::operation::Pad &node)
838 {
839   const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
840   OP_REQUIRES(_ctx.at(pad_index).typeInfo().type() == ir::DataType::INT32);
841
842   const auto output_index{node.getInputs().at(0)};
843   if (_ctx.at(output_index).info().isDynamic())
844     return;
845
846   const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
847
848   const auto &pad_shape = _ctx.at(pad_index).shape();
849   const auto input_rank = static_cast<int32_t>(_ctx.at(input_index).shape().rank());
850
851   OP_REQUIRES(pad_shape.rank() == 2);
852   OP_REQUIRES(pad_shape.dim(0) == input_rank);
853   OP_REQUIRES(pad_shape.dim(1) == 2);
854   OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
855 }
856
857 void ShapeValidator::visit(const ir::operation::Select &)
858 {
859   // TODO Shape validation of select
860 }
861
862 void ShapeValidator::visit(const ir::operation::StridedSlice &node)
863 {
864   const auto output_index{node.getOutputs().at(0)};
865   const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
866
867   if (_ctx.at(output_index).info().isDynamic())
868     return;
869
870   OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4);
871 }
872
873 void ShapeValidator::visit(const ir::operation::Split &node)
874 {
875   const auto output_index{node.getOutputs().at(0)};
876   if (_ctx.at(output_index).info().isDynamic())
877     return;
878
879   const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
880   const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
881
882   const auto num_splits = node.param().num_splits;
883   const auto input_rank = _ctx.at(input_index).shape().rank();
884   auto axis = *reinterpret_cast<const int32_t *>(_ctx.at(axis_index).data()->base());
885   axis = axis < 0 ? axis + input_rank : axis;
886
887   OP_REQUIRES(axis >= 0 && axis < input_rank);
888   OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0);
889 }
890
891 void ShapeValidator::visit(const ir::operation::Shape &node)
892 {
893   const auto output_index{node.getOutputs().at(0)};
894   if (_ctx.at(output_index).info().isDynamic())
895     return;
896
897   const auto input_index{node.getInputs().at(0)};
898   UNUSED_RELEASE(input_index);
899   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
900 }
901
902 void ShapeValidator::visit(const ir::operation::ResizeBilinear &node)
903 {
904   const auto output_index{node.getOutputs().at(0)};
905   const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
906
907   if (_ctx.at(output_index).info().isDynamic())
908   {
909     return;
910   }
911   OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
912   OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
913 }
914
915 void ShapeValidator::visit(const ir::operation::Reverse &node)
916 {
917   const auto output_index{node.getOutputs().at(0)};
918   const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
919
920   if (_ctx.at(output_index).info().isDynamic())
921     return;
922   OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
923 }
924
925 void ShapeValidator::visit(const ir::operation::If &)
926 {
927   // TODO Add to validate with subgraphs
928 }
929
930 void ShapeValidator::visit(const ir::operation::While &)
931 {
932   // This validator does not check shape. So checking isDynamic() is skipped.
933   // TODO Add to validate with subgraphs
934 }
935
936 void ShapeValidator::visit(const ir::operation::SquaredDifference &node)
937 {
938   const auto output_index{node.getOutputs().at(0)};
939   const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
940   const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
941
942   // Check for dimension constraints
943   if (_ctx.at(output_index).info().isDynamic())
944     return;
945
946   auto output_shape = _ctx.at(output_index).shape();
947   auto lhs_shape = _ctx.at(lhs_index).shape();
948   auto rhs_shape = _ctx.at(rhs_index).shape();
949   // Check for output rank
950   OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
951   auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
952
953   for (int idx = 1; idx <= min_rank; idx++)
954   {
955     int l_idx = lhs_shape.rank() - idx;
956     int r_idx = rhs_shape.rank() - idx;
957     int out_idx = output_shape.rank() - idx;
958
959     OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
960
961     auto l_dims = lhs_shape.dim(l_idx);
962     auto r_dims = rhs_shape.dim(r_idx);
963     auto out_dims = output_shape.dim(out_idx);
964
965     OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
966                 ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
967   }
968   auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
969   for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
970   {
971     int out_idx = output_shape.rank() - idx;
972     int tmp_idx = tmp_shape.rank() - idx;
973
974     OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
975                 (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
976   }
977 }
978 void ShapeValidator::visit(const ir::operation::Tile &node)
979 {
980   const auto output_index{node.getOutputs().at(0)};
981   if (_ctx.at(output_index).info().isDynamic())
982     return;
983
984   const auto input_index{node.getInputs().at(0)};
985   const auto multiple_index{node.getInputs().at(1)};
986
987   OP_REQUIRES(_ctx.at(multiple_index).shape().rank() == 1);
988   OP_REQUIRES(_ctx.at(multiple_index).shape().dim(0) == _ctx.at(input_index).shape().rank());
989   OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
990 }
991
992 void ShapeValidator::visit(const ir::operation::Range &node)
993 {
994   const auto output_index{node.getOutputs().at(0)};
995   const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
996   const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
997   const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
998
999   // Check for dimension constraints
1000   if (_ctx.at(output_index).info().isDynamic())
1001     return;
1002
1003   OP_REQUIRES(_ctx.at(start_index).shape().rank() == 0);
1004   OP_REQUIRES(_ctx.at(limit_index).shape().rank() == 0);
1005   OP_REQUIRES(_ctx.at(delta_index).shape().rank() == 0);
1006 }
1007
1008 void ShapeValidator::visit(const ir::operation::MatrixBandPart &node)
1009 {
1010   const auto output_index{node.getOutputs().at(0)};
1011   const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
1012   const auto num_lower_index{
1013       node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)};
1014   const auto num_upper_index{
1015       node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)};
1016
1017   // Check for dimension constraints
1018   if (_ctx.at(output_index).info().isDynamic())
1019     return;
1020
1021   OP_REQUIRES(_ctx.at(input_index).shape().rank() >= 2);     // input must be more than 2 dim matrix
1022   OP_REQUIRES(_ctx.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
1023   OP_REQUIRES(_ctx.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
1024 }
1025
1026 void ShapeValidator::visit(const ir::operation::LogSoftmax &node)
1027 {
1028   const auto output_index{node.getOutputs().at(0)};
1029   if (_ctx.at(output_index).info().isDynamic())
1030     return;
1031
1032   const auto input_index{node.getInputs().at(0)};
1033
1034   OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
1035 }
1036
1037 } // namespace compiler
1038 } // namespace onert