Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / include / ngraph / builder / make_constant.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include "autobroadcast.hpp"
20 #include "ngraph/node.hpp"
21 #include "ngraph/op/broadcast.hpp"
22 #include "ngraph/op/constant.hpp"
23 #include "ngraph/type/float16.hpp"
24
25 namespace ngraph
26 {
27     namespace builder
28     {
29         template <class T>
30         std::shared_ptr<Node>
31             make_constant(const element::Type& type, const Shape& shape, const T& num)
32         {
33             std::shared_ptr<Node> val = nullptr;
34
35 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
36 #pragma GCC diagnostic push
37 #pragma GCC diagnostic error "-Wswitch"
38 #pragma GCC diagnostic error "-Wswitch-enum"
39 #endif
40             switch (type)
41             {
42             case element::Type_t::f32:
43                 val = std::make_shared<ngraph::op::Constant>(
44                     type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)});
45                 break;
46             case element::Type_t::f64:
47                 val = std::make_shared<ngraph::op::Constant>(
48                     type, ngraph::Shape{}, std::vector<double>{static_cast<double>(num)});
49                 break;
50             case element::Type_t::f16:
51                 val = std::make_shared<ngraph::op::Constant>(
52                     type,
53                     ngraph::Shape{},
54                     std::vector<ngraph::float16>{ngraph::float16(static_cast<float>(num))});
55                 break;
56             case element::Type_t::bf16:
57                 val = std::make_shared<ngraph::op::Constant>(
58                     type,
59                     ngraph::Shape{},
60                     std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))});
61                 break;
62             case element::Type_t::i64:
63                 val = std::make_shared<ngraph::op::Constant>(
64                     type, ngraph::Shape{}, std::vector<int64_t>{static_cast<int64_t>(num)});
65                 break;
66             case element::Type_t::i32:
67                 val = std::make_shared<ngraph::op::Constant>(
68                     type, ngraph::Shape{}, std::vector<int32_t>{static_cast<int32_t>(num)});
69                 break;
70             case element::Type_t::i16:
71                 val = std::make_shared<ngraph::op::Constant>(
72                     type, ngraph::Shape{}, std::vector<int16_t>{static_cast<int16_t>(num)});
73                 break;
74             case element::Type_t::i8:
75                 val = std::make_shared<ngraph::op::Constant>(
76                     type, ngraph::Shape{}, std::vector<int8_t>{static_cast<int8_t>(num)});
77                 break;
78             case element::Type_t::u64:
79                 val = std::make_shared<ngraph::op::Constant>(
80                     type, ngraph::Shape{}, std::vector<uint64_t>{static_cast<uint64_t>(num)});
81                 break;
82             case element::Type_t::u32:
83                 val = std::make_shared<ngraph::op::Constant>(
84                     type, ngraph::Shape{}, std::vector<uint32_t>{static_cast<uint32_t>(num)});
85                 break;
86             case element::Type_t::u16:
87                 val = std::make_shared<ngraph::op::Constant>(
88                     type, ngraph::Shape{}, std::vector<uint16_t>{static_cast<uint16_t>(num)});
89                 break;
90             case element::Type_t::u8:
91                 val = std::make_shared<ngraph::op::Constant>(
92                     type, ngraph::Shape{}, std::vector<uint8_t>{static_cast<uint8_t>(num)});
93                 break;
94             case element::Type_t::dynamic:
95                 throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
96             case element::Type_t::boolean:
97                 throw ngraph_error("make_constant: Unsupported element type 'boolean'");
98             case element::Type_t::u1:
99                 throw ngraph_error("make_constant: Unsupported element type 'u1'");
100             case element::Type_t::undefined:
101                 throw ngraph_error("make_constant: Unsupported element type 'undefined'");
102             }
103 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
104 #pragma GCC diagnostic pop
105 #endif
106
107             if (shape.size() > 0)
108             {
109                 ngraph::AxisSet axes;
110                 for (size_t i = 0; i < shape.size(); i++)
111                 {
112                     axes.insert(i);
113                 }
114                 val = builder::opset1::make_broadcast(val, shape, axes).get_node_shared_ptr();
115             }
116
117             return val->add_provenance_group_members_above({});
118         }
119     }
120 }