1 // Copyright (c) 2018 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "condition_inst.h"
16 #include "network_impl.h"
17 #include "implementation_map.h"
18 #include "math_utils.h"
22 namespace cldnn { namespace gpu {
24 struct condition_gpu : typed_primitive_impl<condition>
26 const condition_node& outer;
28 condition_gpu(const condition_node& outer)
32 event_impl::ptr execute_impl(const std::vector<event_impl::ptr>& events, condition_inst& instance) override
34 for (auto& a : events)
38 auto ev = instance.get_network().get_engine().create_user_event(false);
40 bool exec_branch = choose_branch_to_exec(instance);
41 memory_impl::ptr memory_to_copy;
43 memory_to_copy = &execute_branch(instance.get_net_true(), instance.result_id(), instance.input_memory());
45 memory_to_copy = &execute_branch(instance.get_net_false(), instance.result_id(), instance.input_memory());
47 mem_lock<float> inp_ptr{ memory_to_copy };
48 mem_lock<float> out_ptr{ instance.output_memory() };
49 std::copy(inp_ptr.begin(), inp_ptr.end(), out_ptr.begin());
50 dynamic_cast<cldnn::user_event*>(ev.get())->set(); // set as complete
54 static primitive_impl* create(const condition_node& arg)
56 return new condition_gpu(arg);
63 bool check_condition(const float value_1, const float value_2, const cond_functions& func) const
67 case cond_functions::EQUAL: return value_1 == value_2;
69 case cond_functions::GREATER: return value_1 > value_2;
71 case cond_functions::LESS: return value_1 < value_2;
74 throw("Unknown comparision function for: " + outer.id());
80 Loop over memory and check condition.
81 Returns boolean flag, which says what branch should be executed.
83 bool choose_branch_to_exec(condition_inst& instance) const
85 mem_lock<float> lock_compare_data{ instance.compare_memory() };
86 auto compare_layout = instance.compare_memory().get_layout();
87 auto compare_ptr = lock_compare_data.begin();
89 mem_lock<float> lock_input{ instance.input_memory() };
90 auto input_layout = instance.input_memory().get_layout();
91 auto input_ptr = lock_input.begin();
93 auto function = instance.argument.function;
94 auto& offset = instance.argument.offset;
95 auto& range = compare_layout.size;
97 for (auto b = 0; b < range.batch[0]; b++)
99 for (auto f = 0; f < range.feature[0]; f++)
101 for (auto y = 0; y < range.spatial[1]; y++)
103 for (auto x = 0; x < range.spatial[0]; x++)
105 auto input_idx = input_layout.get_linear_offset({
107 f + offset.feature[0],
108 x + offset.spatial[0],
109 y + offset.spatial[1]
111 auto compare_idx = compare_layout.get_linear_offset({ b, f, x, y });
112 if (!check_condition(input_ptr[input_idx], compare_ptr[compare_idx], function)) return false;
122 memory_impl& execute_branch(network_impl::ptr branch, const primitive_id& input_id, memory_impl& input_memory) const
124 branch->set_input_data(input_id, input_memory);
126 return branch->get_outputs().at(0)->output_memory();
134 implementation_map<condition>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx),
135 condition_gpu::create);
136 implementation_map<condition>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb),
137 condition_gpu::create);