Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / topology.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include <cstdint>
20 #include "cldnn_defs.h"
21 #include "compounds.h"
22 #include "primitive.hpp"
23
24 namespace cldnn {
25
26 /// @addtogroup cpp_api C++ API
27 /// @{
28
29 /// @defgroup cpp_topology Network Topology
30 /// @{
31
32 /// @brief Network topology to be defined by user.
33 struct topology
34 {
35     /// @brief Constructs empty network topology.
36     topology()
37         : _impl(check_status<cldnn_topology>("failed to create topology", cldnn_create_topology))
38     {}
39
40     /// @brief Constructs topology containing primitives provided in argument(s).
41     template<class ...Args>
42     topology(const Args&... args)
43         : topology()
44     {
45         add<Args...>(args...);
46     }
47
48     /// @brief Copy construction.
49     topology(const topology& other) :_impl(other._impl)
50     {
51         retain();
52     }
53
54     /// @brief Copy assignment.
55     topology& operator=(const topology& other)
56     {
57         if (_impl == other._impl) return *this;
58         release();
59         _impl = other._impl;
60         retain();
61         return *this;
62     }
63
64     /// Construct C++ topology based on C API @p cldnn_topology
65     topology(const cldnn_topology& other) 
66         :_impl(other)
67     {
68         if (_impl == nullptr) throw std::invalid_argument("implementation pointer should not be null");
69     }
70
71     /// @brief Releases wrapped C API @ref cldnn_topology.
72     ~topology()
73     {
74         release();
75     }
76
77     friend bool operator==(const topology& lhs, const topology& rhs) { return lhs._impl == rhs._impl; }
78     friend bool operator!=(const topology& lhs, const topology& rhs) { return !(lhs == rhs); }
79
80     /// @brief Adds a primitive to topology.
81     template<class PType>
82     void add(PType const& desc)
83     {
84         check_status<void>("primitive add failed", [&](status_t* status) { cldnn_add_primitive(_impl, desc.get_dto(), status); });
85     }
86
87     /// @brief Adds primitives to topology.
88     template<class PType, class ...Args>
89     void add(PType const& desc, Args const&... args)
90     {
91         check_status<void>("primitive add failed", [&](status_t* status) { cldnn_add_primitive(_impl, desc.get_dto(), status); });
92         add<Args...>(args...);
93     }
94
95     /// @brief Returns wrapped C API @ref cldnn_topology.
96     cldnn_topology get() const { return _impl; }
97
98     const std::vector<primitive_id> get_primitive_ids() const
99     {
100         size_t size_ret = 0;
101         status_t err_invalid_arg = CLDNN_SUCCESS;
102         cldnn_get_primitive_ids(_impl, nullptr, 0, &size_ret, &err_invalid_arg);
103         assert(err_invalid_arg == CLDNN_INVALID_ARG);
104         assert(size_ret > 0);
105         std::vector<char> names_buf(size_ret);
106
107         check_status<void>("get topology ids failed", [&](status_t* status)
108         {
109             cldnn_get_primitive_ids(_impl, names_buf.data(), names_buf.size(), &size_ret, status);
110         });
111         assert(names_buf.size() == size_ret);
112
113         std::vector<primitive_id> result;
114         for (auto buf_ptr = names_buf.data(); *buf_ptr != 0; buf_ptr += result.back().size() + 1)
115         {
116             result.emplace_back(buf_ptr);
117         }
118         return result;
119     }
120
121     void change_input_layout(primitive_id id, layout new_layout)
122     {
123         check_status<void>("Change input layout failed.", [&](status_t* status)
124         {
125             cldnn_change_input_layout(_impl, id.c_str(), new_layout, status);
126         });
127     }
128
129 private:
130     friend struct engine;
131     friend struct network;
132     cldnn_topology _impl;
133
134     void retain()
135     {
136         check_status<void>("retain topology failed", [=](status_t* status) { cldnn_retain_topology(_impl, status); });
137     }
138     void release()
139     {
140         check_status<void>("retain topology failed", [=](status_t* status) { cldnn_release_topology(_impl, status); });
141     }
142 };
143
144 CLDNN_API_CLASS(topology)
145 /// @}
146 /// @}
147 }