a2a62cffaf93fd582cdc3dcb2b517573a080d801
[platform/upstream/dldt.git] / inference-engine / tests / ngraph_functions / src / low_precision_transformations / mat_mul_function.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
6
7 #include <queue>
8 #include <memory>
9
10 #include <ngraph/opsets/opset1.hpp>
11 #include "ngraph_ops/type_relaxed.hpp"
12 #include "ngraph_functions/subgraph_builders.hpp"
13 #include "low_precision/network_helper.hpp"
14 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
15
16 namespace ngraph {
17 namespace builder {
18 namespace subgraph {
19
20 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
21     const ngraph::element::Type precision,
22     const ngraph::Shape& inputShape1,
23     const FakeQuantizeOnData& fqOnData1,
24     const ngraph::Shape& inputShape2,
25     const FakeQuantizeOnData& fqOnData2) {
26     const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape1);
27     input1->set_friendly_name("input1");
28
29     const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
30     input2->set_friendly_name("input2");
31
32     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
33         makeFakeQuantize(input1, precision, fqOnData1),
34         makeFakeQuantize(input2, precision, fqOnData2),
35         false,
36         false);
37     matMul->set_friendly_name("matMul");
38
39     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
40
41     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
42         ngraph::ResultVector{ result },
43         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
44         "MatMulTransformation");
45     return function;
46 }
47
48 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
49     const ngraph::Shape& inputShape1,
50     const ngraph::element::Type precisionBeforeDequantization1,
51     const DequantizationOperations& dequantization1,
52     const ngraph::Shape& inputShape2,
53     const ngraph::element::Type precisionBeforeDequantization2,
54     const DequantizationOperations& dequantization2) {
55     if (!dequantization1.convert.empty() && (precisionBeforeDequantization1 == dequantization1.convert.outPrecision)) {
56         THROW_IE_EXCEPTION << "unexpected input arguments for branch 1";
57     }
58
59     if (!dequantization2.convert.empty() && (precisionBeforeDequantization2 == dequantization2.convert.outPrecision)) {
60         THROW_IE_EXCEPTION << "unexpected input arguments for branch 2";
61     }
62
63     const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization1, inputShape1);
64     input1->set_friendly_name("input1");
65
66     const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization2, inputShape2);
67     input2->set_friendly_name("input2");
68
69     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
70         makeDequantization(input1, dequantization1),
71         makeDequantization(input2, dequantization2),
72         false,
73         false);
74     matMul->set_friendly_name("matMul");
75
76     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
77
78     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
79         ngraph::ResultVector{ result },
80         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
81         "MatMulTransformation");
82     return function;
83 }
84
85 std::shared_ptr<ngraph::Function> getOriginalWithConstant2(
86     const ngraph::element::Type precision) {
87     return nullptr;
88 }
89
90 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
91     const ngraph::element::Type precision,
92     const ngraph::Shape& inputShape,
93     const ngraph::element::Type precisionBeforeDequantization,
94     const DequantizationOperations& dequantizationOperations,
95     const ngraph::Shape& weightsConstShape,
96     const std::vector<float>& weightsConstValues,
97     const FakeQuantizeOnWeights& fqOnWeights) {
98     const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
99         precisionBeforeDequantization,
100         inputShape);
101     input->set_friendly_name("input1");
102
103     auto lastDequantization = makeDequantization(input, dequantizationOperations);
104
105     const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
106         precision,
107         weightsConstShape,
108         weightsConstValues);
109
110     auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
111
112     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
113         lastDequantization,
114         fakeQuantize,
115         false,
116         false);
117     matMul->set_friendly_name("matMul");
118
119     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
120
121     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
122         ngraph::ResultVector{ result },
123         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
124         "MatMulTransformation");
125     return function;
126 }
127
128 std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
129     const ngraph::element::Type precision,
130     const ngraph::Shape& inputShape1,
131     const ngraph::element::Type precisionBeforeDequantization1,
132     const DequantizationOperations& dequantization1,
133     const ngraph::Shape& inputShape2,
134     const ngraph::element::Type precisionBeforeDequantization2,
135     const DequantizationOperations& dequantization2,
136     const DequantizationOperations& resultDequantizationOperations) {
137     if (!dequantization1.convert.empty() && (precisionBeforeDequantization1 == dequantization1.convert.outPrecision)) {
138         THROW_IE_EXCEPTION << "unexpected input arguments for branch 1";
139     }
140
141     if (!dequantization2.convert.empty() && (precisionBeforeDequantization2 == dequantization2.convert.outPrecision)) {
142         THROW_IE_EXCEPTION << "unexpected input arguments for branch 2";
143     }
144
145     const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization1, inputShape1);
146     input1->set_friendly_name("input1");
147
148     const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization2, inputShape2);
149     input2->set_friendly_name("input2");
150
151     auto dequantization1Op = makeDequantization(input1, dequantization1);
152     auto dequantization2Op = makeDequantization(input2, dequantization2);
153
154     std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::MatMul>>(
155         std::vector<element::Type>{ element::f32, element::f32 }, std::vector<element::Type>{},
156         ngraph::op::TemporaryReplaceOutputType(dequantization1Op, element::f32).get(),
157         ngraph::op::TemporaryReplaceOutputType(dequantization2Op, element::f32).get(),
158         false,
159         false);
160
161     matMul->set_friendly_name("matMul");
162     ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(matMul, precision);
163     auto dequantizationAfter = makeDequantization(matMul, resultDequantizationOperations);
164     dequantizationAfter->set_friendly_name("matMul");
165
166     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(dequantizationAfter);
167
168     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
169         ngraph::ResultVector{ result },
170         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
171         "MatMulTransformation");
172     return function;
173 }
174
175 std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
176     const ngraph::element::Type precision,
177     const ngraph::Shape& inputShape,
178     const ngraph::element::Type precisionBeforeDequantization,
179     const DequantizationOperations& dequantization,
180     const ngraph::element::Type weightsConstPrecision,
181     const ngraph::Shape& weightsConstShape,
182     const std::vector<float>& weightsConstValues,
183     const DequantizationOperations& resultDequantization) {
184     const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
185         precisionBeforeDequantization,
186         inputShape);
187     input->set_friendly_name("input1");
188
189     const std::shared_ptr<ngraph::Node> lastDequantizationBefore = makeDequantization(input, dequantization);
190
191     const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
192         weightsConstPrecision,
193         weightsConstShape,
194         weightsConstValues);
195
196     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::MatMul>>(
197         std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
198         ngraph::op::TemporaryReplaceOutputType(lastDequantizationBefore, element::f32).get(),
199         ngraph::op::TemporaryReplaceOutputType(weightsConst, element::f32).get(),
200         false,
201         false);
202     matMul->set_friendly_name("matMul");
203     ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(matMul, precision);
204
205     const std::shared_ptr<ngraph::Node> lastDequantizationAfter = makeDequantization(matMul, resultDequantization);
206     lastDequantizationAfter->set_friendly_name("matMul");
207
208     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(lastDequantizationAfter);
209
210     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
211         ngraph::ResultVector{ result },
212         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
213         "MatMulTransformation");
214     return function;
215 }
216
217 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
218     const ngraph::element::Type precision,
219     const ngraph::Shape& inputShape,
220     const FakeQuantizeOnData& fqOnData,
221     const ngraph::Shape& weightsConstShape,
222     const std::vector<float>& weightsConstValues,
223     const FakeQuantizeOnWeights& fqOnWeights) {
224     const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
225         precision,
226         inputShape);
227     input->set_friendly_name("input1");
228
229     auto lastDequantization = makeFakeQuantize(input, precision, fqOnData);
230
231     const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
232         precision,
233         weightsConstShape,
234         weightsConstValues);
235
236     auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
237
238     const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
239         lastDequantization,
240         fakeQuantize,
241         false,
242         false);
243     matMul->set_friendly_name("matMul");
244
245     std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
246
247     std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
248         ngraph::ResultVector{ result },
249         std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
250         "MatMulTransformation");
251     return function;
252 }
253
254 }  // namespace subgraph
255 }  // namespace builder
256 }  // namespace ngraph