Use LogSoftmax-5 in the onnx_importer (#2602)
authorTomasz Dołbniak <tomasz.dolbniak@intel.com>
Wed, 21 Oct 2020 08:50:16 +0000 (10:50 +0200)
committerGitHub <noreply@github.com>
Wed, 21 Oct 2020 08:50:16 +0000 (10:50 +0200)
13 files changed:
ngraph/frontend/onnx_import/include/onnx_import/op/log_softmax.hpp
ngraph/frontend/onnx_import/src/op/log_softmax.cpp
ngraph/frontend/onnx_import/src/ops_bridge.cpp
ngraph/python/tests/__init__.py
ngraph/python/tests/test_onnx/test_backend.py
ngraph/python/tests/test_onnx/test_ops_unary.py
ngraph/test/models/onnx/logsoftmax13_1D.prototxt [new file with mode: 0644]
ngraph/test/models/onnx/logsoftmax13_2D.prototxt [new file with mode: 0644]
ngraph/test/models/onnx/logsoftmax_0D.prototxt [new file with mode: 0644]
ngraph/test/models/onnx/logsoftmax_1D.prototxt [new file with mode: 0644]
ngraph/test/onnx/onnx_import.in.cpp
ngraph/test/runtime/ie/unit_test.manifest
ngraph/test/runtime/interpreter/unit_test.manifest

index 76c251d..aaf66d5 100644 (file)
@@ -31,6 +31,12 @@ namespace ngraph
 
             } // namespace set_1
 
+            namespace set_13
+            {
+                OutputVector log_softmax(const Node& node);
+
+            } // namespace set_1
+
         } // namespace op
 
     } // namespace onnx_import
index d7aec43..c19ca2b 100644 (file)
@@ -17,6 +17,7 @@
 #include <memory>
 
 #include "log_softmax.hpp"
+#include "ngraph/builder/reshape.hpp"
 #include "ngraph/validation_util.hpp"
 #include "onnx_import/default_opset.hpp"
 
@@ -24,25 +25,82 @@ namespace ngraph
 {
     namespace onnx_import
     {
-        namespace op
+        namespace detail
         {
-            namespace set_1
+            std::shared_ptr<ngraph::Node> onnx_logsoftmax(const Output<ngraph::Node> data,
+                                                          const int64_t axis)
             {
-                OutputVector log_softmax(const Node& node)
+                const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
+
+                const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
+                const auto max =
+                    std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1, true);
+
+                const auto data_minus_max =
+                    std::make_shared<default_opset::Subtract>(coerced_data, max);
+
+                const auto result = std::make_shared<default_opset::LogSoftmax>(data_minus_max, 1);
+                if (data.get_partial_shape().is_static())
+                {
+                    return ngraph::builder::opset1::reshape(result, data.get_shape());
+                }
+                else
                 {
-                    OutputVector inputs{node.get_ng_inputs()};
-                    const auto data = inputs.at(0);
-                    const auto data_rank = data.get_partial_shape().rank();
+                    const auto data_shape = std::make_shared<default_opset::ShapeOf>(data);
+                    return std::make_shared<default_opset::Reshape>(result, data_shape, false);
+                }
+            }
+
+            OutputVector log_softmax(const Node& node, const int64_t DEFAULT_AXIS)
+            {
+                OutputVector inputs{node.get_ng_inputs()};
+                const auto data = inputs.at(0);
+                const auto data_rank = data.get_partial_shape().rank();
+
+                NGRAPH_CHECK(data_rank.is_static(),
+                             "ONNX Softmax data rank needs to be known (static)");
 
-                    const auto axis = node.get_attribute_value<int64_t>("axis", 1);
+                const auto axis = node.get_attribute_value<int64_t>("axis", DEFAULT_AXIS);
+
+                std::shared_ptr<ngraph::Node> result;
+                switch (data_rank.get_length())
+                {
+                case 0:
+                {
+                    result = default_opset::Constant::create(data.get_element_type(), Shape{}, {1});
+                    break;
+                }
+                case 1:
+                {
+                    // checks if the axis belongs to the allowed values set (-1 and 0 for 1D)
+                    ngraph::normalize_axis(node.get_description(), axis, data_rank);
+                    result = std::make_shared<default_opset::LogSoftmax>(data, 0);
+                    break;
+                }
+                default:
+                {
                     const auto normalized_axis =
                         ngraph::normalize_axis(node.get_description(), axis, data_rank);
 
-                    const auto softmax =
-                        std::make_shared<default_opset::Softmax>(data, normalized_axis);
-                    return {std::make_shared<default_opset::Log>(softmax)};
+                    result = onnx_logsoftmax(data, normalized_axis);
+                    break;
                 }
+                }
+
+                return {result};
+            }
+        }
 
+        namespace op
+        {
+            namespace set_1
+            {
+                OutputVector log_softmax(const Node& node) { return detail::log_softmax(node, 1); }
+            } // namespace set_1
+
+            namespace set_13
+            {
+                OutputVector log_softmax(const Node& node) { return detail::log_softmax(node, -1); }
             } // namespace set_1
 
         } // namespace op
index 2e896cd..fabfe58 100644 (file)
@@ -360,6 +360,7 @@ namespace ngraph
             REGISTER_OPERATOR("Less", 1, less);
             REGISTER_OPERATOR("Log", 1, log);
             REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
+            REGISTER_OPERATOR("LogSoftmax", 13, log_softmax);
             // REGISTER_OPERATOR("Loop", 1, loop); // Loop operator disabled for the 2021.1 release
             REGISTER_OPERATOR("LpNormalization", 1, lp_norm);
             REGISTER_OPERATOR("LRN", 1, lrn);
index c2f10c8..4640c8e 100644 (file)
@@ -91,7 +91,6 @@ xfail_issue_35927 = xfail_test(reason="RuntimeError: B has zero dimension that i
 xfail_issue_35929 = xfail_test(reason="RuntimeError: Incorrect precision f64!")
 xfail_issue_35930 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError: "
                                       "Required attribute 'to' is missing.")
-xfail_issue_35932 = xfail_test(reason="Assertion error - logsoftmax results mismatch")
 xfail_issue_36437 = xfail_test(reason="RuntimeError: Cannot find blob with name: <value>")
 xfail_issue_36476 = xfail_test(reason="RuntimeError: [NOT_IMPLEMENTED] Input image format U32 is "
                                "not supported yet...")
index 3fdc27a..bdd021e 100644 (file)
@@ -344,10 +344,7 @@ tests_expected_to_fail = [
     (xfail_issue_38091,
         "OnnxBackendNodeModelTest.test_round_cpu",
         "OnnxBackendNodeModelTest.test_mvn_cpu",
-        "OnnxBackendNodeModelTest.test_elu_example_cpu",
-        "OnnxBackendNodeModelTest.test_logsoftmax_axis_0_cpu",
-        "OnnxBackendNodeModelTest.test_logsoftmax_axis_1_cpu",
-        "OnnxBackendNodeModelTest.test_logsoftmax_default_axis_cpu"),
+        "OnnxBackendNodeModelTest.test_elu_example_cpu"),
     (xfail_issue_35929,
         "OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_broadcast_cpu",
         "OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_singleton_broadcast_cpu",
index 77169e0..f9e5693 100644 (file)
@@ -24,8 +24,7 @@ from tests.runtime import get_runtime
 from tests.test_onnx.utils import get_node_model, import_onnx_model, run_model, run_node
 from tests import (xfail_issue_35929,
                    xfail_issue_34323,
-                   xfail_issue_35930,
-                   xfail_issue_35932)
+                   xfail_issue_35930)
 
 
 @pytest.mark.parametrize(
@@ -285,7 +284,6 @@ def test_softmax():
         ng_results = run_node(node, [data])
 
 
-@xfail_issue_35932
 def test_logsoftmax():
     def logsoftmax_2d(x):
         max_x = np.max(x, axis=1).reshape((-1, 1))
diff --git a/ngraph/test/models/onnx/logsoftmax13_1D.prototxt b/ngraph/test/models/onnx/logsoftmax13_1D.prototxt
new file mode 100644 (file)
index 0000000..b75dd3f
--- /dev/null
@@ -0,0 +1,39 @@
+ir_version: 7
+producer_name: "nGraph ONNX Importer"
+graph {
+  node {
+    input: "data"
+    output: "y"
+    op_type: "LogSoftmax"
+  }
+  name: "LogSoftmax test"
+  input {
+    name: "data"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 3
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 3
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 13
+}
diff --git a/ngraph/test/models/onnx/logsoftmax13_2D.prototxt b/ngraph/test/models/onnx/logsoftmax13_2D.prototxt
new file mode 100644 (file)
index 0000000..7117079
--- /dev/null
@@ -0,0 +1,45 @@
+ir_version: 3
+producer_name: "nGraph ONNX Importer"
+graph {
+  node {
+    input: "x"
+    output: "y"
+    op_type: "LogSoftmax"
+  }
+  name: "LogSoftmax test"
+  input {
+    name: "x"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 13
+}
diff --git a/ngraph/test/models/onnx/logsoftmax_0D.prototxt b/ngraph/test/models/onnx/logsoftmax_0D.prototxt
new file mode 100644 (file)
index 0000000..cd9526b
--- /dev/null
@@ -0,0 +1,37 @@
+ir_version: 7
+producer_name: "nGraph ONNX Importer"
+graph {
+  node {
+    input: "data"
+    output: "y"
+    op_type: "LogSoftmax"
+  }
+  name: "LogSoftmax test"
+  input {
+    name: "data"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 1
+}
diff --git a/ngraph/test/models/onnx/logsoftmax_1D.prototxt b/ngraph/test/models/onnx/logsoftmax_1D.prototxt
new file mode 100644 (file)
index 0000000..46799d2
--- /dev/null
@@ -0,0 +1,44 @@
+ir_version: 7
+producer_name: "nGraph ONNX Importer"
+graph {
+  node {
+    input: "data"
+    output: "y"
+    op_type: "LogSoftmax"
+    attribute {
+      name: "axis"
+      i: 0
+      type: INT
+    }
+  }
+  name: "LogSoftmax test"
+  input {
+    name: "data"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 3
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 3
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 1
+}
index e98ef1a..1d45332 100644 (file)
@@ -2719,3 +2719,55 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_group_norm)
     test_case.add_expected_output<float>(shape, output);
     test_case.run();
 }
+
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax_0D)
+{
+    auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_0D.prototxt"));
+
+    auto test_case = test::TestCase<TestEngine>(function);
+    test_case.add_input<float>({3.141592});
+    test_case.add_expected_output<float>({0.0});
+    test_case.run();
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax_1D)
+{
+    const auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax_1D.prototxt"));
+    auto test_case = test::TestCase<TestEngine>(function);
+
+    test_case.add_input<float>({-1.0f, 0.0f, 1.0f});
+    test_case.add_expected_output<float>(Shape{3}, {-2.4076061, -1.407606, -0.407606});
+    test_case.run();
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_1D)
+{
+    const auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_1D.prototxt"));
+    auto test_case = test::TestCase<TestEngine>(function);
+
+    test_case.add_input<float>({-1.0f, 0.0f, 1.0f});
+    test_case.add_expected_output<float>(Shape{3}, {-2.4076061, -1.407606, -0.407606});
+    test_case.run();
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_logsoftmax13_2D)
+{
+    const auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/logsoftmax13_2D.prototxt"));
+    auto test_case = test::TestCase<TestEngine>(function);
+
+    test_case.add_input<float>({0.0f, 1.0f, 2.0f, 3.0f, 10000, 10001, 10002, 10003});
+    test_case.add_expected_output<float>(Shape{2, 4},
+                                         {-3.4401896,
+                                          -2.4401896,
+                                          -1.4401896,
+                                          -0.44018966,
+                                          -3.4401896,
+                                          -2.4401896,
+                                          -1.4401896,
+                                          -0.44018966});
+    test_case.run_with_tolerance_as_fp();
+}
index 2b8e949..fd516f5 100644 (file)
@@ -69,6 +69,7 @@ bool_const_op
 onnx_model_tile
 onnx_model_tile_static
 onnx_model_softmax_0D
+onnx_model_logsoftmax_0D
 builder_opset1_collapse_none
 
 # nGraph function's output number 0 was not found in the CNNNetwork built from it.
index 5f7d0b5..35b3218 100644 (file)
@@ -141,3 +141,7 @@ lstm_cell_bias_peepholes_clip_input_forget
 
 # unsupported element type f16 
 INTERPRETER.ctc_greedy_decoder_f16
+
+# LogSoftmax's reference implementation doesn't handle scalar input properly
+onnx_model_logsoftmax_0D
+