2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
19 #include <loco/IR/DataTypeTraits.h>
21 #include <luci/IR/CircleNodes.h>
22 #include <luci/Profile/CircleNodeOrigin.h>
25 #include <oops/InternalExn.h>
27 #include <flatbuffers/flexbuffers.h>
32 template <typename T> std::vector<T> to_vector(const flexbuffers::TypedVector &typed_vec)
34 std::vector<T> answer(typed_vec.size());
36 for (uint32_t i = 0; i < answer.size(); ++i)
38 answer[i] = typed_vec[i].As<T>();
44 luci::Padding string_to_padding(const std::string &pad_str)
46 if (pad_str == "VALID")
47 return luci::Padding::VALID;
48 if (pad_str == "SAME")
49 return luci::Padding::SAME;
51 return luci::Padding::UNDEFINED;
54 template <typename NodeT> void set_stride(NodeT *node, const luci::Stride &stride)
56 node->stride()->h(stride.h());
57 node->stride()->w(stride.w());
60 template <typename NodeT> void set_filter(NodeT *node, const luci::Filter &filter)
62 node->filter()->h(filter.h());
63 node->filter()->w(filter.w());
66 void init_name_and_origin(luci::CircleNode *node, const std::string &name,
67 const std::shared_ptr<luci::CircleNodeOrigin> &origin)
70 luci::add_origin(node, origin);
73 template <typename NodeT> NodeT *none_act_func(NodeT *node)
75 node->fusedActivationFunction(luci::FusedActFunc::NONE);
79 luci::CircleCast *create_cast(luci::CircleNode *input, loco::DataType in_type,
80 loco::DataType out_type)
82 auto cast = input->graph()->nodes()->create<luci::CircleCast>();
84 cast->in_data_type(in_type);
85 cast->out_data_type(out_type);
86 cast->dtype(out_type);
93 template <loco::DataType DT> void fill_conv_weights(luci::CircleConst *weights)
95 assert(weights->rank() == 4);
97 auto const kn = weights->dim(0).value();
98 auto const kh = weights->dim(1).value();
99 auto const kw = weights->dim(2).value();
101 auto elements_size = kn * kh * kw * 1;
102 weights->size<DT>(elements_size);
104 for (uint32_t b = 0; b < kn; ++b)
106 for (uint32_t y = 0; y < kh; ++y)
108 for (uint32_t x = 0; x < kw; ++x)
110 auto const idx = (b * kh + y) * kw + x;
111 weights->at<DT>(idx) = (y * kw + x == b) ? 1 : 0;
117 luci::CircleConst *create_conv_filter(loco::Graph *graph, const uint32_t kh, const uint32_t kw,
120 auto weights = graph->nodes()->create<luci::CircleConst>();
122 weights->dtype(loco::DataType::FLOAT32);
125 weights->dim(0).set(kn);
126 weights->dim(1).set(kh);
127 weights->dim(2).set(kw);
128 weights->dim(3).set(1);
129 weights->shape_status(luci::ShapeStatus::VALID);
131 fill_conv_weights<loco::DataType::FLOAT32>(weights);
136 template <loco::DataType DT> void fill_zero_bias(luci::CircleConst *bias)
138 assert(bias->rank() == 1);
140 auto const depth = bias->dim(0).value();
142 bias->size<DT>(depth);
144 for (uint32_t i = 0; i < depth; ++i)
150 luci::CircleConst *create_zero_bias(loco::Graph *graph, uint32_t depth)
152 auto bias = graph->nodes()->create<luci::CircleConst>();
154 bias->dtype(loco::DataType::FLOAT32);
157 bias->dim(0).set(depth);
159 fill_zero_bias<loco::DataType::FLOAT32>(bias);
164 luci::CircleConst *create_padding_const(loco::Graph *graph, int32_t left_pad, int32_t right_pad,
165 int32_t top_pad, int32_t bottom_pad)
167 auto paddings = graph->nodes()->create<luci::CircleConst>();
169 paddings->dtype(loco::DataType::S32);
172 paddings->dim(0).set(4);
173 paddings->dim(1).set(2);
174 paddings->size<loco::DataType::S32>(8);
175 paddings->shape_status(luci::ShapeStatus::VALID);
177 paddings->at<loco::DataType::S32>(0) = 0;
178 paddings->at<loco::DataType::S32>(1) = 0;
180 paddings->at<loco::DataType::S32>(2) = left_pad;
181 paddings->at<loco::DataType::S32>(3) = right_pad;
183 paddings->at<loco::DataType::S32>(4) = top_pad;
184 paddings->at<loco::DataType::S32>(5) = bottom_pad;
186 paddings->at<loco::DataType::S32>(6) = 0;
187 paddings->at<loco::DataType::S32>(7) = 0;
192 template <loco::DataType DT, typename Numeric>
193 luci::CircleConst *create_scalar(loco::Graph *graph, Numeric value)
195 auto scalar = graph->nodes()->create<luci::CircleConst>();
201 scalar->shape_status(luci::ShapeStatus::VALID);
203 scalar->scalar<DT>() = value;
208 luci::CircleConst *create_shape_tensor(loco::Graph *graph, const std::vector<uint32_t> &dims_vec)
210 auto shape = graph->nodes()->create<luci::CircleConst>();
212 shape->dtype(loco::DataType::S32);
215 shape->dim(0).set(dims_vec.size());
216 shape->shape_status(luci::ShapeStatus::VALID);
218 shape->size<loco::DataType::S32>(dims_vec.size());
220 for (uint32_t i = 0; i < dims_vec.size(); ++i)
222 shape->at<loco::DataType::S32>(i) = dims_vec[i];
228 int32_t compute_full_padding(int32_t input_size, int32_t output_size, int32_t stride,
231 int32_t effective_input = (output_size - 1) * stride + filter_size;
232 int32_t full = effective_input - input_size;
233 // some extreme cases when part of input was not used in computations
239 template <loco::DataType DT>
240 void fill_coords_addition(luci::Padding padding, const luci::Stride &stride,
241 const luci::Filter &filter, uint32_t input_height, uint32_t input_width,
242 uint32_t depth, luci::CircleConst *cords)
244 assert(cords->rank() == 4);
246 auto const output_height = static_cast<int32_t>(cords->dim(1).value());
247 auto const output_width = static_cast<int32_t>(cords->dim(2).value());
249 auto const element_counts = 1 * output_height * output_width * 1;
250 cords->size<DT>(element_counts);
253 assert(padding != luci::Padding::UNDEFINED);
255 // For VALID padding:
260 if (padding == luci::Padding::SAME)
262 start_y = -compute_full_padding(input_height, output_height, stride.h(), filter.h()) / 2;
263 start_x = -compute_full_padding(input_width, output_width, stride.w(), filter.w()) / 2;
266 auto const step_y = static_cast<int32_t>(stride.h());
267 auto const step_x = static_cast<int32_t>(stride.w());
269 for (int32_t y_o = 0, y_i = start_y; y_o < output_height; ++y_o, y_i += step_y)
271 for (int32_t x_o = 0, x_i = start_x; x_o < output_width; ++x_o, x_i += step_x)
273 auto const output_idx = y_o * output_width + x_o;
274 auto const input_idx = y_i * static_cast<int32_t>(input_width) + x_i;
276 // Add small adjustment value to fix cast operation result that follows "coord addition"
277 // in generated subgraph.
279 // Cast operation discards fractional part of value, so 1.9996 will be transformed to 1
280 // This is not a problem when working with float32, because it represents integers precisely,
281 // but leads to wrong results, when working with quantized numbers.
283 // This value is larger than quantization error,
284 // and small enough to not affect following computations
285 // (in particular multiplication with depth)
286 const float round_adjustment = 1.0f / (depth + 1);
288 cords->at<DT>(output_idx) = input_idx + round_adjustment;
293 luci::CircleConst *create_coords_addition(loco::Graph *graph, luci::Padding padding,
294 const luci::Stride &stride, const luci::Filter &filter,
295 uint32_t input_height, uint32_t input_width,
296 uint32_t depth, uint32_t output_height,
297 uint32_t output_width)
299 auto cords = graph->nodes()->create<luci::CircleConst>();
301 cords->dtype(loco::DataType::FLOAT32);
304 cords->dim(0).set(1);
305 cords->dim(1).set(output_height);
306 cords->dim(2).set(output_width);
307 cords->dim(3).set(1);
309 fill_coords_addition<loco::DataType::FLOAT32>(padding, stride, filter, input_height, input_width,
315 luci::CircleNode *get_custom_output(const luci::CircleCustom *cop, int32_t idx)
317 auto const outputs = loco::succs(cop);
318 assert(outputs.size() == 2);
320 auto output = loco::must_cast<luci::CircleCustomOut *>(*outputs.begin());
321 if (output->index() != idx)
323 output = loco::must_cast<luci::CircleCustomOut *>(*outputs.rbegin());
329 luci::CircleNode *max_pool_branch(luci::Padding padding, const luci::Stride &stride,
330 const luci::Filter filter, luci::CircleCustom *cop)
332 auto graph = cop->graph();
333 auto input = cop->inputs(0);
335 auto origin = luci::get_origin(cop);
336 auto name = cop->name() + "/Argmax";
339 auto maxpool = none_act_func(graph->nodes()->create<luci::CircleMaxPool2D>());
341 init_name_and_origin(maxpool, name + "/MaxPool2D", origin);
343 set_stride(maxpool, stride);
344 set_filter(maxpool, filter);
345 maxpool->padding(padding);
347 maxpool->value(input);
353 luci::CircleNode *window_flattened_coord(const std::string &name, luci::Padding padding,
354 const luci::Stride &stride, const luci::Filter filter,
355 int32_t input_height, int32_t input_width,
356 uint32_t output_height, uint32_t output_width,
357 luci::CircleNode *input)
359 auto const graph = input->graph();
360 auto const origin = luci::get_origin(input);
362 auto const depth_dimension = 3;
364 // Create pad in case of SAME padding
365 luci::CircleNode *conv_input = input;
366 if (padding == luci::Padding::SAME)
368 // Create redundant add to combine two nodes with special quantization restrictions:
369 // PadV2 and Split in this case
370 // TODO Introduce special requantize node and fix quantizer?
371 auto requantize = none_act_func(graph->nodes()->create<luci::CircleMul>());
372 init_name_and_origin(requantize, name + "/Requantize", origin);
373 auto zero_const = create_scalar<loco::DataType::FLOAT32>(graph, 1.0f);
374 init_name_and_origin(zero_const, name + "Requantize_const", origin);
376 requantize->x(input);
377 requantize->y(zero_const);
379 auto pad = graph->nodes()->create<luci::CirclePadV2>();
380 init_name_and_origin(pad, name + "/Pad", origin);
382 pad->input(requantize);
384 int32_t full_w_pad = compute_full_padding(input_width, output_width, stride.w(), filter.w());
385 int32_t full_h_pad = compute_full_padding(input_height, output_height, stride.h(), filter.h());
386 int32_t left_pad = full_w_pad / 2;
387 int32_t right_pad = full_w_pad - left_pad;
388 int32_t top_pad = full_h_pad / 2;
389 int32_t bottom_pad = full_h_pad - top_pad;
390 auto padding_const = create_padding_const(graph, left_pad, right_pad, top_pad, bottom_pad);
391 init_name_and_origin(padding_const, name + "/Pad_shape", origin);
392 pad->paddings(padding_const);
395 create_scalar<loco::DataType::FLOAT32, float>(graph, std::numeric_limits<float>::lowest());
396 init_name_and_origin(padding_value, name + "/Pad_value", origin);
397 pad->constant_values(padding_value);
401 // Create Conv2D to move spatial dimensions to depth
402 auto conv = none_act_func(graph->nodes()->create<luci::CircleConv2D>());
404 init_name_and_origin(conv, name + "/Conv2D", origin);
406 // Padding, Stride and kernel size equal to MaxPool's
407 set_stride(conv, stride);
408 conv->padding(luci::Padding::VALID);
410 // depth of kernel is equal to square size
411 auto const kh = filter.h();
412 auto const kw = filter.w();
413 auto const kd = kh * kw;
416 auto bias = create_zero_bias(graph, kd);
417 init_name_and_origin(bias, conv->name() + "/Bias", origin);
421 auto weights = create_conv_filter(graph, kh, kw, kd);
422 init_name_and_origin(weights, conv->name() + "/Weights", origin);
425 conv->filter(weights);
426 conv->input(conv_input);
430 auto argmax = graph->nodes()->create<luci::CircleArgMax>();
432 init_name_and_origin(argmax, name + "/ArgMax", origin);
434 argmax->output_type(loco::DataType::S32);
437 auto argmax_dim = create_scalar<loco::DataType::S32>(graph, depth_dimension);
438 init_name_and_origin(argmax_dim, argmax->name() + "/Dimension", origin);
440 argmax->dimension(argmax_dim);
444 // Create Reshape to 4-rank back, because argmax decrease rank of tensor by 1
445 auto reshape = graph->nodes()->create<luci::CircleReshape>();
447 init_name_and_origin(reshape, name + "/Reshape", origin);
449 auto shape = create_shape_tensor(graph, {1, output_height, output_width, 1});
450 init_name_and_origin(shape, reshape->name() + "/Shape", origin);
452 reshape->tensor(argmax);
453 reshape->shape(shape);
456 // Create Cast to use float32 instead int32
457 auto argmax_cast = create_cast(reshape, loco::DataType::S32, loco::DataType::FLOAT32);
458 init_name_and_origin(argmax_cast, argmax->name() + "/Cast", origin);
463 // Creates "identity operation" after Floor
464 // to force circle-quantizer requantize output tensor with scale << 1.
466 // Dealing with values of extremely different scales
467 // in following binary operations hurts backend precision.
468 luci::CircleNode *create_post_floor_requantize_node(luci::CircleFloor *floor)
470 auto graph = floor->graph();
471 auto const origin = luci::get_origin(floor);
472 auto name = floor->name();
474 // Use DepthwiseConv2D with identity filter as an "identity operation".
476 // This operation do not change values, but forces circle-quantizer to use
477 // statistics to compute qparam scale instead of fixed scale == 1.0 after floor.
478 // DepthwiseConv2d is not eliminated by optimizations,
479 // so desired scale will reach backend.
480 auto requantizer = none_act_func(graph->nodes()->create<luci::CircleDepthwiseConv2D>());
481 init_name_and_origin(requantizer, name + "/Requantizer", origin);
483 requantizer->input(floor);
485 auto requantizer_filter = create_scalar<loco::DataType::FLOAT32>(graph, 1.0f);
486 init_name_and_origin(requantizer_filter, name + "/Requantizer/filter", origin);
487 requantizer_filter->rank(4);
488 for (uint32_t i = 0; i < 4; ++i)
490 requantizer_filter->dim(i) = 1;
492 requantizer->filter(requantizer_filter);
494 auto requantizer_bias = create_zero_bias(graph, 1);
495 init_name_and_origin(requantizer_bias, name + "/Requantizer/bias", origin);
496 requantizer->bias(requantizer_bias);
498 requantizer->padding(luci::Padding::VALID);
499 requantizer->stride()->w(1);
500 requantizer->stride()->h(1);
501 requantizer->depthMultiplier(1);
502 requantizer->dilation()->w(1);
503 requantizer->dilation()->h(1);
508 luci::CircleNode *window_y_coord(const std::string &name, const luci::Filter &filter,
509 luci::CircleNode *flattened)
511 auto const graph = flattened->graph();
512 auto const origin = luci::get_origin(flattened);
514 auto div = none_act_func(graph->nodes()->create<luci::CircleMul>());
516 init_name_and_origin(div, name + "/Div", origin);
518 // Adjustment_coeff is needed to fix computation of quantized tensors
520 // For example float32 value 2.0 could be quantized to 1.996
521 // after floor it will be transformed to 1.0, but desired answer is still something close to 2.0
523 // rounding_adjustment is chosen so it is small enough to not affect float32 computations,
524 // but "Div" change is larger then potential quantization error.
526 // This computation exploits the fact that div is an x coord in maxpool window,
527 // and lies in defined range [0, filter.h())
528 const float rounding_adjustment = 1.0f / (filter.w() * filter.h());
529 const float divider_value = filter.w() - rounding_adjustment;
530 auto divider = create_scalar<loco::DataType::FLOAT32>(graph, 1.0f / divider_value);
531 init_name_and_origin(divider, div->name() + "/Divider", origin);
537 auto floor = graph->nodes()->create<luci::CircleFloor>();
539 init_name_and_origin(floor, name + "/Floor", origin);
543 auto requantizer = create_post_floor_requantize_node(floor);
548 luci::CircleNode *window_x_coord(const std::string &name, float filter_width,
549 luci::CircleNode *flattened, luci::CircleNode *y_coord)
551 auto const graph = flattened->graph();
552 auto const origin = luci::get_origin(flattened);
554 auto mod = none_act_func(graph->nodes()->create<luci::CircleAdd>());
556 init_name_and_origin(mod, name + "/Mod", origin);
558 auto neg = graph->nodes()->create<luci::CircleNeg>();
560 init_name_and_origin(neg, mod->name() + "/Neg", origin);
562 auto mul = none_act_func(graph->nodes()->create<luci::CircleMul>());
564 init_name_and_origin(mul, neg->name() + "/Neg", origin);
566 auto multipler = create_scalar<loco::DataType::FLOAT32>(graph, filter_width);
567 init_name_and_origin(multipler, mul->name() + "/Multipler", origin);
583 luci::CircleNode *plane_flattened_coord(const std::string &name, uint32_t input_width,
584 luci::CircleNode *y_coord, luci::CircleNode *x_coord,
585 luci::CircleNode *corners)
587 auto const graph = corners->graph();
588 auto const origin = luci::get_origin(corners);
590 auto add = none_act_func(graph->nodes()->create<luci::CircleAdd>());
592 init_name_and_origin(add, name + "/Add", origin);
594 auto addition = none_act_func(graph->nodes()->create<luci::CircleAdd>());
596 init_name_and_origin(addition, add->name() + "/Add", origin);
598 auto y_addition = none_act_func(graph->nodes()->create<luci::CircleMul>());
600 init_name_and_origin(y_addition, addition->name() + "/Mul", origin);
602 auto width_scalar = create_scalar<loco::DataType::FLOAT32>(graph, input_width);
603 init_name_and_origin(width_scalar, y_addition->name() + "/Const", origin);
605 y_addition->x(y_coord);
606 y_addition->y(width_scalar);
609 addition->x(x_coord);
610 addition->y(y_addition);
620 luci::CircleNode *volume_flattened_coords(const std::string &name, uint32_t channel,
621 uint32_t input_depth, luci::CircleNode *plane)
623 auto const graph = plane->graph();
624 auto const origin = luci::get_origin(plane);
627 auto mul = none_act_func(graph->nodes()->create<luci::CircleMul>());
629 init_name_and_origin(mul, name + "/Mul", origin);
631 auto depth_scalar = create_scalar<loco::DataType::FLOAT32>(graph, input_depth);
632 init_name_and_origin(depth_scalar, mul->name() + "/Const", origin);
635 mul->y(depth_scalar);
638 luci::CircleNode *volume = mul;
640 // Add channel number to output
644 auto add_ch = none_act_func(graph->nodes()->create<luci::CircleAdd>());
645 init_name_and_origin(add_ch, name + "/Add_Channel", origin);
647 auto channel_scalar = create_scalar<loco::DataType::FLOAT32>(graph, channel);
648 init_name_and_origin(channel_scalar, add_ch->name() + "/Const", origin);
651 add_ch->y(channel_scalar);
659 luci::CircleNode *argmax_branch(luci::Padding padding, const luci::Stride &stride,
660 const luci::Filter filter, luci::CircleCustom *cop)
662 auto graph = cop->graph();
663 auto input = loco::must_cast<luci::CircleNode *>(cop->inputs(0));
664 auto output = get_custom_output(cop, 1);
666 auto const depth_dimension = 3;
667 auto const input_depth = input->dim(depth_dimension).value();
668 auto const input_height = input->dim(1).value();
669 auto const input_width = input->dim(2).value();
671 assert(output->rank() == 4);
672 auto const output_height = output->dim(1).value();
673 auto const output_width = output->dim(2).value();
675 auto origin = luci::get_origin(cop);
676 auto name = cop->name() + "/Argmax";
679 auto split = graph->nodes()->create<luci::CircleSplit>();
681 init_name_and_origin(split, name + "/Split", origin);
684 auto split_dim = create_scalar<loco::DataType::S32>(graph, depth_dimension);
685 init_name_and_origin(split_dim, split->name() + "/Dim", origin);
687 split->num_split(int32_t(input_depth));
689 split->split_dim(split_dim);
694 * Note: we need define idx from input_tensor of maximum element in MaxPool's sliding window.
695 * For this we split input tensor by channels, define idx in sliding window and convert this idx
696 * to idx from source input_tensor using FloorDiv, Mul and Add operations with constant tensors.
698 std::vector<luci::CircleNode *> branch_outputs(input_depth);
700 for (uint32_t br_n = 0; br_n < input_depth; ++br_n)
702 auto const branch_name = name + "/depth_" + std::to_string(br_n);
704 // Create CircleSplitOut
705 auto split_out = graph->nodes()->create<luci::CircleSplitOut>();
706 init_name_and_origin(split_out, branch_name + "/SplitOut", origin);
707 split_out->index(int32_t(br_n));
708 split_out->input(split);
710 // Define idx of max element in Window:
712 window_flattened_coord(branch_name + "/WindowFlat", padding, stride, filter, input_height,
713 input_width, output_height, output_width, split_out);
715 auto const window_y = window_y_coord(branch_name + "/WindowY", filter, window_coords);
716 auto const window_x =
717 window_x_coord(branch_name + "/WindowX", filter.w(), window_coords, window_y);
719 // Define idx of max element in Plane
720 // This tensor contains coords of left top corners for each window from input tensor
721 auto corners = create_coords_addition(graph, padding, stride, filter, input_height, input_width,
722 input_depth, output_height, output_width);
723 init_name_and_origin(corners, branch_name + "/Const", origin);
726 plane_flattened_coord(branch_name + "/PlaneFlat", input_width, window_y, window_x, corners);
728 // Define volume coords as final value
729 branch_outputs[br_n] =
730 volume_flattened_coords(branch_name + "/VolumeFlat", br_n, input_depth, plane_coord);
733 // Create Concatenation
734 auto concat = none_act_func(graph->nodes()->create<luci::CircleConcatenation>(input_depth));
736 init_name_and_origin(concat, name + "/Concatenation", origin);
737 concat->axis(depth_dimension);
739 for (uint32_t i = 0; i < input_depth; ++i)
741 concat->values(i, branch_outputs[i]);
745 // Output of argmax_with_maxpool should be S64 or S32
746 loco::DataType output_dtype = get_custom_output(cop, 1)->dtype();
747 auto output_cast = create_cast(concat, loco::DataType::FLOAT32, output_dtype);
748 init_name_and_origin(output_cast, name + "/Cast", origin);
753 bool resolve_max_pool_with_argmax(luci::CircleCustom *cop)
755 #define CHECK_OR_FALSE(condition) \
756 if (not(condition)) \
759 const std::vector<uint8_t> custom_options = cop->custom_options();
760 auto map = flexbuffers::GetRoot(custom_options).AsMap();
763 // Note: Only `Targmax` equal to DT_INT64 is supported by tflite converter
764 // Note: Only `data_format` equal to "NHWC" is supported by tflite converter
765 // TODO add support of `include_batch_in_index` param
766 auto ksize_param = to_vector<uint32_t>(map["ksize"].AsTypedVector());
767 auto strides_param = to_vector<uint32_t>(map["strides"].AsTypedVector());
768 auto padding_param = map["padding"].As<std::string>();
770 // Batch size and depth of ksize more than 1 is not supported.
771 CHECK_OR_FALSE(ksize_param.size() == 4);
772 CHECK_OR_FALSE(ksize_param[0] == 1 && ksize_param[3] == 1);
774 CHECK_OR_FALSE(strides_param.size() == 4);
775 CHECK_OR_FALSE(strides_param[0] == 1 && strides_param[3] == 1);
778 auto padding = string_to_padding(padding_param);
782 filter.h(ksize_param[1]);
783 filter.w(ksize_param[2]);
787 stride.h(strides_param[1]);
788 stride.w(strides_param[2]);
791 auto const input = loco::must_cast<luci::CircleNode *>(cop->inputs(0));
792 CHECK_OR_FALSE(input->dtype() == loco::DataType::FLOAT32);
793 CHECK_OR_FALSE(input->rank() == 4);
795 // TODO support batch size > 1 and `include_batch_in_index` option
796 CHECK_OR_FALSE(input->dim(0).value() == 1);
799 auto const outputs = loco::succs(cop);
800 CHECK_OR_FALSE(outputs.size() == 2);
801 assert(outputs.size() == cop->numOutputs());
803 auto output0 = get_custom_output(cop, 0);
804 auto output1 = get_custom_output(cop, 1);
806 // From TF documentation: output of maxpool must has same type as input
807 assert(output0->dtype() == input->dtype());
808 assert(output1->dtype() == loco::DataType::S64 || output1->dtype() == loco::DataType::S32);
811 auto maxpool = max_pool_branch(padding, stride, filter, cop);
812 auto argmax = argmax_branch(padding, stride, filter, cop);
814 // last argmax branch op is cast, it should have dtype initialized
815 assert(argmax->dtype() == output1->dtype());
817 // replace old node with new subgraph
818 cop->inputs(0, nullptr);
819 loco::replace(output0).with(maxpool);
820 loco::replace(output1).with(argmax);
835 * [CUSTOM(MaxPoolWithArgmax)]
837 * [MaxPool output] [Argmax output]
843 * [Split over channels] [MaxPool2D]
845 * [Requantize] ... ... [MaxPool output]
857 * | [Mul 1/<window width>]
861 * | [DepthwiseConv2D for requantize]
863 * | [Mul window width] |
865 * \ [Neg] [Mul input width]
873 * [Mul number of channels]
875 * [Optional Add with channels id] ... ...
883 bool ResolveCustomOpMaxPoolWithArgmaxPass::run(loco::Graph *g)
885 bool changed = false;
886 for (auto node : loco::active_nodes(loco::output_nodes(g)))
888 auto cop = dynamic_cast<luci::CircleCustom *>(node);
892 if (cop->custom_code() != "MaxPoolWithArgmax")
895 if (!resolve_max_pool_with_argmax(cop))