const auto beta = node.get_attribute_value<float>("beta", 1);
const auto alpha_node = default_opset::Constant::create(
- element::Type_t::f32, Shape{}, std::vector<float>{alpha});
+ input_b->get_element_type(), Shape{}, {alpha});
const auto beta_node = default_opset::Constant::create(
- element::Type_t::f32, Shape{}, std::vector<float>{beta});
+ input_c->get_element_type(), Shape{}, {beta});
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);
const auto beta = node.get_attribute_value<float>("beta", 1);
const auto alpha_node = default_opset::Constant::create(
- element::Type_t::f32, Shape{}, std::vector<float>{alpha});
+ input_b->get_element_type(), Shape{}, {alpha});
const auto beta_node = default_opset::Constant::create(
- element::Type_t::f32, Shape{}, std::vector<float>{beta});
+ input_c->get_element_type(), Shape{}, {beta});
const bool trans_a = node.get_attribute_value<int64_t>("transA", 0);
const bool trans_b = node.get_attribute_value<int64_t>("transB", 0);