[LPT] integration branch: Reshape fix, Concat generalization, runtime info usage...
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / mat_mul_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
6
7 #include <queue>
8 #include <memory>
9
10 #include <ngraph/opsets/opset1.hpp>
11 #include "ngraph_ops/type_relaxed.hpp"
12 #include "ngraph_functions/subgraph_builders.hpp"
13 #include "low_precision/network_helper.hpp"
14 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
15
16 namespace ngraph {
17 namespace builder {
18 namespace subgraph {
19
20 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
21     const ngraph::element::Type precision,
22     const ngraph::Shape& inputShape1,
23     const FakeQuantizeOnData& fqOnData1,
24     const ngraph::Shape& inputShape2,
25     const FakeQuantizeOnData& fqOnData2) {
26     const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape1);
27     input1->set_friendly_name("input1");
28
29     const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
30     input2->set_friendly_name("input2");
31
32     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
33         makeFakeQuantize(input1, precision, fqOnData1),
34         makeFakeQuantize(input2, precision, fqOnData2),
35         false,
36         false);
37     matMul->set_friendly_name("matMul");
38
39     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
40
41     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
42         ngraph::ResultVector{ result },
43         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
44         "MatMulTransformation");
45     return function;
46 }
47
48 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
49     const ngraph::Shape& inputShape1,
50     const ngraph::element::Type precisionBeforeDequantization1,
51     const DequantizationOperations& dequantization1,
52     const ngraph::Shape& inputShape2,
53     const ngraph::element::Type precisionBeforeDequantization2,
54     const DequantizationOperations& dequantization2) {
55     if (!dequantization1.convert.empty() && (precisionBeforeDequantization1 == dequantization1.convert.outPrecision)) {
56         THROW_IE_EXCEPTION << "unexpected input arguments for branch 1";
57     }
58
59     if (!dequantization2.convert.empty() && (precisionBeforeDequantization2 == dequantization2.convert.outPrecision)) {
60         THROW_IE_EXCEPTION << "unexpected input arguments for branch 2";
61     }
62
63     const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization1, inputShape1);
64     input1->set_friendly_name("input1");
65
66     const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization2, inputShape2);
67     input2->set_friendly_name("input2");
68
69     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
70         makeDequantization(input1, dequantization1),
71         makeDequantization(input2, dequantization2),
72         false,
73         false);
74     matMul->set_friendly_name("matMul");
75
76     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
77
78     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
79         ngraph::ResultVector{ result },
80         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
81         "MatMulTransformation");
82     return function;
83 }
84
85 std::shared_ptr<ngraph::Function> getOriginalWithConstant2(
86     const ngraph::element::Type precision) {
87     return nullptr;
88 }
89
90 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
91     const ngraph::element::Type precision,
92     const ngraph::Shape& inputShape,
93     const ngraph::element::Type precisionBeforeDequantization,
94     const DequantizationOperations& dequantizationOperations,
95     const ngraph::Shape& weightsConstShape,
96     const std::vector<float>& weightsConstValues,
97     const FakeQuantizeOnWeights& fqOnWeights) {
98     const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
99         precisionBeforeDequantization,
100         inputShape);
101     input->set_friendly_name("input1");
102
103     auto lastDequantization = makeDequantization(input, dequantizationOperations);
104
105     const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
106         precision,
107         weightsConstShape,
108         weightsConstValues);
109
110     auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
111
112     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
113         lastDequantization,
114         fakeQuantize,
115         false,
116         false);
117     matMul->set_friendly_name("matMul");
118     auto& rtInfo = matMul->get_rt_info();
119     rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("matMul");
120
121     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
122
123     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
124         ngraph::ResultVector{ result },
125         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
126         "MatMulTransformation");
127     return function;
128 }
129
130 std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
131     const ngraph::element::Type precision,
132     const ngraph::Shape& inputShape1,
133     const ngraph::element::Type precisionBeforeDequantization1,
134     const DequantizationOperations& dequantization1,
135     const ngraph::Shape& inputShape2,
136     const ngraph::element::Type precisionBeforeDequantization2,
137     const DequantizationOperations& dequantization2,
138     const DequantizationOperations& resultDequantizationOperations) {
139     if (!dequantization1.convert.empty() && (precisionBeforeDequantization1 == dequantization1.convert.outPrecision)) {
140         THROW_IE_EXCEPTION << "unexpected input arguments for branch 1";
141     }
142
143     if (!dequantization2.convert.empty() && (precisionBeforeDequantization2 == dequantization2.convert.outPrecision)) {
144         THROW_IE_EXCEPTION << "unexpected input arguments for branch 2";
145     }
146
147     const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization1, inputShape1);
148     input1->set_friendly_name("input1");
149
150     const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization2, inputShape2);
151     input2->set_friendly_name("input2");
152
153     auto dequantization1Op = makeDequantization(input1, dequantization1);
154     auto dequantization2Op = makeDequantization(input2, dequantization2);
155
156     std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::MatMul>>(
157         std::vector<element::Type>{ element::f32, element::f32 }, std::vector<element::Type>{},
158         ngraph::op::TemporaryReplaceOutputType(dequantization1Op, element::f32).get(),
159         ngraph::op::TemporaryReplaceOutputType(dequantization2Op, element::f32).get(),
160         false,
161         false);
162
163     matMul->set_friendly_name("matMul");
164     ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(matMul, precision);
165     auto dequantizationAfter = makeDequantization(matMul, resultDequantizationOperations);
166     dequantizationAfter->set_friendly_name("matMul");
167
168     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(dequantizationAfter);
169
170     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
171         ngraph::ResultVector{ result },
172         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
173         "MatMulTransformation");
174     return function;
175 }
176
177 std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
178     const ngraph::element::Type precision,
179     const ngraph::Shape& inputShape,
180     const ngraph::element::Type precisionBeforeDequantization,
181     const DequantizationOperations& dequantization,
182     const ngraph::element::Type weightsConstPrecision,
183     const ngraph::Shape& weightsConstShape,
184     const std::vector<float>& weightsConstValues,
185     const DequantizationOperations& resultDequantization) {
186     const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
187         precisionBeforeDequantization,
188         inputShape);
189     input->set_friendly_name("input1");
190
191     const std::shared_ptr<ngraph::Node> lastDequantizationBefore = makeDequantization(input, dequantization);
192
193     const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
194         weightsConstPrecision,
195         weightsConstShape,
196         weightsConstValues);
197
198     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::MatMul>>(
199         std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
200         ngraph::op::TemporaryReplaceOutputType(lastDequantizationBefore, element::f32).get(),
201         ngraph::op::TemporaryReplaceOutputType(weightsConst, element::f32).get(),
202         false,
203         false);
204     matMul->set_friendly_name("matMul");
205     auto& rtInfo = matMul->get_rt_info();
206     rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("matMul");
207     ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(matMul, precision);
208
209     const std::shared_ptr<ngraph::Node> lastDequantizationAfter = makeDequantization(matMul, resultDequantization);
210     lastDequantizationAfter->set_friendly_name("matMul");
211
212     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(lastDequantizationAfter);
213
214     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
215         ngraph::ResultVector{ result },
216         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
217         "MatMulTransformation");
218     return function;
219 }
220
221 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
222     const ngraph::element::Type precision,
223     const ngraph::Shape& inputShape,
224     const FakeQuantizeOnData& fqOnData,
225     const ngraph::Shape& weightsConstShape,
226     const std::vector<float>& weightsConstValues,
227     const FakeQuantizeOnWeights& fqOnWeights) {
228     const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
229         precision,
230         inputShape);
231     input->set_friendly_name("input1");
232
233     auto lastDequantization = makeFakeQuantize(input, precision, fqOnData);
234
235     const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
236         precision,
237         weightsConstShape,
238         weightsConstValues);
239
240     auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
241
242     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
243         lastDequantization,
244         fakeQuantize,
245         false,
246         false);
247     matMul->set_friendly_name("matMul");
248
249     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
250
251     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
252         ngraph::ResultVector{ result },
253         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
254         "MatMulTransformation");
255     return function;
256 }
257
258 }  // namespace subgraph
259 }  // namespace builder
260 }  // namespace ngraph