Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / rnn.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/frontend/frontend.hpp>
6
7 #include <vector>
8 #include <string>
9 #include <memory>
10 #include <set>
11
12 #include <vpu/utils/numeric.hpp>
13
14 namespace vpu {
15
16 namespace {
17
18 class LSTMCellStage final : public StageNode {
19 private:
20     StagePtr cloneImpl() const override {
21         return std::make_shared<LSTMCellStage>(*this);
22     }
23
24     DataMap<float> propagateScaleFactorsImpl(
25             const DataMap<float>&,
26             ScalePropagationStep) override {
27         IE_ASSERT(_inputEdges.size() == 5);
28         IE_ASSERT(_outputEdges.size() == 2);
29
30         DataMap<float> out;
31
32         for (const auto& inEdge : _inputEdges) {
33             out[inEdge->input()] = 1.0f;
34         }
35         for (const auto& outEdge : _outputEdges) {
36             out[outEdge->output()] = 1.0f;
37         }
38
39         return out;
40     }
41
42     DataMap<DimsOrder> propagateDataOrderImpl() const override {
43         IE_ASSERT(_inputEdges.size() == 5);
44         IE_ASSERT(_outputEdges.size() == 2);
45
46         auto output = _outputEdges[0]->output();
47         auto input = _inputEdges[0]->input();
48
49         DimsOrder inputDimsOrder = input->desc().dimsOrder();
50         DimsOrder outputDimsOrder = output->desc().dimsOrder();
51
52         if (inputDimsOrder.numDims() >= 3) inputDimsOrder.moveDim(Dim::C, 2);  // ->...CHW
53         if (outputDimsOrder.numDims() >= 3) outputDimsOrder.moveDim(Dim::C, 2);
54
55         DataMap<DimsOrder> out;
56         out[input] = inputDimsOrder;
57         out[output] = outputDimsOrder;
58
59         return out;
60     }
61
62     DataMap<StridesRequirement> getDataStridesRequirementsImpl() const override {
63         IE_ASSERT(_inputEdges.size() == 5);
64         IE_ASSERT(_outputEdges.size() == 2);
65
66         DataMap<StridesRequirement> out;
67
68         for (const auto& inEdge : _inputEdges) {
69             out[inEdge->input()] = StridesRequirement::compact();
70         }
71         for (const auto& outEdge : _outputEdges) {
72             out[outEdge->output()] = StridesRequirement::compact();
73         }
74
75         return out;
76     }
77
78     void finalizeDataLayoutImpl() override {
79     }
80
81     DataMap<BatchSupport> getBatchSupportInfoImpl() const override {
82         return DataMap<BatchSupport>();
83     }
84
85     void finalCheckImpl() const override {
86     }
87
88     void serializeParamsImpl(BlobSerializer& serializer) const override {
89         auto RNNForward = attrs().get<bool>("RNNForward");
90         auto nCells = attrs().get<int>("nCells");
91         auto nBatches = attrs().get<int>("nBatches");
92         serializer.append(static_cast<int>(RNNForward));
93         serializer.append(static_cast<int>(nCells));
94         serializer.append(static_cast<int>(nBatches));
95     }
96
97     void serializeDataImpl(BlobSerializer& serializer) const override {
98         IE_ASSERT(_inputEdges.size() == 5);
99         IE_ASSERT(_outputEdges.size() == 2);
100
101         int nCells = attrs().get<int>("nCells");
102
103         bool useTempBuffer = (nCells > 1);
104         IE_ASSERT((_tempBufferEdges.size() == 1 && useTempBuffer) || !useTempBuffer);
105
106         for (const auto& inEdge : _inputEdges) {
107             inEdge->input()->serializeNewBuffer(serializer);
108         }
109         for (const auto& outEdge : _outputEdges) {
110             outEdge->output()->serializeNewBuffer(serializer);
111         }
112
113         if (useTempBuffer)
114             _tempBufferEdges[0]->tempBuffer()->serializeNewBuffer(serializer);
115     }
116 };
117
118 }  // namespace
119
120 static void RNNRelayout(
121                  const fp16_t* src,
122                  fp16_t* dst0,
123                  fp16_t* dst1,
124
125                  const int ngates,
126                  const int state_size,
127                  const int input_size
128                 ) {
129     int counter = 0;
130     for (int j = 0; j < ngates * state_size; j++) {
131         for (int i = 0; i < input_size; i++) {
132             dst0[(input_size) * j + i] = src[counter++];
133         }
134         for (int i = 0; i < state_size; i++) {
135             dst1[(state_size) * j + i] = src[counter++];
136         }
137     }
138 }
139
140 void FrontEnd::parseRNN(
141         const Model::Ptr& model,
142         const ie::CNNLayerPtr& _layer,
143         const DataVector &inputs,
144         const DataVector &outputs) {
145     IE_ASSERT(inputs.size() == 3);
146     IE_ASSERT(outputs.size() == 1);
147
148     auto layer = std::dynamic_pointer_cast<ie::RNNSequenceLayer>(_layer);
149     IE_ASSERT(layer != nullptr);
150
151     const int ngates = 4;
152
153     Data weights, biases;
154     std::tie(weights, biases) = getWeightsAndBiases(model, layer);
155
156     size_t nCells = inputs[0]->desc().dim(Dim::H);
157     size_t nBatches = inputs[0]->desc().dim(Dim::C);
158     IE_ASSERT(nCells >= 1);
159     IE_ASSERT(nBatches >= 1);
160
161     size_t input_size = inputs[0]->desc().dim(Dim::W);
162     IE_ASSERT(input_size == inputs[0]->desc().totalDimSize() / nCells / nBatches);
163
164     size_t state_size = inputs[1]->desc().totalDimSize() / nBatches;
165     size_t cell_state_size = inputs[2]->desc().totalDimSize() / nBatches;
166     IE_ASSERT(state_size == cell_state_size);
167
168     size_t weightsSize = weights->desc().totalDimSize();
169     IE_ASSERT(state_size * (input_size + state_size) * ngates == weightsSize);
170
171     size_t biasesSize = biases->desc().totalDimSize();
172     IE_ASSERT(state_size * ngates == biasesSize);
173
174     /* weights repacking */
175     auto newWeightsBlob = ie::make_shared_blob<fp16_t>(ie::Precision::FP16, ie::Layout::C, {weightsSize});
176     newWeightsBlob->allocate();
177     auto newWeightsPtr = newWeightsBlob->buffer().as<fp16_t*>();
178     auto content = weights->content();
179     IE_ASSERT(content != nullptr);
180     auto origWeights = content->get<fp16_t>();
181     IE_ASSERT(origWeights != nullptr);
182     RNNRelayout(origWeights,
183                 newWeightsPtr,
184                 newWeightsPtr + ngates * state_size * input_size,
185
186                 ngates,
187                 state_size,
188                 input_size);
189
190     auto newWeights = model->addConstData(
191         _layer->name + "@weights",
192         weights->desc(),
193         ieBlobContent(newWeightsBlob));
194
195     auto stateCellFinal = model->addFakeData();
196     auto stage = model->addNewStage<LSTMCellStage>(
197         layer->name,
198         StageType::LSTMCell,
199         layer,
200         {inputs[0], inputs[1], inputs[2], newWeights, biases},
201         {outputs[0], stateCellFinal});
202
203     if (nCells > 1)
204         model->addTempBuffer(stage, DataDesc({state_size}));
205
206     bool RNNForward = layer->direction == ie::RNNSequenceLayer::FWD;
207     stage->attrs().set<bool>("RNNForward", RNNForward);
208     stage->attrs().set<int>("nCells", nCells);
209     stage->attrs().set<int>("nBatches", nBatches);
210 }
211
212 void FrontEnd::parseLSTMCell(
213         const Model::Ptr& model,
214         const ie::CNNLayerPtr& _layer,
215         const DataVector &inputs,
216         const DataVector &outputs) {
217     IE_ASSERT(inputs.size() == 3);
218     IE_ASSERT(outputs.size() == 2);
219
220     auto layer = std::dynamic_pointer_cast<ie::LSTMCell>(_layer);
221     IE_ASSERT(layer != nullptr);
222
223     Data weights, biases;
224     std::tie(weights, biases) = getWeightsAndBiases(model, layer);
225
226     auto stage = model->addNewStage<LSTMCellStage>(
227             layer->name,
228             StageType::LSTMCell,
229             layer,
230             {inputs[0], inputs[1], inputs[2], weights, biases},
231             outputs);
232     stage->attrs().set<bool>("RNNForward", true);
233     stage->attrs().set<int>("nCells", 1);
234     stage->attrs().set<int>("nBatches", 1);
235 }
236
237 }  // namespace vpu