[ONNX] Add type conversion for Pow op inputs (#2589)
authorMateusz Tabaka <mateusz.tabaka@intel.com>
Tue, 20 Oct 2020 09:19:03 +0000 (11:19 +0200)
committerGitHub <noreply@github.com>
Tue, 20 Oct 2020 09:19:03 +0000 (11:19 +0200)
Co-authored-by: mitruska <katarzyna.mitrus@intel.com>
ngraph/frontend/onnx_import/include/onnx_import/op/pow.hpp
ngraph/frontend/onnx_import/src/op/pow.cpp [new file with mode: 0644]
ngraph/python/tests/__init__.py
ngraph/python/tests/test_onnx/test_backend.py
ngraph/test/models/onnx/pow_float32_float32.prototxt [new file with mode: 0644]
ngraph/test/models/onnx/pow_float32_int32.prototxt [new file with mode: 0644]
ngraph/test/models/onnx/pow_int32_float32.prototxt [new file with mode: 0644]
ngraph/test/onnx/onnx_import.in.cpp

index 94e95e6..bf2ebe8 100644 (file)
 
 #pragma once
 
-#include <memory>
-
-#include "ngraph/node.hpp"
 #include "onnx_import/core/node.hpp"
-#include "onnx_import/default_opset.hpp"
 
 namespace ngraph
 {
@@ -30,11 +26,7 @@ namespace ngraph
         {
             namespace set_1
             {
-                inline OutputVector pow(const Node& node)
-                {
-                    return {std::make_shared<default_opset::Power>(node.get_ng_inputs().at(0),
-                                                                   node.get_ng_inputs().at(1))};
-                }
+                OutputVector pow(const Node& node);
 
             } // namespace set_1
 
diff --git a/ngraph/frontend/onnx_import/src/op/pow.cpp b/ngraph/frontend/onnx_import/src/op/pow.cpp
new file mode 100644 (file)
index 0000000..16c3906
--- /dev/null
@@ -0,0 +1,67 @@
+//*****************************************************************************
+// Copyright 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include <memory>
+
+#include "ngraph/node.hpp"
+#include "onnx_import/default_opset.hpp"
+#include "onnx_import/op/pow.hpp"
+
+namespace ngraph
+{
+    namespace onnx_import
+    {
+        namespace op
+        {
+            namespace set_1
+            {
+                OutputVector pow(const Node& node)
+                {
+                    auto inputs = node.get_ng_inputs();
+                    NGRAPH_CHECK(inputs.size() == 2,
+                                 "Power operation requires 2 inputs. Got: ",
+                                 inputs.size());
+
+                    auto base = inputs[0];
+                    auto exponent = inputs[1];
+                    auto base_type = inputs[0].get_element_type();
+                    auto exponent_type = inputs[1].get_element_type();
+                    if (exponent_type != base_type)
+                    {
+                        if (exponent_type.is_integral() ||
+                            (base_type.is_real() &&
+                             base_type.bitwidth() >= exponent_type.bitwidth()))
+                        {
+                            exponent =
+                                std::make_shared<default_opset::Convert>(exponent, base_type);
+                        }
+                        else
+                        {
+                            base = std::make_shared<default_opset::Convert>(base, exponent_type);
+                            auto power = std::make_shared<default_opset::Power>(base, exponent);
+                            return {std::make_shared<default_opset::Convert>(power, base_type)};
+                        }
+                    }
+                    return {std::make_shared<default_opset::Power>(base, exponent)};
+                }
+
+            } // namespace set_1
+
+        } // namespace op
+
+    } // namespace onnx_import
+
+} // namespace ngraph
index d926461..c2f10c8 100644 (file)
@@ -154,10 +154,6 @@ xfail_issue_38715 = xfail_test(reason="RuntimeError: While validating ONNX node
 xfail_issue_38717 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
                                       "GreaterOrEqual")
 xfail_issue_38719 = xfail_test(reason="nGraph does not support the following ONNX operations: GatherND")
-xfail_issue_38721 = xfail_test(reason="RuntimeError: While validating ONNX node '<Node(Pow): z>': "
-                                      "While validating node 'v1::Power Power_<number>"
-                                      "(x[0]:f32{3}, y[0]:i64{3}) -> (dynamic?)' with friendly_name "
-                                      "'Power_<number>': Argument element types are inconsistent.")
 xfail_issue_38722 = xfail_test(reason="RuntimeError: While validating ONNX nodes MatMulInteger"
                                       "and QLinearMatMul"
                                       "Input0 scale and input0 zero point shape must be same and 1")
index db0fb74..3fdc27a 100644 (file)
@@ -68,7 +68,6 @@ from tests import (BACKEND_NAME,
                    xfail_issue_33589,
                    xfail_issue_38719,
                    xfail_issue_33535,
-                   xfail_issue_38721,
                    xfail_issue_38722,
                    xfail_issue_38723,
                    xfail_issue_38724,
@@ -189,7 +188,11 @@ tests_expected_to_fail = [
         "OnnxBackendPyTorchConvertedModelTest.test_Embedding_sparse_cpu",
         "OnnxBackendNodeModelTest.test_constantofshape_int_shape_zero_cpu",
         "OnnxBackendNodeModelTest.test_max_int64_cpu",
+        "OnnxBackendNodeModelTest.test_pow_types_float32_int64_cpu",
+        "OnnxBackendNodeModelTest.test_pow_types_float_cpu",
+        "OnnxBackendNodeModelTest.test_pow_types_int64_float32_cpu",
         "OnnxBackendNodeModelTest.test_pow_types_int64_int64_cpu",
+        "OnnxBackendNodeModelTest.test_pow_types_int_cpu",
         "OnnxBackendNodeModelTest.test_min_int64_cpu",
         "OnnxBackendNodeModelTest.test_gather_negative_indices_cpu",
         "OnnxBackendNodeModelTest.test_scatternd_cpu"),
@@ -248,7 +251,8 @@ tests_expected_to_fail = [
         "OnnxBackendNodeModelTest.test_min_uint32_cpu"),
     (xfail_issue_36478,
         "OnnxBackendNodeModelTest.test_max_uint64_cpu",
-        "OnnxBackendNodeModelTest.test_min_uint64_cpu"),
+        "OnnxBackendNodeModelTest.test_min_uint64_cpu",
+        "OnnxBackendNodeModelTest.test_pow_types_float32_uint64_cpu"),
     (xfail_issue_36437,
         "OnnxBackendNodeModelTest.test_argmax_default_axis_example_cpu",
         "OnnxBackendNodeModelTest.test_argmax_default_axis_random_cpu",
@@ -273,7 +277,8 @@ tests_expected_to_fail = [
         "OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu",
         "OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_example_select_last_index_cpu",
         "OnnxBackendNodeModelTest.test_argmin_keepdims_example_select_last_index_cpu",
-        "OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu"),
+        "OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu",
+        "OnnxBackendNodeModelTest.test_pow_types_float32_uint32_cpu"),
     (xfail_issue_38088,
         "OnnxBackendPyTorchConvertedModelTest.test_GLU_cpu"),
     (xfail_issue_38089,
@@ -598,15 +603,6 @@ tests_expected_to_fail = [
         "OnnxBackendNodeModelTest.test_dynamicquantizelinear_min_adjusted_cpu",
         "OnnxBackendNodeModelTest.test_dynamicquantizelinear_cpu",
         "OnnxBackendNodeModelTest.test_dynamicquantizelinear_max_adjusted_cpu"),
-    (xfail_issue_38721,
-        "OnnxBackendNodeModelTest.test_pow_types_int_cpu",
-        "OnnxBackendNodeModelTest.test_pow_types_int64_float32_cpu",
-        "OnnxBackendNodeModelTest.test_pow_types_int32_float32_cpu",
-        "OnnxBackendNodeModelTest.test_pow_types_float_cpu",
-        "OnnxBackendNodeModelTest.test_pow_types_float32_uint64_cpu",
-        "OnnxBackendNodeModelTest.test_pow_types_float32_uint32_cpu",
-        "OnnxBackendNodeModelTest.test_pow_types_float32_int64_cpu",
-        "OnnxBackendNodeModelTest.test_pow_types_float32_int32_cpu"),
     (xfail_issue_38722,
         "OnnxBackendNodeModelTest.test_matmulinteger_cpu",
         "OnnxBackendNodeModelTest.test_qlinearmatmul_2D_cpu",
diff --git a/ngraph/test/models/onnx/pow_float32_float32.prototxt b/ngraph/test/models/onnx/pow_float32_float32.prototxt
new file mode 100644 (file)
index 0000000..4fb1c23
--- /dev/null
@@ -0,0 +1,60 @@
+ir_version: 7
+producer_name: "onnx-importer-test"
+graph {
+  node {
+    input: "X"
+    input: "N"
+    output: "Y"
+    op_type: "Pow"
+  }
+  name: "test-model-lstm"
+  input {
+    name: "X"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "N"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 1
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "Y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  domain: ""
+  version: 12
+}
diff --git a/ngraph/test/models/onnx/pow_float32_int32.prototxt b/ngraph/test/models/onnx/pow_float32_int32.prototxt
new file mode 100644 (file)
index 0000000..8100a44
--- /dev/null
@@ -0,0 +1,59 @@
+producer_name: "onnx-importer-test"
+graph {
+  node {
+    input: "X"
+    input: "N"
+    output: "Y"
+    op_type: "Pow"
+  }
+  name: "test-model-lstm"
+  input {
+    name: "X"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "N"
+    type {
+      tensor_type {
+        elem_type: 6
+        shape {
+          dim {
+            dim_value: 1
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "Y"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  domain: ""
+  version: 12
+}
diff --git a/ngraph/test/models/onnx/pow_int32_float32.prototxt b/ngraph/test/models/onnx/pow_int32_float32.prototxt
new file mode 100644 (file)
index 0000000..901164c
--- /dev/null
@@ -0,0 +1,60 @@
+ir_version: 7
+producer_name: "onnx-importer-test"
+graph {
+  node {
+    input: "X"
+    input: "N"
+    output: "Y"
+    op_type: "Pow"
+  }
+  name: "test-model-lstm"
+  input {
+    name: "X"
+    type {
+      tensor_type {
+        elem_type: 6
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "N"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 1
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "Y"
+    type {
+      tensor_type {
+        elem_type: 6
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  domain: ""
+  version: 12
+}
index e254423..e98ef1a 100644 (file)
@@ -2300,6 +2300,48 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pad_constant)
     test_case.run();
 }
 
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_float32_float32)
+{
+    const auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/pow_float32_float32.prototxt"));
+    auto test_case = test::TestCase<TestEngine>(function);
+
+    test_case.add_input<float>({1.f, 2.f, 3.f, 4.f}); // base
+    test_case.add_input<float>({3.5f});               // exponent
+
+    test_case.add_expected_output<float>(Shape{1, 4}, {1.f, 11.313708f, 46.765373f, 128.f});
+
+    test_case.run();
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_float32_int32)
+{
+    const auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/pow_float32_int32.prototxt"));
+    auto test_case = test::TestCase<TestEngine>(function);
+
+    test_case.add_input<float>({1.f, 2.f, 3.f, 4.f}); // base
+    test_case.add_input<int>({3});                    // exponent
+
+    test_case.add_expected_output<float>(Shape{1, 4}, {1.f, 8.f, 27.f, 64.f});
+
+    test_case.run();
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_int32_float32)
+{
+    const auto function = onnx_import::import_onnx_model(
+        file_util::path_join(SERIALIZED_ZOO, "onnx/pow_int32_float32.prototxt"));
+    auto test_case = test::TestCase<TestEngine>(function);
+
+    test_case.add_input<int>({1, 2, 3, 4}); // base
+    test_case.add_input<float>({3.5f});     // exponent
+
+    test_case.add_expected_output<int>(Shape{1, 4}, {1, 11, 46, 128});
+
+    test_case.run();
+}
+
 NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reciprocal)
 {
     const auto function = onnx_import::import_onnx_model(