#include "common_test_utils/ngraph_test_utils.hpp"
#include "simple_low_precision_transformer.hpp"
#include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
+#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
+
using namespace testing;
using namespace ngraph::pass;
class MaxPoolTransformationTestValues {
public:
- low_precision::LayerTransformation::Params params;
- std::vector<float> subtractValues;
- std::vector<float> mutliplyValues;
+ class Actual {
+ public:
+ ngraph::element::Type precisionBeforeDequantization;
+ ngraph::builder::subgraph::DequantizationOperations dequantization1;
+ ngraph::builder::subgraph::DequantizationOperations dequantization2;
+ };
+
+ class Expected {
+ public:
+ ngraph::element::Type precisionBeforeDequantization;
+ ngraph::builder::subgraph::DequantizationOperations dequantization1;
+ ngraph::builder::subgraph::DequantizationOperations dequantization2;
+ };
+
+ ngraph::pass::low_precision::LayerTransformation::Params params;
+ Actual actual;
+ Expected expected;
};
typedef std::tuple<
- ngraph::element::Type,
ngraph::Shape,
MaxPoolTransformationTestValues> MaxPoolTransformationParams;
class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
public:
void SetUp() override {
- const ngraph::element::Type precision = std::get<0>(GetParam());
- const ngraph::Shape shape = std::get<1>(GetParam());
- const MaxPoolTransformationTestValues testValues = std::get<2>(GetParam());
+ const ngraph::Shape shape = std::get<0>(GetParam());
+ const MaxPoolTransformationTestValues testValues = std::get<1>(GetParam());
- actualFunction = ngraph::builder::subgraph::MaxPoolFunction::getOriginal(
- precision,
+ actualFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
shape,
- {
- testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
- testValues.subtractValues,
- testValues.mutliplyValues
- });
+ testValues.actual.precisionBeforeDequantization,
+ testValues.actual.dequantization1,
+ testValues.actual.dequantization2);
SimpleLowPrecisionTransformer transform;
transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
transform.transform(actualFunction);
- referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::getReference(
- precision,
+ referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
shape,
- {
- testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
- testValues.subtractValues,
- testValues.mutliplyValues
- });
+ testValues.expected.precisionBeforeDequantization,
+ testValues.expected.dequantization1,
+ testValues.expected.dequantization2);
}
static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
- const ngraph::element::Type precision = std::get<0>(obj.param);
- const ngraph::Shape shape = std::get<1>(obj.param);
- const MaxPoolTransformationTestValues testValues = std::get<2>(obj.param);
+ const ngraph::Shape shape = std::get<0>(obj.param);
+ const MaxPoolTransformationTestValues testValues = std::get<1>(obj.param);
std::ostringstream result;
result <<
- LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
- testValues.subtractValues.size() << "_" <<
- testValues.mutliplyValues.size() << "_";
+ LayerTransformation::getTestCaseNameByParams(testValues.actual.precisionBeforeDequantization, shape, testValues.params) << "_" <<
+ testValues.actual.dequantization1 << "_" <<
+ testValues.actual.dequantization2 << "_" <<
+ testValues.expected.dequantization1 << "_" <<
+ testValues.expected.dequantization2 << "_";
return result.str();
}
};
TEST_P(MaxPoolTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
- auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
+ auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
ASSERT_TRUE(res.first) << res.second;
}
-const std::vector<ngraph::element::Type> precisions = {
- ngraph::element::f32,
- // ngraph::element::f16
-};
-
const std::vector<ngraph::Shape> shapes = {
- { 1, 32, 72, 48 }
+ { 1, 32, 72, 48 },
+ { 4, 32, 72, 48 }
};
const std::vector<MaxPoolTransformationTestValues> testValues = {
- { LayerTransformation::createParamsU8I8(), { 128 }, { 0.02f } },
- { LayerTransformation::createParamsU8I8(), {}, { 0.02f } },
- { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), { 128 }, { 0.02f } },
- { LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), {}, { 0.02f } },
- { LayerTransformation::createParamsI8I8(), { 128 }, { 0.02f } },
+ // Multiply
+ {
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ { {}, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }}
+ }
+ },
+ // Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ {
+ {},
+ { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
+ { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
+ },
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ {
+ ngraph::element::f32,
+ { {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
+ { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
+ }
+ }
+ },
+ // Convert + Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ { ngraph::element::f32, { 128 }, { 0.02f }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, { 128 }, { 0.02f }}
+ }
+ },
+ // Convert + Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8(),
+ {
+ ngraph::element::u8,
+ { ngraph::element::f32, {}, { 0.02f }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, {}, { 0.02f }}
+ }
+ },
+ // Convert + Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
+ {
+ ngraph::element::u8,
+ { ngraph::element::f32, { 128 }, { 0.02f }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, { 128 }, { 0.02f }}
+ }
+ },
+ // Convert + Subtract + Multiply
+ {
+ LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
+ {
+ ngraph::element::u8,
+ { ngraph::element::f32, {}, { 0.02f }},
+ {}
+ },
+ {
+ ngraph::element::u8,
+ {},
+ { ngraph::element::f32, {}, { 0.02f }}
+ }
+ }
};
INSTANTIATE_TEST_CASE_P(
LPT,
MaxPoolTransformation,
::testing::Combine(
- ::testing::ValuesIn(precisions),
::testing::ValuesIn(shapes),
::testing::ValuesIn(testValues)),
MaxPoolTransformation::getTestCaseName);