1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ngraph_functions/low_precision_transformations/mat_mul_function.hpp"
10 #include <ngraph/opsets/opset1.hpp>
11 #include "ngraph_ops/type_relaxed.hpp"
12 #include "ngraph_functions/subgraph_builders.hpp"
13 #include "low_precision/network_helper.hpp"
14 #include "ngraph_functions/low_precision_transformations/common/builders.hpp"
20 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
21 const ngraph::element::Type precision,
22 const ngraph::Shape& inputShape1,
23 const FakeQuantizeOnData& fqOnData1,
24 const ngraph::Shape& inputShape2,
25 const FakeQuantizeOnData& fqOnData2) {
26 const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape1);
27 input1->set_friendly_name("input1");
29 const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
30 input2->set_friendly_name("input2");
32 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
33 makeFakeQuantize(input1, precision, fqOnData1),
34 makeFakeQuantize(input2, precision, fqOnData2),
37 matMul->set_friendly_name("matMul");
39 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
41 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
42 ngraph::ResultVector{ result },
43 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
44 "MatMulTransformation");
48 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
49 const ngraph::Shape& inputShape1,
50 const ngraph::element::Type precisionBeforeDequantization1,
51 const DequantizationOperations& dequantization1,
52 const ngraph::Shape& inputShape2,
53 const ngraph::element::Type precisionBeforeDequantization2,
54 const DequantizationOperations& dequantization2) {
55 if (!dequantization1.convert.empty() && (precisionBeforeDequantization1 == dequantization1.convert.outPrecision)) {
56 THROW_IE_EXCEPTION << "unexpected input arguments for branch 1";
59 if (!dequantization2.convert.empty() && (precisionBeforeDequantization2 == dequantization2.convert.outPrecision)) {
60 THROW_IE_EXCEPTION << "unexpected input arguments for branch 2";
63 const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization1, inputShape1);
64 input1->set_friendly_name("input1");
66 const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization2, inputShape2);
67 input2->set_friendly_name("input2");
69 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
70 makeDequantization(input1, dequantization1),
71 makeDequantization(input2, dequantization2),
74 matMul->set_friendly_name("matMul");
76 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
78 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
79 ngraph::ResultVector{ result },
80 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
81 "MatMulTransformation");
85 std::shared_ptr<ngraph::Function> getOriginalWithConstant2(
86 const ngraph::element::Type precision) {
90 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
91 const ngraph::element::Type precision,
92 const ngraph::Shape& inputShape,
93 const ngraph::element::Type precisionBeforeDequantization,
94 const DequantizationOperations& dequantizationOperations,
95 const ngraph::Shape& weightsConstShape,
96 const std::vector<float>& weightsConstValues,
97 const FakeQuantizeOnWeights& fqOnWeights) {
98 const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
99 precisionBeforeDequantization,
101 input->set_friendly_name("input1");
103 auto lastDequantization = makeDequantization(input, dequantizationOperations);
105 const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
110 auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
112 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
117 matMul->set_friendly_name("matMul");
119 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
121 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
122 ngraph::ResultVector{ result },
123 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
124 "MatMulTransformation");
128 std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
129 const ngraph::element::Type precision,
130 const ngraph::Shape& inputShape1,
131 const ngraph::element::Type precisionBeforeDequantization1,
132 const DequantizationOperations& dequantization1,
133 const ngraph::Shape& inputShape2,
134 const ngraph::element::Type precisionBeforeDequantization2,
135 const DequantizationOperations& dequantization2,
136 const DequantizationOperations& resultDequantizationOperations) {
137 if (!dequantization1.convert.empty() && (precisionBeforeDequantization1 == dequantization1.convert.outPrecision)) {
138 THROW_IE_EXCEPTION << "unexpected input arguments for branch 1";
141 if (!dequantization2.convert.empty() && (precisionBeforeDequantization2 == dequantization2.convert.outPrecision)) {
142 THROW_IE_EXCEPTION << "unexpected input arguments for branch 2";
145 const std::shared_ptr<ngraph::opset1::Parameter> input1 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization1, inputShape1);
146 input1->set_friendly_name("input1");
148 const std::shared_ptr<ngraph::opset1::Parameter> input2 = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization2, inputShape2);
149 input2->set_friendly_name("input2");
151 auto dequantization1Op = makeDequantization(input1, dequantization1);
152 auto dequantization2Op = makeDequantization(input2, dequantization2);
154 std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::MatMul>>(
155 std::vector<element::Type>{ element::f32, element::f32 }, std::vector<element::Type>{},
156 ngraph::op::TemporaryReplaceOutputType(dequantization1Op, element::f32).get(),
157 ngraph::op::TemporaryReplaceOutputType(dequantization2Op, element::f32).get(),
161 matMul->set_friendly_name("matMul");
162 ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(matMul, precision);
163 auto dequantizationAfter = makeDequantization(matMul, resultDequantizationOperations);
164 dequantizationAfter->set_friendly_name("matMul");
166 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(dequantizationAfter);
168 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
169 ngraph::ResultVector{ result },
170 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input1, input2 },
171 "MatMulTransformation");
175 std::shared_ptr<ngraph::Function> MatMulFunction::getReference(
176 const ngraph::element::Type precision,
177 const ngraph::Shape& inputShape,
178 const ngraph::element::Type precisionBeforeDequantization,
179 const DequantizationOperations& dequantization,
180 const ngraph::element::Type weightsConstPrecision,
181 const ngraph::Shape& weightsConstShape,
182 const std::vector<float>& weightsConstValues,
183 const DequantizationOperations& resultDequantization) {
184 const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
185 precisionBeforeDequantization,
187 input->set_friendly_name("input1");
189 const std::shared_ptr<ngraph::Node> lastDequantizationBefore = makeDequantization(input, dequantization);
191 const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
192 weightsConstPrecision,
196 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::MatMul>>(
197 std::vector<ngraph::element::Type>{ element::f32, element::f32 }, std::vector<ngraph::element::Type>{},
198 ngraph::op::TemporaryReplaceOutputType(lastDequantizationBefore, element::f32).get(),
199 ngraph::op::TemporaryReplaceOutputType(weightsConst, element::f32).get(),
202 matMul->set_friendly_name("matMul");
203 ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(matMul, precision);
205 const std::shared_ptr<ngraph::Node> lastDequantizationAfter = makeDequantization(matMul, resultDequantization);
206 lastDequantizationAfter->set_friendly_name("matMul");
208 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(lastDequantizationAfter);
210 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
211 ngraph::ResultVector{ result },
212 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
213 "MatMulTransformation");
217 std::shared_ptr<ngraph::Function> MatMulFunction::getOriginal(
218 const ngraph::element::Type precision,
219 const ngraph::Shape& inputShape,
220 const FakeQuantizeOnData& fqOnData,
221 const ngraph::Shape& weightsConstShape,
222 const std::vector<float>& weightsConstValues,
223 const FakeQuantizeOnWeights& fqOnWeights) {
224 const std::shared_ptr<ngraph::opset1::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
227 input->set_friendly_name("input1");
229 auto lastDequantization = makeFakeQuantize(input, precision, fqOnData);
231 const std::shared_ptr<ngraph::opset1::Constant> weightsConst = std::make_shared<ngraph::opset1::Constant>(
236 auto fakeQuantize = makeFakeQuantize(weightsConst, precision, fqOnWeights);
238 const std::shared_ptr<ngraph::opset1::MatMul> matMul = std::make_shared<ngraph::opset1::MatMul>(
243 matMul->set_friendly_name("matMul");
245 std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(matMul);
247 std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
248 ngraph::ResultVector{ result },
249 std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
250 "MatMulTransformation");
254 } // namespace subgraph
255 } // namespace builder
256 } // namespace ngraph