1 // Copyright (c) 2018 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "../C/condition.h"
19 #include "primitive.hpp"
20 #include "topology.hpp"
24 /// @addtogroup cpp_api C++ API
26 /// @addtogroup cpp_topology Network Topology
28 /// @addtogroup cpp_primitives Primitives
30 /// @brief Function, which will be used during comparison.
31 enum cond_functions : int32_t
38 /// @brief Adds primitive, which works like "if".
41 /// @n Applies comparision between 2 inputs.
42 /// @n Compare data - sizes of that input specifes the range of the comparison.
43 /// @n Offset - offset in memory, when comparing values.
44 struct condition : public primitive_base<condition, CLDNN_PRIMITIVE_DESC(condition)>
46 CLDNN_DECLARE_PRIMITIVE(condition)
48 /// @brief Constructs condition primitive / layer.
50 /// @param id An identifier of new primitive.
51 /// @param input An identifier of primitive which is an input for newly created
52 /// condition primitive.
53 /// @param topology_true Topolgoy containg primitives, which will be executed when comparsion results
55 /// @param topology_false Topolgoy containg primitives, which will be executed when comparsion results
57 /// @param compare_Data An identifier of primitive which contains compare values
58 /// @param func Used function during comparison.
59 /// @param offseg Offset for compare data.
60 /// @param output_padding Optional padding for output from primitive.
62 const primitive_id& id,
63 const primitive_id& input,
64 const topology& topology_true,
65 const topology& topology_false,
66 const primitive_id& compare_data,
67 const cond_functions& func,
68 const tensor& offset = { 0, 0, 0, 0 },
69 const padding& output_padding = padding()
71 : primitive_base(id, { input }, output_padding)
72 , topology_true(topology_true)
73 , topology_false(topology_false)
74 , compare_data(compare_data)
80 /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{condition}
81 condition(const dto* dto)
83 , topology_true(dto->topology_true)
84 , topology_false(dto->topology_false)
85 , compare_data(dto->compare_data)
86 , function(static_cast<cond_functions>(dto->function))
91 /// @brief An identifier of topology, which will be executed when comparison returns true.
92 topology topology_true;
93 /// @brief An identifier of topology, which will be executed when comparison returns false.
94 topology topology_false;
95 /// @brief An identifier of primitive which contains compare values.
96 primitive_id compare_data;
97 /// @brief Used function during comparison.
98 cond_functions function;
99 /// @brief Offset for compare data.
102 void update_dto(dto& dto) const override
104 dto.compare_data = compare_data.c_str();
105 dto.function = static_cast<cldnn_cond_functions>(function);
107 dto.topology_true = topology_true.get();
108 dto.topology_false = topology_false.get();
111 std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
113 return { compare_data };