From 7cda3bb275d54f466475fc16cec4345b850eab76 Mon Sep 17 00:00:00 2001 From: Liubov Batanina Date: Mon, 29 Jun 2020 23:07:08 +0300 Subject: [PATCH] Fixed Softmax reference (#1148) --- .../single_layer_tests/softmax.cpp | 32 ++++++++++++++++++++-- ngraph/test/runtime/opset0_downgrade.cpp | 12 -------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/softmax.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/softmax.cpp index 69af36d..bfb446c 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/softmax.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/softmax.cpp @@ -21,13 +21,15 @@ const std::vector inputLayouts2D = { const std::vector inputShapes2D = { InferenceEngine::SizeVector {1, 100}, + InferenceEngine::SizeVector {100, 1}, + InferenceEngine::SizeVector {10, 10}, }; const std::vector axis2D = { - 1 + 0, 1 }; -const auto params = testing::Combine( +const auto params2D = testing::Combine( testing::ValuesIn(netPrecisions), testing::ValuesIn(inputLayouts2D), testing::ValuesIn(inputShapes2D), @@ -39,7 +41,31 @@ const auto params = testing::Combine( INSTANTIATE_TEST_CASE_P( SoftMax2D, SoftMaxLayerTest, - params, + params2D, + SoftMaxLayerTest::getTestCaseName +); + +const std::vector inputShapes4D = { + InferenceEngine::SizeVector {1, 100, 1, 1}, + InferenceEngine::SizeVector {1, 3, 4, 3}, + InferenceEngine::SizeVector {2, 3, 4, 5}, +}; + +const std::vector axis4D = {0, 1, 2, 3}; + +const auto params4D = testing::Combine( + testing::ValuesIn(netPrecisions), + testing::Values(InferenceEngine::Layout::NCHW), + testing::ValuesIn(inputShapes4D), + testing::ValuesIn(axis4D), + testing::Values(CommonTestUtils::DEVICE_CPU), + testing::Values(std::map()) +); + +INSTANTIATE_TEST_CASE_P( + SoftMax4D, + SoftMaxLayerTest, + params4D, SoftMaxLayerTest::getTestCaseName ); diff --git a/ngraph/test/runtime/opset0_downgrade.cpp b/ngraph/test/runtime/opset0_downgrade.cpp index 679b7ad..31400a0 100644 --- a/ngraph/test/runtime/opset0_downgrade.cpp +++ b/ngraph/test/runtime/opset0_downgrade.cpp @@ -688,18 +688,6 @@ namespace return replacement_node; } - shared_ptr op_cast(shared_ptr node) - { - const auto axis = node->get_axis(); - const auto data = node->input(0); - const auto data_shape = data.get_shape(); - std::vector axes(data_shape.size() - axis); - std::iota(std::begin(axes), std::end(axes), axis); - const auto replacement_node = make_shared(node->input_value(0), axes); - replace_node(node, replacement_node); - return replacement_node; - } - shared_ptr op_cast(shared_ptr node) { const auto num_splits = node->get_num_splits(); -- 2.7.4