Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / passes / hw_conv_tiling / hw_convolution_tiler.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <algorithm>
6 #include <limits>
7 #include <vector>
8 #include <memory>
9 #include <utility>
10 #include <vpu/passes/hw_conv_tiling/hw_convolution_tiler.hpp>
11
12 namespace vpu {
13
14 namespace HWTilingNS {
15
16 bool operator<(const TilingOption& a, const TilingOption& b) {
17     return a.cost < b.cost || (isDoubleEqual(a.cost, b.cost) && a.totalNumTiles < b.totalNumTiles);
18 }
19
20 class ConvInputToOutputDirection;
21 class ConvOutputToInputDirection;
22
23 // Input -> Output case
24 class ConvInputToOutputDirection: public GraphDataTiling {
25 public:
26     explicit ConvInputToOutputDirection(const ConvolutionOptions &co): GraphDataTiling(co, Direction::INPUT_TO_OUTPUT) {}
27     ConvInputToOutputDirection(const ConvInputToOutputDirection &other): GraphDataTiling(other) {}
28     void initTileSizes() override {
29         _useCeil = ceilNeeded();
30
31         _inputTileDims.set(Dim::W, std::min(CNN_MAX_INPUT_WIDTH, _co._inputDims[Dim::W]));
32         _inputTileDims.set(Dim::H, std::min(CNN_MAX_INPUT_HEIGHT, _co._inputDims[Dim::H]));
33         _inputTileDims.set(Dim::C, std::min(CNN_MAX_INPUT_CHANNELS, _co._inputDims[Dim::C]));
34
35         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
36         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
37         _outputTileDims.set(Dim::C, _co._outputDims[Dim::C]);
38
39         correctOutputPlaneSize();
40     }
41
42     // Input -> Output case
43     void setInputNOutputTileDimensions(const int tileDimW, const int tileDimH, const int tileDimC) override {
44         _inputTileDims.set(Dim::W, tileDimW);
45         _inputTileDims.set(Dim::H, tileDimH);
46         _inputTileDims.set(Dim::C, tileDimC);
47
48         correctOutputPlaneSize();
49     }
50
51     // Input -> Output case
52     void applyTilingOption(const TilingOption &tilingOption) override {
53         int tileDimW = divUp(_co._inputDims[Dim::W], tilingOption.numWidthTiles);
54         int tileDimH = divUp(_co._inputDims[Dim::H], tilingOption.numHeightTiles);
55         const int tileDimC = divUp(_co._inputDims[Dim::C], tilingOption.numChannelTiles);
56
57         tileDimW = divUp(tileDimW, _co._kernelStride) * _co._kernelStride;
58         tileDimH = divUp(tileDimH, _co._kernelStride) * _co._kernelStride;
59
60         _inputTileDims.set(Dim::W, tileDimW);
61         _inputTileDims.set(Dim::H, tileDimH);
62         _inputTileDims.set(Dim::C, tileDimC);
63
64         correctOutputPlaneSize();
65     }
66
67     void correctPlaneSize() override {
68         correctOutputPlaneSize();
69     }
70
71     void correctOutputPlaneSize() {
72         int maxOutputWidth = calcOutputSize(_inputTileDims[Dim::W], _co._kernelSizeX, _co._kernelStride,
73                 _co._paddingLeft, _co._paddingRight, _useCeil);
74         if (_co._withPool) {
75             maxOutputWidth /= 2;
76         }
77         _outputTileDims.set(Dim::W, std::min(_outputTileDims[Dim::W], maxOutputWidth));
78
79         int maxOutputHeight = calcOutputSize(_inputTileDims[Dim::H], _co._kernelSizeY, _co._kernelStride,
80                 _co._paddingTop, _co._paddingBottom, _useCeil);
81         if (_co._withPool) {
82             maxOutputHeight /= 2;
83         }
84         _outputTileDims.set(Dim::H, std::min(_outputTileDims[Dim::H], maxOutputHeight));
85     }
86
87     const DimValues &splitOverTensorDims() override {
88         return _co._inputDims;
89     }
90
91     void patternMatching() override;
92
93 private:
94     bool ceilNeeded() {
95         int tempX = _co._inputDims[Dim::W] + _co._paddingLeft + _co._paddingRight - _co._kernelSizeX;
96         int tempY = _co._inputDims[Dim::H] + _co._paddingTop + _co._paddingBottom - _co._kernelSizeY;
97
98         int outWidthWithOutCeil = (tempX + _co._kernelStride) / _co._kernelStride;
99         int outHeightWithOutCeil = (tempY + _co._kernelStride) / _co._kernelStride;
100
101         int outWidthWithCeil = static_cast<int>(std::ceil(static_cast<double>(tempX) / _co._kernelStride + 1));
102         int outHeightWithCeil = static_cast<int>(std::ceil(static_cast<double>(tempY) / _co._kernelStride + 1));
103
104         if ((_co._origOutputDims[Dim::W] != outWidthWithCeil) && (_co._origOutputDims[Dim::W] != outWidthWithOutCeil)) {
105             VPU_THROW_EXCEPTION
106                     << "Internal error: Output in " << _co._stageName << " has incorrect width dimension. Expected: "
107                     << outWidthWithCeil << " or " << outWidthWithOutCeil << " Actual: " << _co._origOutputDims[Dim::W];
108         }
109
110         if ((_co._origOutputDims[Dim::H] != outHeightWithCeil) && (_co._origOutputDims[Dim::H] != outHeightWithOutCeil)) {
111             VPU_THROW_EXCEPTION
112                     << "Internal error: Output in " << _co._stageName << " has incorrect height dimension. Expected: "
113                     << outHeightWithCeil << " or " << outHeightWithOutCeil << " Actual: " << _co._origOutputDims[Dim::H];
114         }
115
116         if ((_co._origOutputDims[Dim::W] == outWidthWithOutCeil) && (_co._origOutputDims[Dim::H] == outHeightWithOutCeil)) {
117             return false;
118         } else {
119             return true;
120         }
121     }
122 };
123
124 // Output -> Input case
125 class ConvOutputToInputDirection: public GraphDataTiling {
126 public:
127     explicit ConvOutputToInputDirection(const ConvolutionOptions &co): GraphDataTiling(co, Direction::OUTPUT_TO_INPUT) {}
128     ConvOutputToInputDirection(const ConvOutputToInputDirection &other): GraphDataTiling(other) {}
129     void initTileSizes() override {
130         _useCeil = false;   // no ceiling needed for ConvOutputToInputDirection
131
132         _outputTileDims.set(Dim::W, std::min(CNN_MAX_INPUT_WIDTH, _co._outputDims[Dim::W]));
133         _outputTileDims.set(Dim::H, std::min(CNN_MAX_INPUT_HEIGHT, _co._outputDims[Dim::H]));
134         _outputTileDims.set(Dim::C, _co._outputDims[Dim::C]);
135
136         _inputTileDims.set(Dim::W, std::min(CNN_MAX_INPUT_WIDTH, _co._inputDims[Dim::W]));
137         _inputTileDims.set(Dim::H, std::min(CNN_MAX_INPUT_HEIGHT, _co._inputDims[Dim::H]));
138         _inputTileDims.set(Dim::C, std::min(CNN_MAX_INPUT_CHANNELS, _co._inputDims[Dim::C]));
139
140         correctInputPlaneSize();
141     }
142     // Output -> Input case
143     void setInputNOutputTileDimensions(const int tileDimW, const int tileDimH, const int tileDimC) override {
144         _outputTileDims.set(Dim::W, tileDimW);
145         _outputTileDims.set(Dim::H, tileDimH);
146         _outputTileDims.set(Dim::C, tileDimC);
147
148         correctInputPlaneSize();
149     }
150
151     // Output -> Input case
152     void applyTilingOption(const TilingOption &tilingOption) override {
153         const int tileDimW = divUp(_co._outputDims[Dim::W], tilingOption.numWidthTiles);
154         const int tileDimH = divUp(_co._outputDims[Dim::H], tilingOption.numHeightTiles);
155         // split only input tensor over C dim
156         const int tileDimC = divUp(_co._inputDims[Dim::C], tilingOption.numChannelTiles);
157
158         _outputTileDims.set(Dim::W, tileDimW);
159         _outputTileDims.set(Dim::H, tileDimH);
160         _inputTileDims.set(Dim::C, tileDimC);
161
162         correctInputPlaneSize();
163     }
164
165     int calcInputSize(
166             int outputSize,
167             int kernelSize, int kernelStride,
168             int padBefore, int padAfter
169     ) {
170         return (outputSize - 1) * kernelStride + kernelSize - padBefore - padAfter;
171     }
172
173     void correctPlaneSize() override {
174         correctInputPlaneSize();
175     }
176
177     void correctInputPlaneSize() {
178         int maxInputWidth = calcInputSize(_outputTileDims[Dim::W], _co._kernelSizeX, _co._kernelStride, _co._paddingLeft,
179                                           _co._paddingRight);
180         if (_co._withPool) {
181             maxInputWidth *= 2;
182         }
183         _inputTileDims.set(Dim::W, std::min(_inputTileDims[Dim::W], maxInputWidth));
184
185         int maxInputHeight = calcInputSize(_outputTileDims[Dim::H], _co._kernelSizeY, _co._kernelStride, _co._paddingTop,
186                                            _co._paddingBottom);
187         if (_co._withPool) {
188             maxInputHeight *= 2;
189         }
190         _inputTileDims.set(Dim::H, std::min(_inputTileDims[Dim::H], maxInputHeight));
191     }
192
193     const DimValues &splitOverTensorDims() override {
194         return _co._outputDims;
195     }
196
197     void patternMatching() override {
198         // noop
199     }
200 };
201
202 HWConvolutionTiler::HWConvolutionTiler(const ConvolutionOptions &co,
203                    Direction direction,
204                    size_t maxTilingOptions) :
205         _co(co),
206         _searcher(_co, direction, maxTilingOptions) {
207     _tilingPossible = tileForHW();
208 }
209
210 bool HWConvolutionTiler::tileForHW() {
211     const std::vector<TilingOption> &tilingOptions = _searcher.tilingOptions();
212     if (tilingOptions.empty()) {
213             return false;
214     }
215
216     for (const TilingOption &tilingOption : tilingOptions) {
217         const HWConvolutionTileLayoutCut tileLayoutCut = _searcher.tileLayoutCut(tilingOption);
218         if (tileLayoutCut.tileCutPossible()) {
219             _hwTilings.push_back(tileLayoutCut.hwTiling());
220         }
221     }
222
223     return _hwTilings.size() != 0;
224 }
225
226 void ConvInputToOutputDirection::patternMatching() {
227     if (!_co._withPool &&
228         _co._kernelSizeX == 3 && _co._kernelSizeY == 3 && _co._paddingLeft == 1 && _co._paddingRight == 1  &&
229         _co._paddingTop == 1 && _co._paddingBottom == 1  && _co._kernelStride == 1 &&
230         _co._inputDims[Dim::C] == 512 && _co._inputDims[Dim::H] == 28 && _co._inputDims[Dim::W] == 28 &&
231         _co._outputDims[Dim::C] == 512) {
232         _inputTileDims.set(Dim::H, 28);
233         _inputTileDims.set(Dim::C, 172);
234         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
235         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
236         correctPlaneSize();
237         return;
238     }
239
240     if (!_co._withPool &&
241         _co._kernelSizeX == 3 && _co._kernelSizeY == 3 && _co._paddingLeft == 1 && _co._paddingRight == 1  &&
242         _co._paddingTop == 1 && _co._paddingBottom == 1  && _co._kernelStride == 1 &&
243         _co._inputDims[Dim::C] == 256 && _co._inputDims[Dim::H] == 56 && _co._inputDims[Dim::W] == 56 &&
244         _co._outputDims[Dim::C] == 256) {
245         _inputTileDims.set(Dim::H, 30);
246         _inputTileDims.set(Dim::C, 128);
247         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
248         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
249         correctPlaneSize();
250         return;
251     }
252
253     if (!_co._withPool &&
254         _co._kernelSizeX == 3 && _co._kernelSizeY == 3 && _co._paddingLeft == 1 && _co._paddingRight == 1  &&
255         _co._paddingTop == 1 && _co._paddingBottom == 1  && _co._kernelStride == 1 &&
256         _co._inputDims[Dim::C] == 64 && _co._inputDims[Dim::H] == 224 && _co._inputDims[Dim::W] == 224 &&
257         _co._outputDims[Dim::C] == 64) {
258         _inputTileDims.set(Dim::H, 82);
259         _inputTileDims.set(Dim::W, 82);
260         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
261         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
262         correctPlaneSize();
263         return;
264     }
265
266     if (_co._inputDims[Dim::C] == 512 &&
267         _co._inputDims[Dim::H] == 7 &&
268         _co._inputDims[Dim::W] == 7 &&
269         _co._outputDims[Dim::C] == 4096) {
270         _inputTileDims.set(Dim::C, 64);
271         correctPlaneSize();
272         return;
273     }
274
275     if (!_co._withPool &&
276         _co._kernelSizeX == 3 && _co._kernelSizeY == 3 && _co._paddingLeft == 1 && _co._paddingRight == 1  &&
277         _co._paddingTop == 1 && _co._paddingBottom == 1  && _co._kernelStride == 1 &&
278         _co._inputDims[Dim::C] == 128 && _co._inputDims[Dim::H] == 112 && _co._inputDims[Dim::W] == 112 &&
279         _co._outputDims[Dim::C] == 128) {
280         _inputTileDims.set(Dim::H, 32);
281         _inputTileDims.set(Dim::W, 112);
282         _inputTileDims.set(Dim::C, 32);
283         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
284         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
285         correctPlaneSize();
286         return;
287     }
288
289     if (_co._inputDims[Dim::C] == 1088 &&
290         _co._inputDims[Dim::H] == 17 &&
291         _co._inputDims[Dim::W] == 17 &&
292         (_co._outputDims[Dim::C] == 128 || _co._outputDims[Dim::C] == 192)) {
293         _inputTileDims.set(Dim::H, 17);
294         _inputTileDims.set(Dim::C, 544);
295         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
296         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
297         correctPlaneSize();
298         return;
299     }
300
301     if (_co._inputDims[Dim::C] == 1024 &&
302         _co._inputDims[Dim::H] == 17 &&
303         _co._inputDims[Dim::W] == 17 &&
304         _co._outputDims[Dim::C] == 384) {
305         _inputTileDims.set(Dim::H, 17);
306         _inputTileDims.set(Dim::C, 512);
307         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
308         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
309         correctPlaneSize();
310         return;
311     }
312
313     if (!_co._withPool &&
314         _co._kernelSizeX == 3 && _co._kernelSizeY == 3 && _co._paddingLeft == 0 && _co._paddingRight == 0  &&
315         _co._paddingTop == 0 && _co._paddingBottom == 0  && _co._kernelStride == 2 &&
316         _co._inputDims[Dim::C] == 384 && _co._inputDims[Dim::H] == 35 && _co._inputDims[Dim::W] == 35 &&
317         _co._outputDims[Dim::C] == 384) {
318         _inputTileDims.set(Dim::C, 194);
319         _inputTileDims.set(Dim::H, 35);
320         _inputTileDims.set(Dim::W, 35);
321         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
322         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
323         correctPlaneSize();
324         return;
325     }
326
327     if (_co._inputDims[Dim::C] == 192 &&
328         _co._inputDims[Dim::H] == 71 &&
329         _co._inputDims[Dim::W] == 71 &&
330         _co._outputDims[Dim::H] == 35) {
331         _inputTileDims.set(Dim::W, 71);
332         _inputTileDims.set(Dim::C, 96);
333         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
334         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
335         correctPlaneSize();
336         return;
337     }
338
339     if (!_co._withPool &&
340         _co._inputDims[Dim::C] == 256 &&
341         _co._inputDims[Dim::H] == 128 &&
342         _co._inputDims[Dim::W] == 128 &&
343         _co._outputDims[Dim::C] == 256) {
344         _inputTileDims.set(Dim::W, 128);
345         _inputTileDims.set(Dim::H, 15);
346         _inputTileDims.set(Dim::C, 64);
347         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
348         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
349         correctPlaneSize();
350         return;
351     }
352
353     if (!_co._withPool &&
354         _co._inputDims[Dim::C] == 512 &&
355         _co._inputDims[Dim::H] == 64 &&
356         _co._inputDims[Dim::W] == 64 &&
357         _co._outputDims[Dim::C] == 512) {
358         _inputTileDims.set(Dim::W, 64);
359         _inputTileDims.set(Dim::H, 10);
360         _inputTileDims.set(Dim::C, 128);
361         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
362         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
363         correctPlaneSize();
364         return;
365     }
366
367     if (!_co._withPool &&
368         _co._kernelSizeX == 1 && _co._kernelSizeY == 1 && _co._paddingLeft == 0 && _co._paddingRight == 0  &&
369         _co._paddingTop == 0 && _co._paddingBottom == 0  && _co._kernelStride == 1 &&
370         _co._inputDims[Dim::C] == 384 &&
371         _co._inputDims[Dim::H] == 56 &&
372         _co._inputDims[Dim::W] == 56 &&
373         _co._outputDims[Dim::C] == 64) {
374         _inputTileDims.set(Dim::C, 384);
375         _inputTileDims.set(Dim::H, 56);
376         _inputTileDims.set(Dim::W, 20);
377         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
378         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
379         correctPlaneSize();
380         return;
381     }
382
383     if (!_co._withPool &&
384         _co._kernelSizeX == 1 && _co._kernelSizeY == 1 && _co._paddingLeft == 0 && _co._paddingRight == 0  &&
385         _co._paddingTop == 0 && _co._paddingBottom == 0  && _co._kernelStride == 1 &&
386         _co._inputDims[Dim::C] == 2112 &&
387         _co._inputDims[Dim::H] == 14 &&
388         _co._inputDims[Dim::W] == 14 &&
389         _co._outputDims[Dim::C] == 1056) {
390         _inputTileDims.set(Dim::C, 556);
391         _inputTileDims.set(Dim::H, 14);
392         _inputTileDims.set(Dim::W, 14);
393         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
394         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
395         correctPlaneSize();
396         return;
397     }
398
399     if (!_co._withPool &&
400         _co._kernelSizeX == 3 && _co._kernelSizeY == 3 && _co._paddingLeft == 1 && _co._paddingRight == 1  &&
401         _co._paddingTop == 1 && _co._paddingBottom == 1  && _co._kernelStride == 2 &&
402         _co._inputDims[Dim::C] == 256 &&
403         _co._inputDims[Dim::H] == 52 &&
404         _co._inputDims[Dim::W] == 52 &&
405         _co._outputDims[Dim::C] == 512) {
406         _inputTileDims.set(Dim::C, 128);
407         _inputTileDims.set(Dim::H, 52);
408         _inputTileDims.set(Dim::W, 52);
409         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
410         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
411         correctPlaneSize();
412         return;
413     }
414
415     if (!_co._withPool &&
416         _co._kernelSizeX == 3 && _co._kernelSizeY == 3 && _co._paddingLeft == 1 && _co._paddingRight == 1  &&
417         _co._paddingTop == 1 && _co._paddingBottom == 1  && _co._kernelStride == 1 &&
418         _co._inputDims[Dim::C] == 256 &&
419         _co._inputDims[Dim::H] == 23 &&
420         _co._inputDims[Dim::W] == 23 &&
421         _co._outputDims[Dim::C] == 640) {
422         _inputTileDims.set(Dim::C, 256);
423         _inputTileDims.set(Dim::H, 14);
424         _inputTileDims.set(Dim::W, 23);
425         _outputTileDims.set(Dim::H, _co._outputDims[Dim::H]);
426         _outputTileDims.set(Dim::W, _co._outputDims[Dim::W]);
427         correctPlaneSize();
428         return;
429     }
430 }
431
432 std::unique_ptr<GraphDataTiling> ConvGraphDataTilingFactory::makeDirTiling(const ConvolutionOptions &co,
433         Direction direction) {
434     if (direction == Direction::INPUT_TO_OUTPUT) {
435         return std::unique_ptr<GraphDataTiling>(new ConvInputToOutputDirection(co));
436     } else if (direction == Direction::OUTPUT_TO_INPUT) {
437         return std::unique_ptr<GraphDataTiling>(new ConvOutputToInputDirection(co));
438     } else {
439         IE_ASSERT(false) << "Unsupported direction";
440     }
441 }
442
443 std::unique_ptr<GraphDataTiling> ConvGraphDataTilingFactory::makeDirTiling(const GraphDataTiling &o) {
444     if (o.getDirection() == Direction::INPUT_TO_OUTPUT) {
445         return std::unique_ptr<GraphDataTiling>(
446                 new ConvInputToOutputDirection(dynamic_cast<const ConvInputToOutputDirection&>(o)));
447     } else if (o.getDirection() == Direction::OUTPUT_TO_INPUT) {
448         return std::unique_ptr<GraphDataTiling>(
449                 new ConvOutputToInputDirection(dynamic_cast<const ConvOutputToInputDirection&>(o)));
450     } else {
451         IE_ASSERT(false) << "Unsupported direction";
452     }
453 }
454
455 //
456 // Looks for the optimal tiling accordingly to the cost function. Modifies dimensions in dirTiling during search.
457 //
458 std::vector<TilingOption> HWConvolutionTilingSearcher::selectBetterTiling() const {
459     const auto &env = CompileEnv::get();
460     GraphDataTiling &dirTiling = *_dirTiling;
461     FixedMaxHeap<TilingOption> tilingOptions(_maxTilingOptions);
462
463     // TODO: estimate this numbers
464     const int maxNumWidthTiles = 15;
465     const int maxNumHeightTiles = 15;
466     const int maxNumChannelTiles = _co._withPool ? 1 : 15;
467
468     const auto outputTileInitial = dirTiling.getOutputTileDims();
469     const auto inputTileInitial = dirTiling.getInputTileDims();
470
471     auto minInputTileDimW = 64;
472     auto minInputTileDimH = _co._kernelSizeY;
473     if (_co._withPool) {
474         minInputTileDimW *= 2;
475         minInputTileDimH *= 2;
476     }
477
478     const DimValues &splitOver = dirTiling.splitOverTensorDims();
479     const auto direction = dirTiling.getDirection();
480     // split over Input tensor for the Channel dimension always
481     for (int numChannelTiles = 1; numChannelTiles <= maxNumChannelTiles; numChannelTiles++) {
482         const int tileSizeDimC = divUp(_co._inputDims[Dim::C], numChannelTiles);
483
484         // here split and iterate either over input tensors or over output tensors depending on the direction.
485         for (int numWidthTiles = 1; numWidthTiles <= maxNumWidthTiles; numWidthTiles++) {
486             int tileSizeDimW = divUp(splitOver[Dim::W], numWidthTiles);
487
488             //
489             // Filter-out too small SoW input tiles when loops split input tensors.
490             //
491
492             if (numWidthTiles > 1 && direction == Direction::INPUT_TO_OUTPUT) {
493                 tileSizeDimW = divUp(tileSizeDimW, _co._kernelStride) * _co._kernelStride;
494
495                 if (tileSizeDimW < minInputTileDimW) {
496                     break;
497                 }
498             }
499
500             for (int numHeightTiles = 1; numHeightTiles <= maxNumHeightTiles; numHeightTiles++) {
501                 int tileSizeDimH = divUp(splitOver[Dim::H], numHeightTiles);
502
503                 //
504                 // Filter-out too small SoH input tiles when loops split input tensors.
505                 //
506
507                 if (numHeightTiles > 1 && direction == Direction::INPUT_TO_OUTPUT) {
508                     tileSizeDimH = divUp(tileSizeDimH, _co._kernelStride) * _co._kernelStride;
509
510                     if (tileSizeDimH < minInputTileDimH) {
511                         break;
512                     }
513                 }
514
515                 //
516                 // Try current tile size.
517                 //
518
519                 dirTiling.resetInputTileDims(inputTileInitial);
520                 dirTiling.resetOutputTileDims(outputTileInitial);
521
522                 dirTiling.setInputNOutputTileDimensions(tileSizeDimW, tileSizeDimH, tileSizeDimC);
523
524                 //
525                 // Limitations for Conv+Pool case.
526                 //
527
528                 if (_co._withPool) {
529                     if (dirTiling.getOutputTileDims()[Dim::W] <= 2 ||
530                         dirTiling.getOutputTileDims()[Dim::H] <= 2) {
531                         break;
532                     }
533                 }
534
535                 //
536                 // Check that tiling is valid.
537                 //
538
539                 // todo: check internal in/out hardcodes
540                 const auto heightTiles = calcHeightTiles(_co, dirTiling.getOutputTileDims(),
541                                                          dirTiling.useCeil());
542                 const auto widthTiles = calcWidthTiles(_co, dirTiling.getOutputTileDims(), dirTiling.useCeil());
543
544                 if (heightTiles.empty()) {
545                     continue;
546                 }
547                 if (widthTiles.empty()) {
548                     break;
549                 }
550
551                 bool isOK = true;
552                 double solutionCost = 0.0;
553
554                 for (const auto &heightTile : heightTiles) {
555                     for (const auto &widthTile : widthTiles) {
556                         //
557                         // Limitations for Conv+Pool case.
558                         //
559
560                         if (_co._withPool) {
561                             if (widthTile.inputWithJunk % 2 != 0 ||
562                                 heightTile.inputWithJunk % 2 != 0 ||
563                                 widthTile.outputWithJunk % 2 != 0 ||
564                                 widthTile.outputWithJunk <= 2 ||
565                                 heightTile.outputWithJunk <= 2) {
566                                 isOK = false;
567                                 break;
568                             }
569                         }
570
571                         //
572                         // Can use this tile.
573                         //
574
575                         auto tileInfo = splitHwConvIntoOutChannelsTiles(  // left asis, not new ver in new api
576                                 widthTile.inputWithJunk, heightTile.inputWithJunk, tileSizeDimC,
577                                 outputTileInitial[Dim::C],
578                                 _co._kernelSizeX, _co._kernelSizeY, _co._kernelStride);
579
580                         if (tileInfo.numDescr == 0) {
581                             isOK = false;
582                             break;
583                         }
584
585                         //
586                         // Output tile fits to CMX limitation.
587                         //
588
589                         DimValues fullOutputTileDims;
590                         fullOutputTileDims.set(Dim::W, widthTile.outputWithJunk);
591                         fullOutputTileDims.set(Dim::H, heightTile.outputWithJunk);
592                         fullOutputTileDims.set(Dim::C, outputTileInitial[Dim::C]);
593
594                         // TODO: support HCW
595                         if (calculateHwBufferSize(fullOutputTileDims) > env.resources.cmxLimit) {
596                             isOK = false;
597                             break;
598                         }
599
600                         //
601                         // Calc tile cost.
602                         //
603
604                         solutionCost += tileInfo.cost * numChannelTiles;
605
606                         // Alignment for output
607                         if ((widthTile.outputStartIndex * sizeof(fp16_t)) % 16 != 0) {
608                             solutionCost += 1.0
609                                             * widthTile.outputWithJunk
610                                             * heightTile.outputWithJunk
611                                             * outputTileInitial[Dim::C];
612                         }
613
614                         // Alignment for input
615                         if ((widthTile.inputStartIndex * sizeof(fp16_t)) % 16 != 0) {
616                             solutionCost += 1.0
617                                             * widthTile.inputWithJunk
618                                             * heightTile.inputWithJunk
619                                             * tileInfo.extendedInputDimC;
620                         }
621
622                         // SoC overhead
623                         solutionCost += 1.0
624                                         * (numChannelTiles - 1)
625                                         * widthTile.outputWithJunk
626                                         * heightTile.outputWithJunk
627                                         * outputTileInitial[Dim::C];
628                     }
629
630                     if (!isOK) {
631                         break;
632                     }
633                 }
634
635                 if (!isOK) {
636                     continue;
637                 }
638
639                 //
640                 // Put to the pool of best options.
641                 //
642
643                 const int totalNumTiles = numWidthTiles * numHeightTiles * numChannelTiles;
644
645                 const TilingOption to =
646                         {numWidthTiles, numHeightTiles, numChannelTiles, totalNumTiles, solutionCost};
647                 tilingOptions.push(to);
648
649                 // Skip smaller SoC tiling.
650                 break;
651             }
652         }
653     }
654
655     dirTiling.resetInputTileDims(inputTileInitial);
656     dirTiling.resetOutputTileDims(outputTileInitial);
657
658     return tilingOptions.sorted();
659 }
660
661 HWConvolutionTileLayoutCut HWConvolutionTilingSearcher::tileLayoutCut(const TilingOption &option) const {
662     return HWConvolutionTileLayoutCut(*_dirTiling, option);
663 }
664
665 std::ostream& operator<<(std::ostream &o, const TilingOption &to) {
666     o << "WHC: "
667         << to.numWidthTiles << "x"
668         << to.numHeightTiles << "x"
669         << to.numChannelTiles
670         << " Tot: " << to.totalNumTiles << " " << " cost: " << to.cost;
671
672     return o;
673 }
674
675 // based on height of the tile for output tensor
676 SmallVector<HwPlaneTileInfo> calcHeightTiles(const ConvolutionOptions &_co,
677                                              const DimValues &outputTileDims, bool useCeil) {
678     SmallVector<HwPlaneTileInfo> heightTiles;
679
680     if (outputTileDims[Dim::H] == _co._outputDims[Dim::H]) {
681         HwPlaneTileInfo info;
682         info.inputWithJunk = _co._inputDims[Dim::H];
683         info.outputWithJunk = _co._outputDims[Dim::H];
684         info.outputJunkBefore = 0;
685         info.outputJunkAfter = 0;
686         info.inputStartIndex = 0;
687         info.inputEndIndex = _co._inputDims[Dim::H];
688         info.outputStartIndex = 0;
689         info.outputEndIndex = _co._outputDims[Dim::H];
690
691         heightTiles.emplace_back(info);
692     } else {
693         if (_co._withPool) {
694             heightTiles = splitIntoPlaneTilesWithPool(
695                     _co._inputDims[Dim::H],
696                     _co._kernelSizeY,
697                     _co._kernelStride,
698                     _co._paddingTop,
699                     outputTileDims[Dim::H]);
700         } else {
701             heightTiles = splitIntoPlaneTiles(
702                     _co._inputDims[Dim::H],
703                     _co._outputDims[Dim::H],
704                     _co._kernelSizeY,
705                     _co._kernelStride,
706                     _co._paddingTop, _co._paddingBottom,
707                     outputTileDims[Dim::H],
708                     useCeil);
709         }
710     }
711
712     return heightTiles;
713 }
714
715 SmallVector<HwPlaneTileInfo> calcWidthTiles(const ConvolutionOptions &_co,
716                                             const DimValues &outputTileDims, bool useCeil) {
717     SmallVector<HwPlaneTileInfo> widthTiles;
718
719     if (outputTileDims[Dim::W] == _co._outputDims[Dim::W]) {
720         HwPlaneTileInfo info;
721         info.inputWithJunk = _co._inputDims[Dim::W];
722         info.outputWithJunk = _co._outputDims[Dim::W];
723         info.outputJunkBefore = 0;
724         info.outputJunkAfter = 0;
725         info.inputStartIndex = 0;
726         info.inputEndIndex = _co._inputDims[Dim::W];
727         info.outputStartIndex = 0;
728         info.outputEndIndex = _co._outputDims[Dim::W];
729
730         widthTiles.emplace_back(info);
731     } else {
732         if (_co._withPool) {
733             widthTiles = splitIntoPlaneTilesWithPool(
734                     _co._inputDims[Dim::W],
735                     _co._kernelSizeX,
736                     _co._kernelStride,
737                     _co._paddingLeft,
738                     outputTileDims[Dim::W]);
739         } else {
740             widthTiles = splitIntoPlaneTiles(
741                     _co._inputDims[Dim::W],
742                     _co._outputDims[Dim::W],
743                     _co._kernelSizeX,
744                     _co._kernelStride,
745                     _co._paddingLeft, _co._paddingRight,
746                     outputTileDims[Dim::W],
747                     useCeil);
748         }
749     }
750
751     return widthTiles;
752 }
753
754 }  // namespace HWTilingNS
755
756 }  // namespace vpu
757