Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / condition_gpu.cpp
1 // Copyright (c) 2018 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "condition_inst.h"
16 #include "network_impl.h"
17 #include "implementation_map.h"
18 #include "math_utils.h"
19
20 #include <algorithm>
21
22 namespace cldnn { namespace gpu {
23
24 struct condition_gpu : typed_primitive_impl<condition>
25 {
26     const condition_node& outer;
27
28     condition_gpu(const condition_node& outer)
29         : outer(outer)
30     {}
31
32     event_impl::ptr execute_impl(const std::vector<event_impl::ptr>& events, condition_inst& instance) override
33     {
34         for (auto& a : events)
35         {
36             a->wait();
37         }
38         auto ev = instance.get_network().get_engine().create_user_event(false);
39
40         bool exec_branch = choose_branch_to_exec(instance);
41         memory_impl::ptr memory_to_copy;
42         if (exec_branch)
43             memory_to_copy = &execute_branch(instance.get_net_true(), instance.result_id(), instance.input_memory());
44         else
45             memory_to_copy = &execute_branch(instance.get_net_false(), instance.result_id(), instance.input_memory());
46         //just copy memory
47         mem_lock<float> inp_ptr{ memory_to_copy };
48         mem_lock<float> out_ptr{ instance.output_memory() };
49         std::copy(inp_ptr.begin(), inp_ptr.end(), out_ptr.begin());
50         dynamic_cast<cldnn::user_event*>(ev.get())->set(); // set as complete
51         return ev;
52     }
53
54     static primitive_impl* create(const condition_node& arg)
55     { 
56         return new condition_gpu(arg);
57     }
58
59 private:
60     /*
61     Add functions here.
62     */
63     bool check_condition(const float value_1, const float value_2, const cond_functions& func) const
64     {
65         switch (func)
66         {
67         case cond_functions::EQUAL: return value_1 == value_2;
68             break;
69         case cond_functions::GREATER: return value_1 > value_2;
70             break;
71         case cond_functions::LESS: return value_1 < value_2;
72             break;
73         default:
74             throw("Unknown comparision function for: " + outer.id());
75             break;
76         }
77     }
78
79     /*
80     Loop over memory and check condition.
81     Returns boolean flag, which says what branch should be executed.
82     */
83     bool choose_branch_to_exec(condition_inst& instance) const
84     {
85         mem_lock<float> lock_compare_data{ instance.compare_memory() };
86         auto compare_layout = instance.compare_memory().get_layout();
87         auto compare_ptr = lock_compare_data.begin();
88
89         mem_lock<float> lock_input{ instance.input_memory() };
90         auto input_layout = instance.input_memory().get_layout();
91         auto input_ptr = lock_input.begin();
92
93         auto function = instance.argument.function;
94         auto& offset = instance.argument.offset;
95         auto& range = compare_layout.size;
96
97         for (auto b = 0; b < range.batch[0]; b++)
98         {
99             for (auto f = 0; f < range.feature[0]; f++)
100             {
101                 for (auto y = 0; y < range.spatial[1]; y++)
102                 {
103                     for (auto x = 0; x < range.spatial[0]; x++)
104                     {
105                         auto input_idx = input_layout.get_linear_offset({
106                             b + offset.batch[0],
107                             f + offset.feature[0],
108                             x + offset.spatial[0],
109                             y + offset.spatial[1]
110                             });
111                         auto compare_idx = compare_layout.get_linear_offset({ b, f, x, y });
112                         if (!check_condition(input_ptr[input_idx], compare_ptr[compare_idx], function)) return false;
113                     }
114                 }
115             }
116         }
117         return true;
118     }
119
120     
121
122     memory_impl& execute_branch(network_impl::ptr branch, const primitive_id& input_id, memory_impl& input_memory) const
123     {
124         branch->set_input_data(input_id, input_memory);
125         branch->execute({});
126         return branch->get_outputs().at(0)->output_memory();
127     }
128
129 };
130
131 namespace {
132     struct attach {
133         attach() {
134             implementation_map<condition>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx),
135                 condition_gpu::create);
136             implementation_map<condition>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb),
137                 condition_gpu::create);
138         }
139         ~attach() = default;
140     };
141     attach attach_impl;
142 }
143
144 }