-// Copyright (c) 2016-2018 Intel Corporation
+// Copyright (c) 2016-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// limitations under the License.
#include "kernel_selector_helper.h"
+#include "kernel_selector_params.h"
+
+#include "gpu/ocl_toolkit.h"
+
+#include "program_node.h"
+#include "program_impl.h"
+
+#include "training_params.h"
kernel_selector::data_type to_data_type(data_types dt)
{
switch (dt)
{
case cldnn::data_types::i8: return kernel_selector::weights_type::INT8;
+ case cldnn::data_types::u8: return kernel_selector::weights_type::UINT8;
case cldnn::data_types::f16: return kernel_selector::weights_type::F16;
case cldnn::data_types::f32: return kernel_selector::weights_type::F32;
default:
switch (dt)
{
case kernel_selector::weights_type::INT8: return data_types::i8;
+ case kernel_selector::weights_type::UINT8: return data_types::u8;
case kernel_selector::weights_type::F16: return data_types::f16;
case kernel_selector::weights_type::F32: return data_types::f32;
default:
case format::bf8_xy16: return kernel_selector::data_layout::bf8_xy16;
case format::winograd_2x3_s1_data: return kernel_selector::data_layout::winograd_2x3_s1_data;
case format::byxf_af32: return kernel_selector::data_layout::byxf_af32;
+ case format::byx8_f4: return kernel_selector::data_layout::byx8_f4;
case format::fs_bs_yx_bsv4_fsv32: return kernel_selector::data_layout::fs_bs_yx_bsv4_fsv32;
// case format::brfyx: return kernel_selector::data_layout::brfyx;
+ case format::b_fs_yx_fsv4: return kernel_selector::data_layout::b_fs_yx_fsv4;
default:
return kernel_selector::data_layout::bfyx;
}
case kernel_selector::data_layout::brfyx: return cldnn::format::bfyx;
case kernel_selector::data_layout::winograd_2x3_s1_data: return cldnn::format::winograd_2x3_s1_data;
case kernel_selector::data_layout::byxf_af32: return cldnn::format::byxf_af32;
+ case kernel_selector::data_layout::byx8_f4: return cldnn::format::byx8_f4;
case kernel_selector::data_layout::fs_bs_yx_bsv4_fsv32: return cldnn::format::fs_bs_yx_bsv4_fsv32;
default:
return cldnn::format::bfyx;
case format::byxf: return kernel_selector::weights_layout::oyxi;
case format::yxfb: return kernel_selector::weights_layout::yxio;
case format::os_iyx_osv16: return kernel_selector::weights_layout::os_iyx_osv16;
+ case format::os_iyx_osv32: return kernel_selector::weights_layout::os_iyx_osv32;
+ case format::os_iyx_osv64: return kernel_selector::weights_layout::os_iyx_osv64;
case format::bs_xs_xsv8_bsv8: return kernel_selector::weights_layout::os_i_osv8__ai8;
case format::bs_xs_xsv8_bsv16: return kernel_selector::weights_layout::os_i_osv16__ai8;
case format::bs_x_bsv16: return kernel_selector::weights_layout::os_i_osv16;
case format::winograd_6x3_s1_fused_weights: return kernel_selector::weights_layout::winograd_6x3_s1_fused_weights;
case format::image_2d_weights_winograd_6x3_s1_fbxyb: return kernel_selector::weights_layout::image_2d_weights_winograd_6x3_s1_fbxyb;
case format::image_2d_weights_winograd_6x3_s1_xfbyb: return kernel_selector::weights_layout::image_2d_weights_winograd_6x3_s1_xfbyb;
- case format::os_is_yx_isa8_osv8_isv4: return kernel_selector::weights_layout::os_is_yx_isa8_osv8_isv4;
+ case format::os_is_yx_isa8_osv8_isv4: return kernel_selector::weights_layout::os_is_yx_isa8_osv8_isv4;
+ case format::os_is_yx_isa8_osv8_isv4_swizzled_by_4: return kernel_selector::weights_layout::os_is_yx_isa8_osv8_isv4_swizzled_by_4;
case format::is_o_yx_isv32: return kernel_selector::weights_layout::is_o_yx_isv32;
+ case format::is_o32_yx_isv32_swizzled_by_4: return kernel_selector::weights_layout::is_o32_yx_isv32_swizzled_by_4;
+ case format::os_is_y_x8_osv8_isv4: return kernel_selector::weights_layout::os_is_y_x8_osv8_isv4;
+ case format::bf_lyx_yx: return kernel_selector::weights_layout::bf_lyx_yx;
+ case format::os_is_yx_osv16_isv4: return kernel_selector::weights_layout::os_is_yx_osv16_isv4;
default:
return kernel_selector::weights_layout::oi;
}
switch (l)
{
case kernel_selector::weights_layout::oi:
- case kernel_selector::weights_layout::oiyx: return cldnn::format::bfyx;
- case kernel_selector::weights_layout::oyxi: return cldnn::format::byxf;
+ case kernel_selector::weights_layout::oiyx: return cldnn::format::bfyx;
+ case kernel_selector::weights_layout::oyxi: return cldnn::format::byxf;
case kernel_selector::weights_layout::io:
- case kernel_selector::weights_layout::iyxo: return cldnn::format::fyxb;
- case kernel_selector::weights_layout::yxio: return cldnn::format::yxfb;
- case kernel_selector::weights_layout::os_iyx_osv16: return cldnn::format::os_iyx_osv16;
- case kernel_selector::weights_layout::os_i_osv16: return cldnn::format::bs_x_bsv16;
- case kernel_selector::weights_layout::os_i_osv8__ai8: return cldnn::format::bs_xs_xsv8_bsv8;
- case kernel_selector::weights_layout::os_i_osv16__ai8: return cldnn::format::bs_xs_xsv8_bsv16;
- case kernel_selector::weights_layout::image_2d_weights_c4_fyx_b: return cldnn::format::image_2d_weights_c4_fyx_b;
- case kernel_selector::weights_layout::image_2d_weights_c1_b_fyx: return cldnn::format::image_2d_weights_c1_b_fyx;
- case kernel_selector::weights_layout::winograd_2x3_s1_weights: return cldnn::format::winograd_2x3_s1_weights;
- case kernel_selector::weights_layout::winograd_2x3_s1_fused_weights: return cldnn::format::winograd_2x3_s1_fused_weights;
- case kernel_selector::weights_layout::winograd_6x3_s1_fused_weights: return cldnn::format::winograd_6x3_s1_fused_weights;
- case kernel_selector::weights_layout::image_2d_weights_winograd_6x3_s1_fbxyb: return cldnn::format::image_2d_weights_winograd_6x3_s1_fbxyb;
- case kernel_selector::weights_layout::image_2d_weights_winograd_6x3_s1_xfbyb: return cldnn::format::image_2d_weights_winograd_6x3_s1_xfbyb;
- case kernel_selector::weights_layout::os_is_yx_isa8_osv8_isv4: return cldnn::format::os_is_yx_isa8_osv8_isv4;
- case kernel_selector::weights_layout::is_o_yx_isv32: return cldnn::format::is_o_yx_isv32;
+ case kernel_selector::weights_layout::iyxo: return cldnn::format::fyxb;
+ case kernel_selector::weights_layout::yxio: return cldnn::format::yxfb;
+ case kernel_selector::weights_layout::os_iyx_osv16: return cldnn::format::os_iyx_osv16;
+ case kernel_selector::weights_layout::os_iyx_osv32: return cldnn::format::os_iyx_osv32;
+ case kernel_selector::weights_layout::os_iyx_osv64: return cldnn::format::os_iyx_osv64;
+ case kernel_selector::weights_layout::os_i_osv16: return cldnn::format::bs_x_bsv16;
+ case kernel_selector::weights_layout::os_i_osv8__ai8: return cldnn::format::bs_xs_xsv8_bsv8;
+ case kernel_selector::weights_layout::os_i_osv16__ai8: return cldnn::format::bs_xs_xsv8_bsv16;
+ case kernel_selector::weights_layout::image_2d_weights_c4_fyx_b: return cldnn::format::image_2d_weights_c4_fyx_b;
+ case kernel_selector::weights_layout::image_2d_weights_c1_b_fyx: return cldnn::format::image_2d_weights_c1_b_fyx;
+ case kernel_selector::weights_layout::winograd_2x3_s1_weights: return cldnn::format::winograd_2x3_s1_weights;
+ case kernel_selector::weights_layout::winograd_2x3_s1_fused_weights: return cldnn::format::winograd_2x3_s1_fused_weights;
+ case kernel_selector::weights_layout::winograd_6x3_s1_fused_weights: return cldnn::format::winograd_6x3_s1_fused_weights;
+ case kernel_selector::weights_layout::image_2d_weights_winograd_6x3_s1_fbxyb: return cldnn::format::image_2d_weights_winograd_6x3_s1_fbxyb;
+ case kernel_selector::weights_layout::image_2d_weights_winograd_6x3_s1_xfbyb: return cldnn::format::image_2d_weights_winograd_6x3_s1_xfbyb;
+ case kernel_selector::weights_layout::os_is_yx_isa8_osv8_isv4: return cldnn::format::os_is_yx_isa8_osv8_isv4;
+ case kernel_selector::weights_layout::os_is_yx_isa8_osv8_isv4_swizzled_by_4: return cldnn::format::os_is_yx_isa8_osv8_isv4_swizzled_by_4;
+ case kernel_selector::weights_layout::is_o_yx_isv32: return cldnn::format::is_o_yx_isv32;
+ case kernel_selector::weights_layout::is_o32_yx_isv32_swizzled_by_4: return cldnn::format::is_o32_yx_isv32_swizzled_by_4;
+ case kernel_selector::weights_layout::os_is_y_x8_osv8_isv4: return cldnn::format::os_is_y_x8_osv8_isv4;
+ case kernel_selector::weights_layout::bf_lyx_yx: return cldnn::format::bf_lyx_yx;
default:
return cldnn::format::bfyx;
}
new_vals[3] = align_to(vals[3], 32);
new_vals[2] = align_to(vals[2], 4);
}
+ if (ks_layout == kernel_selector::Tensor::byx8_f4)
+ {
+ new_vals[3] = align_to(vals[3], 4);
+ new_vals[2] = align_to(vals[2], 8);
+ }
for (size_t i = 0; i < vec.size(); i++)
{
kernel_selector::weights_tensor convert_weights_tensor(const layout& l)
{
- assert(l.format.dimension() == 4);
- const auto& t = l.size.sizes(format::bfyx);
- const auto base_layout = kernel_selector::weights_layout::oiyx;
+ const auto& t = l.size.sizes(l.format);
+ const auto base_layout = to_weights_layout(l.format);
const auto ks_type = to_weights_type(l.data_type);
const auto ks_layout = to_weights_layout(l.format);
std::vector<size_t> vec(kernel_selector::WeightsTensor::ChannelsCount(base_layout));
return kernel_selector::activation_function::COSH;
case activation_log:
return kernel_selector::activation_function::LOG;
- case activation_log2:
- return kernel_selector::activation_function::LOG2;
+ case activation_log2:
+ return kernel_selector::activation_function::LOG2;
case activation_exp:
return kernel_selector::activation_function::EXP;
+ case activation_not:
+ return kernel_selector::activation_function::NOT;
default:
throw std::runtime_error("Unknown activation function");
break;
throw std::runtime_error("Unknown activation_grad function");
break;
}
+}
+
+void set_params(const program_node& node, kernel_selector::params& params)
+{
+ const auto& context = node.get_program().get_engine().get_context();
+ const auto& engine_info = context->get_engine_info();
+
+ params.engineInfo.bSubGroupSupport = context->extension_supported("cl_intel_subgroups");
+ params.engineInfo.bSubGroupShortSupport = context->extension_supported("cl_intel_subgroups_short");
+ params.engineInfo.bFP16Support = context->extension_supported("cl_khr_fp16");
+ params.engineInfo.bFP64Support = context->extension_supported("cl_khr_fp64");
+ params.engineInfo.bIMADSupport = engine_info.supports_imad != 0;
+ params.engineInfo.bIMMADSupport = engine_info.supports_immad != 0;
+ params.engineInfo.bImageSupport = engine_info.supports_image != 0;
+ params.engineInfo.maxWorkGroupSize = engine_info.max_work_group_size;
+ params.engineInfo.maxLocalMemSize = engine_info.max_local_mem_size;
+ params.engineInfo.maxImage2dWidth = engine_info.max_image2d_width;
+ params.engineInfo.maxImage2dHeight = engine_info.max_image2d_height;
+ params.engineInfo.deviceId = engine_info.dev_id;
+ params.engineInfo.computeUnitsCount = engine_info.compute_units_count;
+ params.engineInfo.deviceCache = engine_info.device_cache;
+ params.engineInfo.driverVersion = engine_info.driver_version;
+ params.engineInfo.hostVersion = to_host_version(cldnn::get_version());
+}
+
+void set_learning_params(const program_node& node, kernel_selector::training_params& params, bool use_momentum)
+{
+ const auto learning_params = node.get_program().get_options().template get<build_option_type::learning_config>()->params;
+
+ if (use_momentum)
+ {
+ params.use_momentum = true;
+ }
+
+ params.momentum_factor = learning_params.momentum;
+ params.weights_decay = learning_params.weights_decay;
+}
+
+void set_optional_params(const program_impl& program, kernel_selector::optional_params& params)
+{
+ const auto& context = program.get_engine().get_context();
+
+ params.meaningfulKernelsNames = context->get_configuration().meaningful_kernels_names;
+ params.allowStaticInputReordering = program.get_options().get<build_option_type::optimize_data>()->enabled();
+ params.allowInputReordering = false;
+ params.allowOutputReordering = false;
+
+ const auto& tuning_config = program.get_options().get<build_option_type::tuning_config>();
+ params.tuningParams.mode = to_tuning_mode(tuning_config->config.mode);
+ params.tuningParams.cacheFilePath = tuning_config->config.cache_file_path;
}
\ No newline at end of file