1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
10 #include <ngraph/opsets/opset1.hpp>
11 #include "ngraph_ops/type_relaxed.hpp"
12 #include "ngraph_functions/subgraph_builders.hpp"
13 #include "low_precision/network_helper.hpp"
14 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
20 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
21 const ngraph::element::Type precision,
22 const ngraph::Shape& inputShape1,
23 const FakeQuantizeOnData& fqOnData1,
24 const ngraph::Shape& inputShape2,
25 const FakeQuantizeOnData& fqOnData2) {
26 const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape1);
27 input1->set_friendly_name("input1");
29 const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
30 input2->set_friendly_name("input2");
32 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
33 makeFakeQuantize(input1, precision, fqOnData1),
34 makeFakeQuantize(input2, precision, fqOnData2),
37 matMul->set_friendly_name("matMul");
39 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
41 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
42 ngraph::ResultVector{ result },
43 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
44 "MatMulTransformation");
48 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
49 const ngraph::Shape& inputShape1,
50 const ngraph::element::Type precisionBeforeDequantization1,
51 const DequantizationOperations& dequantization1,
52 const ngraph::Shape& inputShape2,
53 const ngraph::element::Type precisionBeforeDequantization2,
54 const DequantizationOperations& dequantization2) {
55 if (!dequantization1.convert.empty() && (precisionBeforeDequantization1 == dequantization1.convert.outPrecision)) {
56 THROW_IE_EXCEPTION << "unexpected input arguments for branch 1";
59 if (!dequantization2.convert.empty() && (precisionBeforeDequantization2 == dequantization2.convert.outPrecision)) {
60 THROW_IE_EXCEPTION << "unexpected input arguments for branch 2";
63 const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization1, inputShape1);
64 input1->set_friendly_name("input1");
66 const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization2, inputShape2);
67 input2->set_friendly_name("input2");
69 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
70 makeDequantization(input1, dequantization1),
71 makeDequantization(input2, dequantization2),
74 matMul->set_friendly_name("matMul");
76 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
78 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
79 ngraph::ResultVector{ result },
80 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
81 "MatMulTransformation");
85 std::shared_ptr<ngraph::Function> getOriginalWithConstant2(
86 const ngraph::element::Type precision) {
90 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
91 const ngraph::element::Type precision,
92 const ngraph::Shape& inputShape,
93 const ngraph::element::Type precisionBeforeDequantization,
94 const DequantizationOperations& dequantizationOperations,
95 const ngraph::Shape& weightsConstShape,
96 const std::vector<float>& weightsConstValues,
97 const FakeQuantizeOnWeights& fqOnWeights) {
98 const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
99 precisionBeforeDequantization,
101 input->set_friendly_name("input1");
103 auto lastDequantization = makeDequantization(input, dequantizationOperations);
105 const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
110 auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
112 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
117 matMul->set_friendly_name("matMul");
118 auto& rtInfo = matMul->get_rt_info();
119 rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("matMul");
121 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
123 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
124 ngraph::ResultVector{ result },
125 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
126 "MatMulTransformation");
130 std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
131 const ngraph::element::Type precision,
132 const ngraph::Shape& inputShape1,
133 const ngraph::element::Type precisionBeforeDequantization1,
134 const DequantizationOperations& dequantization1,
135 const ngraph::Shape& inputShape2,
136 const ngraph::element::Type precisionBeforeDequantization2,
137 const DequantizationOperations& dequantization2,
138 const DequantizationOperations& resultDequantizationOperations) {
139 if (!dequantization1.convert.empty() && (precisionBeforeDequantization1 == dequantization1.convert.outPrecision)) {
140 THROW_IE_EXCEPTION << "unexpected input arguments for branch 1";
143 if (!dequantization2.convert.empty() && (precisionBeforeDequantization2 == dequantization2.convert.outPrecision)) {
144 THROW_IE_EXCEPTION << "unexpected input arguments for branch 2";
147 const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization1, inputShape1);
148 input1->set_friendly_name("input1");
150 const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization2, inputShape2);
151 input2->set_friendly_name("input2");
153 auto dequantization1Op = makeDequantization(input1, dequantization1);
154 auto dequantization2Op = makeDequantization(input2, dequantization2);
156 std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::MatMul>>(
157 std::vector<element::Type>{ element::f32, element::f32 }, std::vector<element::Type>{},
158 ngraph::op::TemporaryReplaceOutputType(dequantization1Op, element::f32).get(),
159 ngraph::op::TemporaryReplaceOutputType(dequantization2Op, element::f32).get(),
163 matMul->set_friendly_name("matMul");
164 ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(matMul, precision);
165 auto dequantizationAfter = makeDequantization(matMul, resultDequantizationOperations);
166 dequantizationAfter->set_friendly_name("matMul");
168 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(dequantizationAfter);
170 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
171 ngraph::ResultVector{ result },
172 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
173 "MatMulTransformation");
177 std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
178 const ngraph::element::Type precision,
179 const ngraph::Shape& inputShape,
180 const ngraph::element::Type precisionBeforeDequantization,
181 const DequantizationOperations& dequantization,
182 const ngraph::element::Type weightsConstPrecision,
183 const ngraph::Shape& weightsConstShape,
184 const std::vector<float>& weightsConstValues,
185 const DequantizationOperations& resultDequantization) {
186 const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
187 precisionBeforeDequantization,
189 input->set_friendly_name("input1");
191 const std::shared_ptr<ngraph::Node> lastDequantizationBefore = makeDequantization(input, dequantization);
193 const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
194 weightsConstPrecision,
198 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::MatMul>>(
199 std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
200 ngraph::op::TemporaryReplaceOutputType(lastDequantizationBefore, element::f32).get(),
201 ngraph::op::TemporaryReplaceOutputType(weightsConst, element::f32).get(),
204 matMul->set_friendly_name("matMul");
205 auto& rtInfo = matMul->get_rt_info();
206 rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("matMul");
207 ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(matMul, precision);
209 const std::shared_ptr<ngraph::Node> lastDequantizationAfter = makeDequantization(matMul, resultDequantization);
210 lastDequantizationAfter->set_friendly_name("matMul");
212 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(lastDequantizationAfter);
214 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
215 ngraph::ResultVector{ result },
216 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
217 "MatMulTransformation");
221 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
222 const ngraph::element::Type precision,
223 const ngraph::Shape& inputShape,
224 const FakeQuantizeOnData& fqOnData,
225 const ngraph::Shape& weightsConstShape,
226 const std::vector<float>& weightsConstValues,
227 const FakeQuantizeOnWeights& fqOnWeights) {
228 const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
231 input->set_friendly_name("input1");
233 auto lastDequantization = makeFakeQuantize(input, precision, fqOnData);
235 const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
240 auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
242 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
247 matMul->set_friendly_name("matMul");
249 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
251 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
252 ngraph::ResultVector{ result },
253 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
254 "MatMulTransformation");
258 } // namespace subgraph
259 } // namespace builder
260 } // namespace ngraph