From 2a96917e2acd7f3a5521cfd5117fec44d5add6f9 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 28 Jul 2020 14:01:13 +0200 Subject: [PATCH] Treat 1d single-element tensors as scalars. (#1498) --- ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp b/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp index 75a0e84..7468487 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp @@ -75,10 +75,14 @@ namespace ngraph const std::shared_ptr input, const std::set allowed_types) { - const auto validated_input_rank = input->get_output_partial_shape(0).rank(); + const auto validated_input_shape = input->get_output_partial_shape(0); + const auto validated_input_rank = validated_input_shape.rank(); - NGRAPH_CHECK( - validated_input_rank.same_scheme({0}), input_name, " needs to be a scalar."); + NGRAPH_CHECK(validated_input_rank.same_scheme({0}) || + (validated_input_rank.same_scheme({1}) && + validated_input_shape[0].get_length() == 1), + input_name, + " needs to be a scalar or 1D, single-element tensor."); if (!allowed_types.empty()) { -- 2.7.4