2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2015 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "MaxPoolWithArgmax.h"
20 #include <flatbuffers/flexbuffers.h>
22 flatbuffers::Offset<void> MaxPoolWithArgmaxChef::value(flatbuffers::FlatBufferBuilder &fbb) const
24 return flatbuffers::Offset<void>();
27 flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
28 MaxPoolWithArgmaxChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
30 auto &operation = (*_operation);
32 assert(operation.type() == "MaxPoolWithArgmax");
35 * REGISTER_OP("MaxPoolWithArgmax")
36 .Attr("ksize: list(int) >= 4")
37 .Attr("strides: list(int) >= 4")
38 .Attr("Targmax: {int32, int64} = DT_INT64")
39 .Attr(GetPaddingAttrString())
40 .Attr("include_batch_in_index: bool = false")
43 .Output("argmax: Targmax")
44 .Attr("T: realnumbertype")
45 .SetShapeFn([](InferenceContext* c) {
46 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
47 c->set_output(1, c->output(0));
52 auto flex_buffers = std::make_unique<flexbuffers::Builder>();
53 size_t map_start = flex_buffers->StartMap();
55 auto start = flex_buffers->StartVector("ksize");
57 flex_buffers->Add(operation.max_pool_with_argmax_options().filter_width());
58 flex_buffers->Add(operation.max_pool_with_argmax_options().filter_height());
60 flex_buffers->EndVector(start, /*typed=*/true, /*fixed=*/false);
61 start = flex_buffers->StartVector("strides");
63 flex_buffers->Add(operation.max_pool_with_argmax_options().stride_w());
64 flex_buffers->Add(operation.max_pool_with_argmax_options().stride_h());
66 flex_buffers->EndVector(start, /*typed=*/true, /*fixed=*/false);
67 auto output_type = operation.max_pool_with_argmax_options().output_type();
68 assert(output_type == tflchef::INT64 || output_type == tflchef::INT32);
69 flex_buffers->Int("Targmax", output_type);
70 std::string padding = operation.max_pool_with_argmax_options().padding() ? "VALID" : "SAME";
71 flex_buffers->String("padding", padding);
72 flex_buffers->Bool("include_batch_in_index",
73 operation.max_pool_with_argmax_options().include_batch_in_index());
74 flex_buffers->Int("T", tflchef::FLOAT32);
75 flex_buffers->EndMap(map_start);
76 flex_buffers->Finish();
78 auto circle_custom_options = fbb.CreateVector(flex_buffers->GetBuffer());
79 return circle_custom_options;
82 std::unique_ptr<OpChef>
83 MaxPoolWithArgmaxChefFactory::create(const tflchef::Operation *operation) const
85 return std::unique_ptr<OpChef>{new MaxPoolWithArgmaxChef{operation}};