Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / condition.hpp
1 // Copyright (c) 2018 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 ///////////////////////////////////////////////////////////////////////////////////////////////////
16 #pragma once
17
18 #include "../C/condition.h"
19 #include "primitive.hpp"
20 #include "topology.hpp"
21
22 namespace cldnn
23 {
24 /// @addtogroup cpp_api C++ API
25 /// @{
26 /// @addtogroup cpp_topology Network Topology
27 /// @{
28 /// @addtogroup cpp_primitives Primitives
29 /// @{
30 /// @brief Function, which will be used during comparison.
31 enum cond_functions : int32_t
32 {
33     EQUAL,
34     GREATER,
35     LESS
36 };
37
38 /// @brief Adds primitive, which works like "if".
39 ///
40 /// @details
41 /// @n   Applies comparision between 2 inputs.
42 /// @n   Compare data - sizes of that input specifes the range of the comparison.
43 /// @n   Offset - offset in memory, when comparing values.
44 struct condition : public primitive_base<condition, CLDNN_PRIMITIVE_DESC(condition)>
45 {
46     CLDNN_DECLARE_PRIMITIVE(condition)
47
48         /// @brief Constructs condition primitive / layer.
49         ///
50         /// @param id                 An identifier of new primitive.
51         /// @param input              An identifier of primitive which is an input for newly created
52         ///                           condition primitive.
53         /// @param topology_true      Topolgoy containg primitives, which will be executed when comparsion results  
54         ///                           true.
55         /// @param topology_false     Topolgoy containg primitives, which will be executed when comparsion results  
56         ///                           false..
57         /// @param compare_Data       An identifier of primitive which contains compare values
58         /// @param func               Used function during comparison.
59         /// @param offseg             Offset for compare data.
60         /// @param output_padding     Optional padding for output from primitive.
61         condition(
62             const primitive_id& id,
63             const primitive_id& input,
64             const topology& topology_true,
65             const topology& topology_false,
66             const primitive_id& compare_data,
67             const cond_functions& func,
68             const tensor& offset = { 0, 0, 0, 0 },
69             const padding& output_padding = padding()
70         )
71         : primitive_base(id, { input }, output_padding)
72         , topology_true(topology_true)
73         , topology_false(topology_false)
74         , compare_data(compare_data)
75         , function(func)
76         , offset(offset)
77     {}
78
79
80     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{condition}
81     condition(const dto* dto)
82         : primitive_base(dto)
83         , topology_true(dto->topology_true)
84         , topology_false(dto->topology_false)
85         , compare_data(dto->compare_data)
86         , function(static_cast<cond_functions>(dto->function))
87         , offset(dto->offset)
88     {}
89
90
91     /// @brief An identifier of topology, which will be executed when comparison returns true.
92     topology topology_true;
93     /// @brief An identifier of topology, which will be executed when comparison returns false.
94     topology topology_false;
95     /// @brief An identifier of primitive which contains compare values.
96     primitive_id compare_data;
97     /// @brief Used function during comparison.
98     cond_functions function;
99     /// @brief Offset for compare data.
100     tensor offset;
101 protected:
102     void update_dto(dto& dto) const override
103     {
104         dto.compare_data = compare_data.c_str();
105         dto.function = static_cast<cldnn_cond_functions>(function);
106         dto.offset = offset;
107         dto.topology_true = topology_true.get();
108         dto.topology_false = topology_false.get();
109     }
110
111     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
112     {
113         return { compare_data };
114     }
115 };
116 }
117 /// @}
118 /// @}
119 /// @}