Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / broadcast_gpu.cpp
index 8c72bdc..fc3667a 100644 (file)
@@ -35,6 +35,25 @@ struct broadcast_gpu : typed_primitive_gpu_impl<broadcast>
         auto bc_params          = get_default_params<kernel_selector::broadcast_params>(arg, 1);
         auto bc_optional_params = get_default_optional_params<kernel_selector::broadcast_optional_params>(arg.get_program());
 
+        const auto& broadcast_axes = arg.get_primitive()->broadcast_axes;
+        uint16_t index = (uint16_t) 0;
+        uint16_t input_index = (uint16_t) broadcast_axes.size();
+
+        //bfyx format
+        for (size_t i = 0; i < 4; ++i)
+        {
+            if (std::find(broadcast_axes.begin(), broadcast_axes.end(), i) != broadcast_axes.end())
+            {
+                bc_params.input_order.push_back(index);
+                ++index;
+            }
+            else
+            {
+                bc_params.input_order.push_back(input_index);
+                ++input_index;
+            }
+        }
+
         auto& kernel_selector = kernel_selector::broadcast_kernel_selector::Instance();
         auto best_kernels = kernel_selector.GetBestKernels(bc_params, bc_optional_params);
 
@@ -49,20 +68,12 @@ namespace {
         attach() {
             auto val_fw = broadcast_gpu::create;
 
-            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw);
-            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw);
-            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::i8,  format::yxfb), val_fw);
-            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::u8,  format::yxfb), val_fw);
-
             implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
             implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
             implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::i8,  format::bfyx), val_fw);
             implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::u8,  format::bfyx), val_fw);
-
-            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw);
-            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw);
-            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::i8,  format::byxf), val_fw);
-            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::u8,  format::byxf), val_fw);
+            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
+            implementation_map<broadcast>::add(std::make_tuple(engine_types::ocl, data_types::i64, format::bfyx), val_fw);
         }
         ~attach() = default;
     };