1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
19 #include "autobroadcast.hpp"
20 #include "ngraph/node.hpp"
21 #include "ngraph/op/broadcast.hpp"
22 #include "ngraph/op/constant.hpp"
23 #include "ngraph/type/float16.hpp"
31 make_constant(const element::Type& type, const Shape& shape, const T& num)
33 std::shared_ptr<Node> val = nullptr;
35 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
36 #pragma GCC diagnostic push
37 #pragma GCC diagnostic error "-Wswitch"
38 #pragma GCC diagnostic error "-Wswitch-enum"
42 case element::Type_t::f32:
43 val = std::make_shared<ngraph::op::Constant>(
44 type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)});
46 case element::Type_t::f64:
47 val = std::make_shared<ngraph::op::Constant>(
48 type, ngraph::Shape{}, std::vector<double>{static_cast<double>(num)});
50 case element::Type_t::f16:
51 val = std::make_shared<ngraph::op::Constant>(
54 std::vector<ngraph::float16>{ngraph::float16(static_cast<float>(num))});
56 case element::Type_t::bf16:
57 val = std::make_shared<ngraph::op::Constant>(
60 std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))});
62 case element::Type_t::i64:
63 val = std::make_shared<ngraph::op::Constant>(
64 type, ngraph::Shape{}, std::vector<int64_t>{static_cast<int64_t>(num)});
66 case element::Type_t::i32:
67 val = std::make_shared<ngraph::op::Constant>(
68 type, ngraph::Shape{}, std::vector<int32_t>{static_cast<int32_t>(num)});
70 case element::Type_t::i16:
71 val = std::make_shared<ngraph::op::Constant>(
72 type, ngraph::Shape{}, std::vector<int16_t>{static_cast<int16_t>(num)});
74 case element::Type_t::i8:
75 val = std::make_shared<ngraph::op::Constant>(
76 type, ngraph::Shape{}, std::vector<int8_t>{static_cast<int8_t>(num)});
78 case element::Type_t::u64:
79 val = std::make_shared<ngraph::op::Constant>(
80 type, ngraph::Shape{}, std::vector<uint64_t>{static_cast<uint64_t>(num)});
82 case element::Type_t::u32:
83 val = std::make_shared<ngraph::op::Constant>(
84 type, ngraph::Shape{}, std::vector<uint32_t>{static_cast<uint32_t>(num)});
86 case element::Type_t::u16:
87 val = std::make_shared<ngraph::op::Constant>(
88 type, ngraph::Shape{}, std::vector<uint16_t>{static_cast<uint16_t>(num)});
90 case element::Type_t::u8:
91 val = std::make_shared<ngraph::op::Constant>(
92 type, ngraph::Shape{}, std::vector<uint8_t>{static_cast<uint8_t>(num)});
94 case element::Type_t::dynamic:
95 throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
96 case element::Type_t::boolean:
97 throw ngraph_error("make_constant: Unsupported element type 'boolean'");
98 case element::Type_t::u1:
99 throw ngraph_error("make_constant: Unsupported element type 'u1'");
100 case element::Type_t::undefined:
101 throw ngraph_error("make_constant: Unsupported element type 'undefined'");
103 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
104 #pragma GCC diagnostic pop
107 if (shape.size() > 0)
109 ngraph::AxisSet axes;
110 for (size_t i = 0; i < shape.size(); i++)
114 val = builder::opset1::make_broadcast(val, shape, axes).get_node_shared_ptr();
117 return val->add_provenance_group_members_above({});