Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ResolveCustomOpMaxPoolWithArgmaxPass.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
18
19 #include <loco/IR/DataTypeTraits.h>
20
21 #include <luci/IR/CircleNodes.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
23
24 #include <loco.h>
25 #include <oops/InternalExn.h>
26
27 #include <flatbuffers/flexbuffers.h>
28
29 namespace
30 {
31
32 template <typename T> std::vector<T> to_vector(const flexbuffers::TypedVector &typed_vec)
33 {
34   std::vector<T> answer(typed_vec.size());
35
36   for (uint32_t i = 0; i < answer.size(); ++i)
37   {
38     answer[i] = typed_vec[i].As<T>();
39   }
40
41   return answer;
42 }
43
44 luci::Padding string_to_padding(const std::string &pad_str)
45 {
46   if (pad_str == "VALID")
47     return luci::Padding::VALID;
48   if (pad_str == "SAME")
49     return luci::Padding::SAME;
50
51   return luci::Padding::UNDEFINED;
52 }
53
54 template <typename NodeT> void set_stride(NodeT *node, const luci::Stride &stride)
55 {
56   node->stride()->h(stride.h());
57   node->stride()->w(stride.w());
58 }
59
60 template <typename NodeT> void set_filter(NodeT *node, const luci::Filter &filter)
61 {
62   node->filter()->h(filter.h());
63   node->filter()->w(filter.w());
64 }
65
66 void init_name_and_origin(luci::CircleNode *node, const std::string &name,
67                           const std::shared_ptr<luci::CircleNodeOrigin> &origin)
68 {
69   node->name(name);
70   luci::add_origin(node, origin);
71 }
72
73 template <typename NodeT> NodeT *none_act_func(NodeT *node)
74 {
75   node->fusedActivationFunction(luci::FusedActFunc::NONE);
76   return node;
77 }
78
79 luci::CircleCast *create_cast(luci::CircleNode *input, loco::DataType in_type,
80                               loco::DataType out_type)
81 {
82   auto cast = input->graph()->nodes()->create<luci::CircleCast>();
83
84   cast->in_data_type(in_type);
85   cast->out_data_type(out_type);
86   cast->dtype(out_type);
87
88   cast->x(input);
89
90   return cast;
91 }
92
93 template <loco::DataType DT> void fill_conv_weights(luci::CircleConst *weights)
94 {
95   assert(weights->rank() == 4);
96
97   auto const kn = weights->dim(0).value();
98   auto const kh = weights->dim(1).value();
99   auto const kw = weights->dim(2).value();
100
101   auto elements_size = kn * kh * kw * 1;
102   weights->size<DT>(elements_size);
103
104   for (uint32_t b = 0; b < kn; ++b)
105   {
106     for (uint32_t y = 0; y < kh; ++y)
107     {
108       for (uint32_t x = 0; x < kw; ++x)
109       {
110         auto const idx = (b * kh + y) * kw + x;
111         weights->at<DT>(idx) = (y * kw + x == b) ? 1 : 0;
112       }
113     }
114   }
115 }
116
117 luci::CircleConst *create_conv_filter(loco::Graph *graph, const uint32_t kh, const uint32_t kw,
118                                       const uint32_t kn)
119 {
120   auto weights = graph->nodes()->create<luci::CircleConst>();
121
122   weights->dtype(loco::DataType::FLOAT32);
123
124   weights->rank(4);
125   weights->dim(0).set(kn);
126   weights->dim(1).set(kh);
127   weights->dim(2).set(kw);
128   weights->dim(3).set(1);
129   weights->shape_status(luci::ShapeStatus::VALID);
130
131   fill_conv_weights<loco::DataType::FLOAT32>(weights);
132
133   return weights;
134 }
135
136 template <loco::DataType DT> void fill_zero_bias(luci::CircleConst *bias)
137 {
138   assert(bias->rank() == 1);
139
140   auto const depth = bias->dim(0).value();
141
142   bias->size<DT>(depth);
143
144   for (uint32_t i = 0; i < depth; ++i)
145   {
146     bias->at<DT>(i) = 0;
147   }
148 }
149
150 luci::CircleConst *create_zero_bias(loco::Graph *graph, uint32_t depth)
151 {
152   auto bias = graph->nodes()->create<luci::CircleConst>();
153
154   bias->dtype(loco::DataType::FLOAT32);
155
156   bias->rank(1);
157   bias->dim(0).set(depth);
158
159   fill_zero_bias<loco::DataType::FLOAT32>(bias);
160
161   return bias;
162 }
163
164 luci::CircleConst *create_padding_const(loco::Graph *graph, int32_t left_pad, int32_t right_pad,
165                                         int32_t top_pad, int32_t bottom_pad)
166 {
167   auto paddings = graph->nodes()->create<luci::CircleConst>();
168
169   paddings->dtype(loco::DataType::S32);
170
171   paddings->rank(2);
172   paddings->dim(0).set(4);
173   paddings->dim(1).set(2);
174   paddings->size<loco::DataType::S32>(8);
175   paddings->shape_status(luci::ShapeStatus::VALID);
176
177   paddings->at<loco::DataType::S32>(0) = 0;
178   paddings->at<loco::DataType::S32>(1) = 0;
179
180   paddings->at<loco::DataType::S32>(2) = left_pad;
181   paddings->at<loco::DataType::S32>(3) = right_pad;
182
183   paddings->at<loco::DataType::S32>(4) = top_pad;
184   paddings->at<loco::DataType::S32>(5) = bottom_pad;
185
186   paddings->at<loco::DataType::S32>(6) = 0;
187   paddings->at<loco::DataType::S32>(7) = 0;
188
189   return paddings;
190 }
191
192 template <loco::DataType DT, typename Numeric>
193 luci::CircleConst *create_scalar(loco::Graph *graph, Numeric value)
194 {
195   auto scalar = graph->nodes()->create<luci::CircleConst>();
196
197   scalar->dtype(DT);
198
199   scalar->rank(0);
200   scalar->size<DT>(1);
201   scalar->shape_status(luci::ShapeStatus::VALID);
202
203   scalar->scalar<DT>() = value;
204
205   return scalar;
206 }
207
208 luci::CircleConst *create_shape_tensor(loco::Graph *graph, const std::vector<uint32_t> &dims_vec)
209 {
210   auto shape = graph->nodes()->create<luci::CircleConst>();
211
212   shape->dtype(loco::DataType::S32);
213
214   shape->rank(1);
215   shape->dim(0).set(dims_vec.size());
216   shape->shape_status(luci::ShapeStatus::VALID);
217
218   shape->size<loco::DataType::S32>(dims_vec.size());
219
220   for (uint32_t i = 0; i < dims_vec.size(); ++i)
221   {
222     shape->at<loco::DataType::S32>(i) = dims_vec[i];
223   }
224
225   return shape;
226 }
227
228 int32_t compute_full_padding(int32_t input_size, int32_t output_size, int32_t stride,
229                              int32_t filter_size)
230 {
231   int32_t effective_input = (output_size - 1) * stride + filter_size;
232   int32_t full = effective_input - input_size;
233   // some extreme cases when part of input was not used in computations
234   if (full < 0)
235     full = 0;
236   return full;
237 }
238
239 template <loco::DataType DT>
240 void fill_coords_addition(luci::Padding padding, const luci::Stride &stride,
241                           const luci::Filter &filter, uint32_t input_height, uint32_t input_width,
242                           uint32_t depth, luci::CircleConst *cords)
243 {
244   assert(cords->rank() == 4);
245
246   auto const output_height = static_cast<int32_t>(cords->dim(1).value());
247   auto const output_width = static_cast<int32_t>(cords->dim(2).value());
248   {
249     auto const element_counts = 1 * output_height * output_width * 1;
250     cords->size<DT>(element_counts);
251   }
252
253   assert(padding != luci::Padding::UNDEFINED);
254
255   // For VALID padding:
256   int32_t start_y = 0;
257   int32_t start_x = 0;
258
259   // For SAME padding:
260   if (padding == luci::Padding::SAME)
261   {
262     start_y = -compute_full_padding(input_height, output_height, stride.h(), filter.h()) / 2;
263     start_x = -compute_full_padding(input_width, output_width, stride.w(), filter.w()) / 2;
264   }
265
266   auto const step_y = static_cast<int32_t>(stride.h());
267   auto const step_x = static_cast<int32_t>(stride.w());
268
269   for (int32_t y_o = 0, y_i = start_y; y_o < output_height; ++y_o, y_i += step_y)
270   {
271     for (int32_t x_o = 0, x_i = start_x; x_o < output_width; ++x_o, x_i += step_x)
272     {
273       auto const output_idx = y_o * output_width + x_o;
274       auto const input_idx = y_i * static_cast<int32_t>(input_width) + x_i;
275
276       // Add small adjustment value to fix cast operation result that follows "coord addition"
277       // in generated subgraph.
278       //
279       // Cast operation discards fractional part of value, so 1.9996 will be transformed to 1
280       // This is not a problem when working with float32, because it represents integers precisely,
281       // but leads to wrong results, when working with quantized numbers.
282       //
283       // This value is larger than quantization error,
284       // and small enough to not affect following computations
285       // (in particular multiplication with depth)
286       const float round_adjustment = 1.0f / (depth + 1);
287
288       cords->at<DT>(output_idx) = input_idx + round_adjustment;
289     }
290   }
291 }
292
293 luci::CircleConst *create_coords_addition(loco::Graph *graph, luci::Padding padding,
294                                           const luci::Stride &stride, const luci::Filter &filter,
295                                           uint32_t input_height, uint32_t input_width,
296                                           uint32_t depth, uint32_t output_height,
297                                           uint32_t output_width)
298 {
299   auto cords = graph->nodes()->create<luci::CircleConst>();
300
301   cords->dtype(loco::DataType::FLOAT32);
302
303   cords->rank(4);
304   cords->dim(0).set(1);
305   cords->dim(1).set(output_height);
306   cords->dim(2).set(output_width);
307   cords->dim(3).set(1);
308
309   fill_coords_addition<loco::DataType::FLOAT32>(padding, stride, filter, input_height, input_width,
310                                                 depth, cords);
311
312   return cords;
313 }
314
315 luci::CircleNode *get_custom_output(const luci::CircleCustom *cop, int32_t idx)
316 {
317   auto const outputs = loco::succs(cop);
318   assert(outputs.size() == 2);
319
320   auto output = loco::must_cast<luci::CircleCustomOut *>(*outputs.begin());
321   if (output->index() != idx)
322   {
323     output = loco::must_cast<luci::CircleCustomOut *>(*outputs.rbegin());
324   }
325
326   return output;
327 }
328
329 luci::CircleNode *max_pool_branch(luci::Padding padding, const luci::Stride &stride,
330                                   const luci::Filter filter, luci::CircleCustom *cop)
331 {
332   auto graph = cop->graph();
333   auto input = cop->inputs(0);
334
335   auto origin = luci::get_origin(cop);
336   auto name = cop->name() + "/Argmax";
337
338   // Create MaxPool
339   auto maxpool = none_act_func(graph->nodes()->create<luci::CircleMaxPool2D>());
340   {
341     init_name_and_origin(maxpool, name + "/MaxPool2D", origin);
342
343     set_stride(maxpool, stride);
344     set_filter(maxpool, filter);
345     maxpool->padding(padding);
346
347     maxpool->value(input);
348   }
349
350   return maxpool;
351 }
352
353 luci::CircleNode *window_flattened_coord(const std::string &name, luci::Padding padding,
354                                          const luci::Stride &stride, const luci::Filter filter,
355                                          int32_t input_height, int32_t input_width,
356                                          uint32_t output_height, uint32_t output_width,
357                                          luci::CircleNode *input)
358 {
359   auto const graph = input->graph();
360   auto const origin = luci::get_origin(input);
361
362   auto const depth_dimension = 3;
363
364   // Create pad in case of SAME padding
365   luci::CircleNode *conv_input = input;
366   if (padding == luci::Padding::SAME)
367   {
368     // Create redundant add to combine two nodes with special quantization restrictions:
369     // PadV2 and Split in this case
370     // TODO Introduce special requantize node and fix quantizer?
371     auto requantize = none_act_func(graph->nodes()->create<luci::CircleMul>());
372     init_name_and_origin(requantize, name + "/Requantize", origin);
373     auto zero_const = create_scalar<loco::DataType::FLOAT32>(graph, 1.0f);
374     init_name_and_origin(zero_const, name + "Requantize_const", origin);
375
376     requantize->x(input);
377     requantize->y(zero_const);
378
379     auto pad = graph->nodes()->create<luci::CirclePadV2>();
380     init_name_and_origin(pad, name + "/Pad", origin);
381
382     pad->input(requantize);
383
384     int32_t full_w_pad = compute_full_padding(input_width, output_width, stride.w(), filter.w());
385     int32_t full_h_pad = compute_full_padding(input_height, output_height, stride.h(), filter.h());
386     int32_t left_pad = full_w_pad / 2;
387     int32_t right_pad = full_w_pad - left_pad;
388     int32_t top_pad = full_h_pad / 2;
389     int32_t bottom_pad = full_h_pad - top_pad;
390     auto padding_const = create_padding_const(graph, left_pad, right_pad, top_pad, bottom_pad);
391     init_name_and_origin(padding_const, name + "/Pad_shape", origin);
392     pad->paddings(padding_const);
393
394     auto padding_value =
395       create_scalar<loco::DataType::FLOAT32, float>(graph, std::numeric_limits<float>::lowest());
396     init_name_and_origin(padding_value, name + "/Pad_value", origin);
397     pad->constant_values(padding_value);
398
399     conv_input = pad;
400   }
401   // Create Conv2D to move spatial dimensions to depth
402   auto conv = none_act_func(graph->nodes()->create<luci::CircleConv2D>());
403   {
404     init_name_and_origin(conv, name + "/Conv2D", origin);
405
406     // Padding, Stride and kernel size equal to MaxPool's
407     set_stride(conv, stride);
408     conv->padding(luci::Padding::VALID);
409
410     // depth of kernel is equal to square size
411     auto const kh = filter.h();
412     auto const kw = filter.w();
413     auto const kd = kh * kw;
414
415     // use zero bias
416     auto bias = create_zero_bias(graph, kd);
417     init_name_and_origin(bias, conv->name() + "/Bias", origin);
418
419     // create filter
420     // TODO make shared
421     auto weights = create_conv_filter(graph, kh, kw, kd);
422     init_name_and_origin(weights, conv->name() + "/Weights", origin);
423
424     conv->bias(bias);
425     conv->filter(weights);
426     conv->input(conv_input);
427   }
428
429   // Create ArgMax
430   auto argmax = graph->nodes()->create<luci::CircleArgMax>();
431   {
432     init_name_and_origin(argmax, name + "/ArgMax", origin);
433
434     argmax->output_type(loco::DataType::S32);
435
436     // Create argmax_dim
437     auto argmax_dim = create_scalar<loco::DataType::S32>(graph, depth_dimension);
438     init_name_and_origin(argmax_dim, argmax->name() + "/Dimension", origin);
439
440     argmax->dimension(argmax_dim);
441     argmax->input(conv);
442   }
443
444   // Create Reshape to 4-rank back, because argmax decrease rank of tensor by 1
445   auto reshape = graph->nodes()->create<luci::CircleReshape>();
446   {
447     init_name_and_origin(reshape, name + "/Reshape", origin);
448
449     auto shape = create_shape_tensor(graph, {1, output_height, output_width, 1});
450     init_name_and_origin(shape, reshape->name() + "/Shape", origin);
451
452     reshape->tensor(argmax);
453     reshape->shape(shape);
454   }
455
456   // Create Cast to use float32 instead int32
457   auto argmax_cast = create_cast(reshape, loco::DataType::S32, loco::DataType::FLOAT32);
458   init_name_and_origin(argmax_cast, argmax->name() + "/Cast", origin);
459
460   return argmax_cast;
461 }
462
463 // Creates "identity operation" after Floor
464 // to force circle-quantizer requantize output tensor with scale << 1.
465 //
466 // Dealing with values of extremely different scales
467 // in following binary operations hurts backend precision.
468 luci::CircleNode *create_post_floor_requantize_node(luci::CircleFloor *floor)
469 {
470   auto graph = floor->graph();
471   auto const origin = luci::get_origin(floor);
472   auto name = floor->name();
473
474   // Use DepthwiseConv2D with identity filter as an "identity operation".
475   //
476   // This operation do not change values, but forces circle-quantizer to use
477   // statistics to compute qparam scale instead of fixed scale == 1.0 after floor.
478   // DepthwiseConv2d is not eliminated by optimizations,
479   // so desired scale will reach backend.
480   auto requantizer = none_act_func(graph->nodes()->create<luci::CircleDepthwiseConv2D>());
481   init_name_and_origin(requantizer, name + "/Requantizer", origin);
482
483   requantizer->input(floor);
484
485   auto requantizer_filter = create_scalar<loco::DataType::FLOAT32>(graph, 1.0f);
486   init_name_and_origin(requantizer_filter, name + "/Requantizer/filter", origin);
487   requantizer_filter->rank(4);
488   for (uint32_t i = 0; i < 4; ++i)
489   {
490     requantizer_filter->dim(i) = 1;
491   }
492   requantizer->filter(requantizer_filter);
493
494   auto requantizer_bias = create_zero_bias(graph, 1);
495   init_name_and_origin(requantizer_bias, name + "/Requantizer/bias", origin);
496   requantizer->bias(requantizer_bias);
497
498   requantizer->padding(luci::Padding::VALID);
499   requantizer->stride()->w(1);
500   requantizer->stride()->h(1);
501   requantizer->depthMultiplier(1);
502   requantizer->dilation()->w(1);
503   requantizer->dilation()->h(1);
504
505   return requantizer;
506 }
507
508 luci::CircleNode *window_y_coord(const std::string &name, const luci::Filter &filter,
509                                  luci::CircleNode *flattened)
510 {
511   auto const graph = flattened->graph();
512   auto const origin = luci::get_origin(flattened);
513
514   auto div = none_act_func(graph->nodes()->create<luci::CircleMul>());
515   {
516     init_name_and_origin(div, name + "/Div", origin);
517
518     // Adjustment_coeff is needed to fix computation of quantized tensors
519     //
520     // For example float32 value 2.0 could be quantized to 1.996
521     // after floor it will be transformed to 1.0, but desired answer is still something close to 2.0
522     //
523     // rounding_adjustment is chosen so it is small enough to not affect float32 computations,
524     // but "Div" change is larger then potential quantization error.
525     //
526     // This computation exploits the fact that div is an x coord in maxpool window,
527     // and lies in defined range [0, filter.h())
528     const float rounding_adjustment = 1.0f / (filter.w() * filter.h());
529     const float divider_value = filter.w() - rounding_adjustment;
530     auto divider = create_scalar<loco::DataType::FLOAT32>(graph, 1.0f / divider_value);
531     init_name_and_origin(divider, div->name() + "/Divider", origin);
532
533     div->x(flattened);
534     div->y(divider);
535   }
536
537   auto floor = graph->nodes()->create<luci::CircleFloor>();
538   {
539     init_name_and_origin(floor, name + "/Floor", origin);
540     floor->x(div);
541   }
542
543   auto requantizer = create_post_floor_requantize_node(floor);
544
545   return requantizer;
546 }
547
548 luci::CircleNode *window_x_coord(const std::string &name, float filter_width,
549                                  luci::CircleNode *flattened, luci::CircleNode *y_coord)
550 {
551   auto const graph = flattened->graph();
552   auto const origin = luci::get_origin(flattened);
553
554   auto mod = none_act_func(graph->nodes()->create<luci::CircleAdd>());
555   {
556     init_name_and_origin(mod, name + "/Mod", origin);
557
558     auto neg = graph->nodes()->create<luci::CircleNeg>();
559     {
560       init_name_and_origin(neg, mod->name() + "/Neg", origin);
561
562       auto mul = none_act_func(graph->nodes()->create<luci::CircleMul>());
563       {
564         init_name_and_origin(mul, neg->name() + "/Neg", origin);
565
566         auto multipler = create_scalar<loco::DataType::FLOAT32>(graph, filter_width);
567         init_name_and_origin(multipler, mul->name() + "/Multipler", origin);
568
569         mul->x(y_coord);
570         mul->y(multipler);
571       }
572
573       neg->x(mul);
574     }
575
576     mod->x(flattened);
577     mod->y(neg);
578   }
579
580   return mod;
581 }
582
583 luci::CircleNode *plane_flattened_coord(const std::string &name, uint32_t input_width,
584                                         luci::CircleNode *y_coord, luci::CircleNode *x_coord,
585                                         luci::CircleNode *corners)
586 {
587   auto const graph = corners->graph();
588   auto const origin = luci::get_origin(corners);
589
590   auto add = none_act_func(graph->nodes()->create<luci::CircleAdd>());
591   {
592     init_name_and_origin(add, name + "/Add", origin);
593
594     auto addition = none_act_func(graph->nodes()->create<luci::CircleAdd>());
595     {
596       init_name_and_origin(addition, add->name() + "/Add", origin);
597
598       auto y_addition = none_act_func(graph->nodes()->create<luci::CircleMul>());
599       {
600         init_name_and_origin(y_addition, addition->name() + "/Mul", origin);
601
602         auto width_scalar = create_scalar<loco::DataType::FLOAT32>(graph, input_width);
603         init_name_and_origin(width_scalar, y_addition->name() + "/Const", origin);
604
605         y_addition->x(y_coord);
606         y_addition->y(width_scalar);
607       }
608
609       addition->x(x_coord);
610       addition->y(y_addition);
611     }
612
613     add->x(addition);
614     add->y(corners);
615   }
616
617   return add;
618 }
619
620 luci::CircleNode *volume_flattened_coords(const std::string &name, uint32_t channel,
621                                           uint32_t input_depth, luci::CircleNode *plane)
622 {
623   auto const graph = plane->graph();
624   auto const origin = luci::get_origin(plane);
625
626   // Create Mul
627   auto mul = none_act_func(graph->nodes()->create<luci::CircleMul>());
628   {
629     init_name_and_origin(mul, name + "/Mul", origin);
630
631     auto depth_scalar = create_scalar<loco::DataType::FLOAT32>(graph, input_depth);
632     init_name_and_origin(depth_scalar, mul->name() + "/Const", origin);
633
634     mul->x(plane);
635     mul->y(depth_scalar);
636   }
637
638   luci::CircleNode *volume = mul;
639
640   // Add channel number to output
641   if (channel > 0)
642   {
643     // Create Add
644     auto add_ch = none_act_func(graph->nodes()->create<luci::CircleAdd>());
645     init_name_and_origin(add_ch, name + "/Add_Channel", origin);
646
647     auto channel_scalar = create_scalar<loco::DataType::FLOAT32>(graph, channel);
648     init_name_and_origin(channel_scalar, add_ch->name() + "/Const", origin);
649
650     add_ch->x(mul);
651     add_ch->y(channel_scalar);
652
653     volume = add_ch;
654   }
655
656   return volume;
657 }
658
659 luci::CircleNode *argmax_branch(luci::Padding padding, const luci::Stride &stride,
660                                 const luci::Filter filter, luci::CircleCustom *cop)
661 {
662   auto graph = cop->graph();
663   auto input = loco::must_cast<luci::CircleNode *>(cop->inputs(0));
664   auto output = get_custom_output(cop, 1);
665
666   auto const depth_dimension = 3;
667   auto const input_depth = input->dim(depth_dimension).value();
668   auto const input_height = input->dim(1).value();
669   auto const input_width = input->dim(2).value();
670
671   assert(output->rank() == 4);
672   auto const output_height = output->dim(1).value();
673   auto const output_width = output->dim(2).value();
674
675   auto origin = luci::get_origin(cop);
676   auto name = cop->name() + "/Argmax";
677
678   // Create Split
679   auto split = graph->nodes()->create<luci::CircleSplit>();
680   {
681     init_name_and_origin(split, name + "/Split", origin);
682
683     // Create split_dim
684     auto split_dim = create_scalar<loco::DataType::S32>(graph, depth_dimension);
685     init_name_and_origin(split_dim, split->name() + "/Dim", origin);
686
687     split->num_split(int32_t(input_depth));
688
689     split->split_dim(split_dim);
690     split->input(input);
691   }
692
693   /**
694    * Note: we need define idx from input_tensor of maximum element in MaxPool's sliding window.
695    * For this we split input tensor by channels, define idx in sliding window and convert this idx
696    * to idx from source input_tensor using FloorDiv, Mul and Add operations with constant tensors.
697    */
698   std::vector<luci::CircleNode *> branch_outputs(input_depth);
699
700   for (uint32_t br_n = 0; br_n < input_depth; ++br_n)
701   {
702     auto const branch_name = name + "/depth_" + std::to_string(br_n);
703
704     // Create CircleSplitOut
705     auto split_out = graph->nodes()->create<luci::CircleSplitOut>();
706     init_name_and_origin(split_out, branch_name + "/SplitOut", origin);
707     split_out->index(int32_t(br_n));
708     split_out->input(split);
709
710     // Define idx of max element in Window:
711     auto window_coords =
712       window_flattened_coord(branch_name + "/WindowFlat", padding, stride, filter, input_height,
713                              input_width, output_height, output_width, split_out);
714
715     auto const window_y = window_y_coord(branch_name + "/WindowY", filter, window_coords);
716     auto const window_x =
717       window_x_coord(branch_name + "/WindowX", filter.w(), window_coords, window_y);
718
719     // Define idx of max element in Plane
720     // This tensor contains coords of left top corners for each window from input tensor
721     auto corners = create_coords_addition(graph, padding, stride, filter, input_height, input_width,
722                                           input_depth, output_height, output_width);
723     init_name_and_origin(corners, branch_name + "/Const", origin);
724
725     auto plane_coord =
726       plane_flattened_coord(branch_name + "/PlaneFlat", input_width, window_y, window_x, corners);
727
728     // Define volume coords as final value
729     branch_outputs[br_n] =
730       volume_flattened_coords(branch_name + "/VolumeFlat", br_n, input_depth, plane_coord);
731   }
732
733   // Create Concatenation
734   auto concat = none_act_func(graph->nodes()->create<luci::CircleConcatenation>(input_depth));
735   {
736     init_name_and_origin(concat, name + "/Concatenation", origin);
737     concat->axis(depth_dimension);
738
739     for (uint32_t i = 0; i < input_depth; ++i)
740     {
741       concat->values(i, branch_outputs[i]);
742     }
743   }
744
745   // Output of argmax_with_maxpool should be S64 or S32
746   loco::DataType output_dtype = get_custom_output(cop, 1)->dtype();
747   auto output_cast = create_cast(concat, loco::DataType::FLOAT32, output_dtype);
748   init_name_and_origin(output_cast, name + "/Cast", origin);
749
750   return output_cast;
751 }
752
753 bool resolve_max_pool_with_argmax(luci::CircleCustom *cop)
754 {
755 #define CHECK_OR_FALSE(condition) \
756   if (not(condition))             \
757     return false;
758
759   const std::vector<uint8_t> custom_options = cop->custom_options();
760   auto map = flexbuffers::GetRoot(custom_options).AsMap();
761
762   // Define params
763   // Note: Only `Targmax` equal to DT_INT64 is supported by tflite converter
764   // Note: Only `data_format` equal to "NHWC" is supported by tflite converter
765   // TODO add support of `include_batch_in_index` param
766   auto ksize_param = to_vector<uint32_t>(map["ksize"].AsTypedVector());
767   auto strides_param = to_vector<uint32_t>(map["strides"].AsTypedVector());
768   auto padding_param = map["padding"].As<std::string>();
769
770   // Batch size and depth of ksize more than 1 is not supported.
771   CHECK_OR_FALSE(ksize_param.size() == 4);
772   CHECK_OR_FALSE(ksize_param[0] == 1 && ksize_param[3] == 1);
773
774   CHECK_OR_FALSE(strides_param.size() == 4);
775   CHECK_OR_FALSE(strides_param[0] == 1 && strides_param[3] == 1);
776
777   // define Padding
778   auto padding = string_to_padding(padding_param);
779
780   // define Filter
781   luci::Filter filter;
782   filter.h(ksize_param[1]);
783   filter.w(ksize_param[2]);
784
785   // define Stride
786   luci::Stride stride;
787   stride.h(strides_param[1]);
788   stride.w(strides_param[2]);
789
790   // input node
791   auto const input = loco::must_cast<luci::CircleNode *>(cop->inputs(0));
792   CHECK_OR_FALSE(input->dtype() == loco::DataType::FLOAT32);
793   CHECK_OR_FALSE(input->rank() == 4);
794
795   // TODO support batch size > 1 and `include_batch_in_index` option
796   CHECK_OR_FALSE(input->dim(0).value() == 1);
797
798   // output nodes
799   auto const outputs = loco::succs(cop);
800   CHECK_OR_FALSE(outputs.size() == 2);
801   assert(outputs.size() == cop->numOutputs());
802
803   auto output0 = get_custom_output(cop, 0);
804   auto output1 = get_custom_output(cop, 1);
805
806   // From TF documentation: output of maxpool must has same type as input
807   assert(output0->dtype() == input->dtype());
808   assert(output1->dtype() == loco::DataType::S64 || output1->dtype() == loco::DataType::S32);
809
810   // Create MaxPool
811   auto maxpool = max_pool_branch(padding, stride, filter, cop);
812   auto argmax = argmax_branch(padding, stride, filter, cop);
813
814   // last argmax branch op is cast, it should have dtype initialized
815   assert(argmax->dtype() == output1->dtype());
816
817   // replace old node with new subgraph
818   cop->inputs(0, nullptr);
819   loco::replace(output0).with(maxpool);
820   loco::replace(output1).with(argmax);
821
822   return true;
823 }
824
825 } // namespace
826
827 namespace luci
828 {
829
830 /**
831  * BEFORE
832  *                 |
833  *            [CircleNode]
834  *                 |
835  *     [CUSTOM(MaxPoolWithArgmax)]
836  *         |              |
837  *  [MaxPool output]  [Argmax output]
838  *
839  * AFTER
840  *                         |
841  *                    [CircleNode]
842  *                    /          \
843  *       [Split over channels]  [MaxPool2D]
844  *         /       |      \              \
845  *   [Requantize] ...     ...      [MaxPool output]
846  *         |
847  *      [PadV2]
848  *         |
849  *      [Conv2D]
850  *         |
851  *      [ArgMax]
852  *         |
853  *    [Reshape to 4d]
854  *         |
855  *  [Cast to float32]
856  *    /        |
857  *   |  [Mul 1/<window width>]
858  *   |                \
859  *   |              [Floor]
860  *   |                 |
861  *   |    [DepthwiseConv2D for requantize]
862  *   |              /     \
863  *   | [Mul window width] |
864  *   \       /           /
865  *    \   [Neg] [Mul input width]
866  *     \   /    /
867  *     [Add]   /
868  *         \  /
869  *        [Add]
870  *          |
871  *     [Add const]
872  *           |
873  * [Mul number of channels]
874  *             \
875  * [Optional Add with channels id]   ...  ...
876  *                            \      |     /
877  *                           [Concatenation]
878  *                                 |
879  *                           [Cast to int]
880  *                                 |
881  *                          [Argmax output]
882  */
883 bool ResolveCustomOpMaxPoolWithArgmaxPass::run(loco::Graph *g)
884 {
885   bool changed = false;
886   for (auto node : loco::active_nodes(loco::output_nodes(g)))
887   {
888     auto cop = dynamic_cast<luci::CircleCustom *>(node);
889     if (not cop)
890       continue;
891
892     if (cop->custom_code() != "MaxPoolWithArgmax")
893       continue;
894
895     if (!resolve_max_pool_with_argmax(cop))
896       continue;
897
898     changed = true;
899   }
900
901   return changed;
902 }
903
904 } // namespace luci