Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / gna_plugin / quantization / model_quantizer.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6 #include <vector>
7 #include "gna_plugin_config.hpp"
8 #include "layer_transform.hpp"
9 #include "graph_tools.hpp"
10 #include "details/ie_cnn_network_tools.h"
11 #include "layer_quantizer.hpp"
12 #include "scale_factor_calc.hpp"
13
14 namespace GNAPluginNS {
15 /**
16  * Quantize entire cnn - network
17  * @tparam T - type trait for weights and biases
18  */
19 template<class T>
20 class ModelQuantizer {
21  public:
22     CNNNetworkPtr quantize(InferenceEngine::ICNNNetwork &model, float scaleFactor) const {
23         return quantize(model, [](InferenceEngine::CNNNetPtr &){}, scaleFactor);
24     }
25
26     template <class PreQuantisationCb>
27     CNNNetworkPtr quantize(InferenceEngine::ICNNNetwork &model, const PreQuantisationCb &cb, float scaleFactor) const {
28         auto visitor = [&](InferenceEngine::CNNLayerPtr lp) {
29             return InferenceEngine::injectData<QuantizedLayerParams>(lp);
30         };
31         auto copiedNet = InferenceEngine::CNNNetCopy(model, visitor);
32
33         // TODO: probably not the best way of using dynamic cast in order to transform Precision
34         // one of solution is to create not copyNet overloads, that accepts 2 functors, one for layer copy
35         // and another one for net copy
36         auto rawNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(copiedNet.get());
37         rawNet->setPrecision(T::mandatory().getNetPrecision());
38
39         // allow client code to access copied topology, to avoid copies if user would like to chain quantisation with
40         // another preprocessing
41         cb(copiedNet);
42
43         LayersQuantizer<T> lc(scaleFactor);
44         auto sortedNewNet = InferenceEngine::details::CNNNetSortTopologically(*copiedNet.get());
45         gnalog() << "Sorted layers: " << std::endl;
46         for (auto &&layer : sortedNewNet) {
47             gnalog() << layer->name << std::endl;
48         }
49
50         // weights scale is a hint, not all weightable layers preserve it in all possible precisions
51         propagateScaleFactor(sortedNewNet, T::mandatory().getWeightsPrecision().size(), scaleFactor);
52
53         // sorted order gives possibility for propagate quantisation along depended layers
54         for (auto &&layer : sortedNewNet) {
55             transformLayer(layer, lc);
56         }
57
58         return copiedNet;
59     }
60
61  private :
62     void propagateScaleFactor(std::vector<InferenceEngine::CNNLayerPtr> & net, int weightsBytesSize, float scaleFactor) const {
63         ScaleFactorCalculator sf(net, weightsBytesSize, scaleFactor);
64
65         while (!sf.allLayersProcessed()) {
66             for (auto &&layer : sf.getStartLayers()) {
67                 transformLayer(layer, sf);
68                 // transforming until we reached cases where output scale updated due to situation in downstream layer
69                 if (sf.needToRestart()) {
70                     break;
71                 }
72             }
73         }
74     }
75 };
76 }  // namespace GNAPluginNS