Alpha and beta nodes element types fix (#1150)
authorTomasz Dołbniak <tomasz.dolbniak@intel.com>
Tue, 30 Jun 2020 10:04:11 +0000 (12:04 +0200)
committerGitHub <noreply@github.com>
Tue, 30 Jun 2020 10:04:11 +0000 (12:04 +0200)
ngraph/src/ngraph/frontend/onnx_import/op/gemm.cpp

index 20c240b..4772bf3 100644 (file)
@@ -53,9 +53,9 @@ namespace ngraph
                     const auto beta = node.get_attribute_value<float>("beta", 1);
 
                     const auto alpha_node = default_opset::Constant::create(
-                        element::Type_t::f32, Shape{}, std::vector<float>{alpha});
+                        input_b->get_element_type(), Shape{}, {alpha});
                     const auto beta_node = default_opset::Constant::create(
-                        element::Type_t::f32, Shape{}, std::vector<float>{beta});
+                        input_c->get_element_type(), Shape{}, {beta});
 
                     const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
                     const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
@@ -114,9 +114,9 @@ namespace ngraph
                     const auto beta = node.get_attribute_value<float>("beta", 1);
 
                     const auto alpha_node = default_opset::Constant::create(
-                        element::Type_t::f32, Shape{}, std::vector<float>{alpha});
+                        input_b->get_element_type(), Shape{}, {alpha});
                     const auto beta_node = default_opset::Constant::create(
-                        element::Type_t::f32, Shape{}, std::vector<float>{beta});
+                        input_c->get_element_type(), Shape{}, {beta});
 
                     const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
                     const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);