1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <vpu/frontend/frontend.hpp>
11 #include <unordered_set>
17 VPU_DECLARE_ENUM(ROIPoolingMethod,
24 class ROIPoolingStage final : public StageNode {
26 StagePtr cloneImpl() const override {
27 return std::make_shared<ROIPoolingStage>(*this);
30 void propagateDataOrderImpl() const override {
31 IE_ASSERT(_inputEdges.size() == 2);
32 IE_ASSERT(_outputEdges.size() == 1);
34 auto input0 = _inputEdges[0]->input();
35 auto output = _outputEdges[0]->output();
37 _orderInfo.setInput(_inputEdges[0], input0->desc().dimsOrder().createMovedDim(Dim::C, 2));
38 _orderInfo.setOutput(_outputEdges[0], output->desc().dimsOrder().createMovedDim(Dim::C, 2));
41 void getDataStridesRequirementsImpl() const override {
42 IE_ASSERT(_inputEdges.size() == 2);
43 IE_ASSERT(_outputEdges.size() == 1);
45 _stridesInfo.setInput(_inputEdges[0], StridesRequirement::compact());
46 _stridesInfo.setInput(_inputEdges[1], StridesRequirement::compact());
47 _stridesInfo.setOutput(_outputEdges[0], StridesRequirement::compact());
50 void finalizeDataLayoutImpl() override {
53 void getBatchSupportInfoImpl() const override {
56 void finalCheckImpl() const override {
59 void serializeParamsImpl(BlobSerializer& serializer) const override {
60 auto pooled_w = attrs().get<int>("pooled_w");
61 auto pooled_h = attrs().get<int>("pooled_h");
62 auto spatial_scale = attrs().get<float>("spatial_scale");
63 auto method = attrs().get<ROIPoolingMethod>("method");
65 serializer.append(static_cast<uint32_t>(pooled_w));
66 serializer.append(static_cast<uint32_t>(pooled_h));
67 serializer.append(static_cast<float>(spatial_scale));
68 serializer.append(static_cast<uint32_t>(method));
71 void serializeDataImpl(BlobSerializer& serializer) const override {
72 IE_ASSERT(_inputEdges.size() == 2);
73 IE_ASSERT(_outputEdges.size() == 1);
74 IE_ASSERT(_tempBufferEdges.empty());
76 auto input0 = _inputEdges[0]->input();
77 auto input1 = _inputEdges[1]->input();
78 auto output = _outputEdges[0]->output();
80 input0->serializeNewBuffer(serializer);
81 output->serializeNewBuffer(serializer);
82 input1->serializeNewBuffer(serializer);
88 void FrontEnd::parseROIPooling(
89 const Model::Ptr& model,
90 const ie::CNNLayerPtr& layer,
91 const DataVector& inputs,
92 const DataVector& outputs) {
93 ie::details::CaselessEq<std::string> cmp;
95 IE_ASSERT(inputs.size() == 2);
96 IE_ASSERT(outputs.size() == 1);
98 auto stage = model->addNewStage<ROIPoolingStage>(
100 StageType::ROIPooling,
105 stage->attrs().set<int>("pooled_w", layer->GetParamAsInt("pooled_w", 7));
106 stage->attrs().set<int>("pooled_h", layer->GetParamAsInt("pooled_h", 7));
107 stage->attrs().set<float>("spatial_scale", layer->GetParamAsFloat("spatial_scale", 0.0625f));
109 auto method = layer->GetParamAsString("method", "max");
110 if (cmp(method, "bilinear")) {
111 stage->attrs().set("method", ROIPoolingMethod::Bilinear);
113 stage->attrs().set("method", ROIPoolingMethod::Max);