cda718ca4d975cd949c05b7629351a4fd1d91ba4
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / region_yolo.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/frontend/frontend.hpp>
6
7 #include <vector>
8 #include <unordered_set>
9 #include <memory>
10 #include <set>
11
12 namespace vpu {
13
14 namespace {
15
16 class RegionYoloStage final : public StageNode {
17 private:
18     StagePtr cloneImpl() const override {
19         return std::make_shared<RegionYoloStage>(*this);
20     }
21
22     void propagateDataOrderImpl() const override {
23         IE_ASSERT(_inputEdges.size() == 1);
24         IE_ASSERT(_outputEdges.size() == 1);
25
26         auto output = _outputEdges[0]->output();
27
28         if (!attrs().get<bool>("doSoftMax")) {
29             _orderInfo.setOutput(_outputEdges[0], output->desc().dimsOrder().createMovedDim(Dim::C, 2));  // CHW
30         }
31     }
32
33     void getDataStridesRequirementsImpl() const override {
34         IE_ASSERT(_inputEdges.size() == 1);
35         IE_ASSERT(_outputEdges.size() == 1);
36
37         if (attrs().get<bool>("doSoftMax")) {
38             // Major dimension must be compact.
39             _stridesInfo.setInput(_inputEdges[0], StridesRequirement().add(2, DimStride::Compact));
40         }
41     }
42
43     void finalizeDataLayoutImpl() override {
44     }
45
46     void getBatchSupportInfoImpl() const override {
47         IE_ASSERT(_inputEdges.size() == 1);
48         IE_ASSERT(_outputEdges.size() == 1);
49
50         _batchInfo.setInput(_inputEdges[0], BatchSupport::Split);
51         _batchInfo.setOutput(_outputEdges[0], BatchSupport::Split);
52     }
53
54     void finalCheckImpl() const override {
55     }
56
57     void serializeParamsImpl(BlobSerializer& serializer) const override {
58         auto classes = attrs().get<int>("classes");
59         auto coords = attrs().get<int>("coords");
60         auto num = attrs().get<int>("num");
61         auto maskSize = attrs().get<int>("maskSize");
62         auto doSoftMax = attrs().get<bool>("doSoftMax");
63
64         serializer.append(static_cast<int32_t>(classes));
65         serializer.append(static_cast<int32_t>(coords));
66         serializer.append(static_cast<int32_t>(num));
67         serializer.append(static_cast<int32_t>(maskSize));
68         serializer.append(static_cast<int32_t>(doSoftMax));
69     }
70
71     void serializeDataImpl(BlobSerializer& serializer) const override {
72         IE_ASSERT(_inputEdges.size() == 1);
73         IE_ASSERT(_outputEdges.size() == 1);
74         IE_ASSERT(_tempBufferEdges.empty());
75
76         auto input = _inputEdges[0]->input();
77         auto output = _outputEdges[0]->output();
78
79         input->serializeOldBuffer(handle_from_this(), serializer);
80         output->serializeOldBuffer(handle_from_this(), serializer);
81     }
82 };
83
84 }  // namespace
85
86 void FrontEnd::parseRegionYolo(
87         const Model::Ptr& model,
88         const ie::CNNLayerPtr& layer,
89         const DataVector& inputs,
90         const DataVector& outputs) {
91     IE_ASSERT(inputs.size() == 1);
92     IE_ASSERT(outputs.size() == 1);
93
94     auto mask = layer->GetParamAsInts("mask", {});
95
96     auto stage = model->addNewStage<RegionYoloStage>(
97         layer->name,
98         StageType::RegionYolo,
99         layer,
100         inputs,
101         outputs);
102
103     stage->attrs().set<int>("classes", layer->GetParamAsInt("classes", 20));
104     stage->attrs().set<int>("coords", layer->GetParamAsInt("coords", 4));
105     stage->attrs().set<int>("num", layer->GetParamAsInt("num", 5));
106     stage->attrs().set<int>("maskSize", static_cast<int>(mask.size()));
107     stage->attrs().set<bool>("doSoftMax", layer->GetParamAsInt("do_softmax", 1));
108 }
109
110 }  // namespace vpu