From: Konrad Dobros Date: Tue, 16 Jun 2020 06:07:05 +0000 (+0200) Subject: [IE CLDNN] Add resample improvements (#933) X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=db3dff36b93d28546a8a3f7bfa463d257696f1cf;p=platform%2Fupstream%2Fdldt.git [IE CLDNN] Add resample improvements (#933) This change: - extends concat in-place optimization for resample on input - adds resample primitive int8 support for bilinear mode - fixes some potential issues with offset calculations with in8 --- diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_base.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_base.cpp index 511b61e..d44b0f3 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_base.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_base.cpp @@ -89,7 +89,8 @@ bool ResampleKernelBase::Validate(const Params& p, const optional_params& o) con const auto& input = params.inputs[0]; if ((input.GetDType() == Datatype::UINT8 || input.GetDType() == Datatype::INT8) && - params.resampleType != ResampleType::NEAREST_NEIGHBOR) + params.resampleType != ResampleType::NEAREST_NEIGHBOR && + params.resampleType != ResampleType::BILINEAR_INTERP) return false; return true; @@ -154,6 +155,8 @@ JitConstants ResampleKernelBase::GetJitConstants(const resample_params& params) } } + jit.Merge(MakeTypeJitConstants(GetAccumulatorType(params), "ACCUMULATOR")); + return jit; } @@ -178,4 +181,26 @@ KernelsData ResampleKernelBase::GetCommonKernelsData(const Params& params, const return {kd}; } + +Datatype ResampleKernelBase::GetAccumulatorType(const resample_params& params) const { + auto in_dt = params.inputs[0].GetDType(); + auto out_dt = params.output.GetDType(); + + if (params.resampleType == ResampleType::NEAREST_NEIGHBOR) + return in_dt; + + auto smaller_fp_type = [](const Datatype& current, const Datatype& candidate) -> Datatype { + if (candidate != Datatype::F32 || candidate != Datatype::F16) + return current; + + return BytesPerElement(candidate) < BytesPerElement(current) ? candidate : current; + }; + + Datatype fp_type = Datatype::F32; + fp_type = smaller_fp_type(fp_type, in_dt); + fp_type = smaller_fp_type(fp_type, out_dt); + + return fp_type; +} + } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_base.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_base.h index 8ebbd0d..f2a3c31 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_base.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_base.h @@ -58,5 +58,6 @@ protected: virtual JitConstants GetJitConstants(const resample_params& params) const; KernelsData GetCommonKernelsData(const Params& params, const optional_params& options) const; size_t GetFeatureBlockSize(const resample_params& params) const; + virtual Datatype GetAccumulatorType(const resample_params& params) const; }; } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_opt.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_opt.cpp index 66648f5..9a74a61 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_opt.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_opt.cpp @@ -106,8 +106,8 @@ JitConstants ResampleKernelOpt::GetJitConstants(const resample_params ¶ms) c jit.AddConstant(MakeJitConstant("VEC_SIZE", vec_size)); if (!params.fused_ops.empty()) { - std::vector idx_order = {"b", "feature_num", "y", "(x + out_x)"}; - FusedOpsConfiguration conf = {"", idx_order, "res", params.inputs[0].GetDType(), vec_size, LoadType::LT_ALIGNED_READ}; + std::vector idx_order = {"b", "feature_block", "y", "(x + out_x)"}; + FusedOpsConfiguration conf = {"", idx_order, "res", GetAccumulatorType(params), vec_size, LoadType::LT_ALIGNED_READ}; conf.SetVectorAxis(Tensor::DataChannelName::FEATURE); jit.Merge(MakeFusedOpsJitConstants(params, {conf})); } diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_ref.cpp index c85b044..5f3a423 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/resample/resample_kernel_ref.cpp @@ -115,7 +115,7 @@ JitConstants ResampleKernelRef::GetJitConstants(const resample_params& params) c idx_order = {"batch", "OF_ID", "oz", "oy", "ox"}; } - FusedOpsConfiguration conf = {"", idx_order, "interp_val", params.inputs[0].GetDType(), 1}; + FusedOpsConfiguration conf = {"", idx_order, "interp_val", GetAccumulatorType(params), 1}; jit.Merge(MakeFusedOpsJitConstants(params, {conf})); } diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/resample_opt.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/resample_opt.cl index db71924..a870286 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/resample_opt.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/resample_opt.cl @@ -15,17 +15,18 @@ #include "include/common.cl" #include "include/data_types.cl" #include "include/include_all.cl" -#include "include/unit_type.cl" #define unroll_for __attribute__((opencl_unroll_hint)) for -#ifdef INPUT0_LAYOUT_FS_B_YX_FSV32 - #define READ_FUNC(ptr, offset) CAT(UNIT_BLOCK_READ, VEC_SIZE)(ptr, offset) - #define WRITE_FUNC(ptr, offset, val) CAT(UNIT_BLOCK_WRITE, VEC_SIZE)(ptr, offset, val) -#else - #define READ_FUNC(ptr, offset) UNIT_BLOCK_READ(ptr, offset) - #define WRITE_FUNC(ptr, offset, val) UNIT_BLOCK_WRITE(ptr, offset, val) -#endif +#define READ_FUNC(ptr, offset) BLOCK_READN(INPUT0_TYPE, VEC_SIZE, ptr, offset) +#define WRITE_FUNC(ptr, offset, val) BLOCK_WRITEN(OUTPUT_TYPE, VEC_SIZE, ptr, offset, val) + +#define IN_VEC_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_SIZE) +#define TO_IN_VEC_TYPE(x) CAT(convert_, IN_VEC_TYPE)(x) +#define ACC_VEC_TYPE MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, VEC_SIZE) +#define TO_ACC_VEC_TYPE(x) CAT(convert_, ACC_VEC_TYPE)(x) +#define OUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_SIZE) +#define TO_OUT_VEC_TYPE(x) CAT(convert_, OUT_VEC_TYPE)(x) __attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) KERNEL (resample_opt)(__global INPUT0_TYPE* input, @@ -41,11 +42,10 @@ KERNEL (resample_opt)(__global INPUT0_TYPE* input, const int f_block = get_group_id(1); const int b = get_global_id(2); const int feature_num = f_block * FEATURE_SLICE_SIZE + get_sub_group_local_id(); -#ifdef INPUT0_LAYOUT_FS_B_YX_FSV32 - typedef MAKE_VECTOR_TYPE(UNIT_TYPE, VEC_SIZE) unit_t; -#else - typedef UNIT_TYPE unit_t; -#endif + const uint feature_block = f_block * FEATURE_SLICE_SIZE; + + typedef IN_VEC_TYPE in_vec_t; + typedef ACC_VEC_TYPE acc_vec_t; if (feature_num >= OUTPUT_FEATURE_NUM) return; @@ -55,46 +55,36 @@ KERNEL (resample_opt)(__global INPUT0_TYPE* input, const int ix = floor((x + out_x) * X_RATIO); const int iy = floor(y * Y_RATIO); - unit_t res = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_num, iy, ix)); + in_vec_t res = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, iy, ix)); #else - const UNIT_TYPE ix = TO_UNIT_TYPE(X_RATIO) * (x + out_x); - const UNIT_TYPE iy = TO_UNIT_TYPE(Y_RATIO) * y; + const ACCUMULATOR_TYPE ix = TO_ACCUMULATOR_TYPE(X_RATIO) * (x + out_x); + const ACCUMULATOR_TYPE iy = TO_ACCUMULATOR_TYPE(Y_RATIO) * y; const int top_y_index = (int)(floor(iy)); - const int bottom_y_index = (int)(min(ceil(iy), TO_UNIT_TYPE(INPUT0_SIZE_Y) - 1)); + const int bottom_y_index = min((int)ceil(iy), INPUT0_SIZE_Y - 1); const int left_x_index = (int)(floor(ix)); - const int right_x_index = (int)(min(ceil(ix), TO_UNIT_TYPE(INPUT0_SIZE_X) - 1)); + const int right_x_index = min((int)ceil(ix), INPUT0_SIZE_X - 1); - const UNIT_TYPE dx = ix - left_x_index; - const UNIT_TYPE dy = iy - top_y_index; + const ACCUMULATOR_TYPE dx = ix - left_x_index; + const ACCUMULATOR_TYPE dy = iy - top_y_index; - const unit_t top_left = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_num, top_y_index, left_x_index)); - const unit_t top_right = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_num, top_y_index, right_x_index)); - const unit_t bottom_left = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_num, bottom_y_index, left_x_index)); - const unit_t bottom_right = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_num, bottom_y_index, right_x_index)); + const in_vec_t top_left = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, top_y_index, left_x_index)); + const in_vec_t top_right = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, top_y_index, right_x_index)); + const in_vec_t bottom_left = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, bottom_y_index, left_x_index)); + const in_vec_t bottom_right = READ_FUNC(input, INPUT0_GET_INDEX(b, feature_block, bottom_y_index, right_x_index)); - const unit_t top = top_left + (top_right - top_left) * dx; - const unit_t bottom = bottom_left + (bottom_right - bottom_left) * dx; - unit_t res = top + (bottom - top) * dy; + const acc_vec_t top = TO_ACC_VEC_TYPE(top_left) + (TO_ACC_VEC_TYPE(top_right) - TO_ACC_VEC_TYPE(top_left)) * dx; + const acc_vec_t bottom = TO_ACC_VEC_TYPE(bottom_left) + (TO_ACC_VEC_TYPE(bottom_right) - TO_ACC_VEC_TYPE(bottom_left)) * dx; + acc_vec_t res = top + (bottom - top) * dy; #endif #if HAS_FUSED_OPS FUSED_OPS; - res = FUSED_OPS_RESULT; + OUT_VEC_TYPE out = FUSED_OPS_RESULT; #else - res = ACTIVATION(res, ACTIVATION_PARAMS); + OUT_VEC_TYPE out = TO_OUT_VEC_TYPE(ACTIVATION(res, ACTIVATION_PARAMS)); #endif -#if OUTPUT_IS_FP - WRITE_FUNC(output, OUTPUT_GET_INDEX(b, feature_num, y, (x + out_x)), res); -#else -#if VEC_SIZE > 1 - for (uint i = 0; i < VEC_SIZE; i++) - output[OUTPUT_GET_INDEX(b, feature_num + i*SUB_GROUP_SIZE, y, (x + out_x))] = res[i]; -#else - output[OUTPUT_GET_INDEX(b, feature_num, y, (x + out_x))] = res; -#endif - -#endif + WRITE_FUNC(output, OUTPUT_GET_INDEX(b, feature_block, y, (x + out_x)), out); } } diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/resample_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/resample_ref.cl index c965d04..a7372ed 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/resample_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/resample_ref.cl @@ -39,7 +39,7 @@ inline uint FUNC(get_output_index)(uint b, uint f, uint z, uint y, uint x) } -#define TRIANGLE_COEFF(x) (INPUT0_MAX_FUNC(INPUT0_VAL_ZERO, INPUT0_VAL_ONE - INPUT0_ABS_FUNC(x))) +#define TRIANGLE_COEFF(x) (ACCUMULATOR_MAX_FUNC(ACCUMULATOR_VAL_ZERO, ACCUMULATOR_VAL_ONE - ACCUMULATOR_ABS_FUNC(x))) #define unroll_for __attribute__((opencl_unroll_hint)) for KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input, @@ -54,10 +54,15 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input, typedef MAKE_VECTOR_TYPE(OUTPUT_TYPE, PACK_SIZE) out_pack_t; const int ox = get_global_id(0); - const int oy = get_global_id(1) % OUTPUT_SIZE_Y; - const int oz = get_global_id(1) / OUTPUT_SIZE_Y; - const int feature = (get_global_id(2) * PACK_SIZE) % OUTPUT_FEATURE_NUM; - const int batch = (get_global_id(2) * PACK_SIZE) / OUTPUT_FEATURE_NUM; +#if OUTPUT_DIMS <= 4 + const int oy = get_global_id(1); + const int oz = 0; +#else + const int oy = (int)get_global_id(1) % OUTPUT_SIZE_Y; + const int oz = (int)get_global_id(1) / OUTPUT_SIZE_Y; +#endif + const int feature = ((int)get_global_id(2) * PACK_SIZE) % OUTPUT_FEATURE_NUM; + const int batch = ((int)get_global_id(2) * PACK_SIZE) / OUTPUT_FEATURE_NUM; const int ix = floor(ox * X_RATIO); const int iy = floor(oy * Y_RATIO); const int iz = floor(oz * Z_RATIO); @@ -117,13 +122,13 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input, return; #endif - const int top_y_index = (int)(floor(iy)); - const int bottom_y_index = (int)(min(TO_INPUT0_TYPE(ceil(iy)), TO_INPUT0_TYPE(INPUT0_SIZE_Y) - 1)); - const int left_x_index = (int)(floor(ix)); - const int right_x_index = (int)(min(TO_INPUT0_TYPE(ceil(ix)), TO_INPUT0_TYPE(INPUT0_SIZE_X) - 1)); + const int top_y_index = (int)(floor(iy)); + const int bottom_y_index = min((int)ceil(iy), INPUT0_SIZE_Y - 1); + const int left_x_index = (int)(floor(ix)); + const int right_x_index = min((int)ceil(ix), INPUT0_SIZE_X - 1); - const INPUT0_TYPE dx = TO_INPUT0_TYPE(ix - left_x_index); - const INPUT0_TYPE dy = TO_INPUT0_TYPE(iy - top_y_index); + const ACCUMULATOR_TYPE dx = TO_ACCUMULATOR_TYPE(ix - left_x_index); + const ACCUMULATOR_TYPE dy = TO_ACCUMULATOR_TYPE(iy - top_y_index); unroll_for(int in_f = 0; in_f < OUTPUT_FEATURE_NUM; in_f++) { INPUT0_TYPE top_left = input[INPUT0_GET_INDEX(batch, in_f, top_y_index, left_x_index)]; @@ -131,17 +136,17 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input, INPUT0_TYPE bottom_left = input[INPUT0_GET_INDEX(batch, in_f, bottom_y_index, left_x_index)]; INPUT0_TYPE bottom_right = input[INPUT0_GET_INDEX(batch, in_f, bottom_y_index, right_x_index)]; - INPUT0_TYPE top = top_left + (top_right - top_left) * dx; - INPUT0_TYPE bottom = bottom_left + (bottom_right - bottom_left) * dx; + ACCUMULATOR_TYPE top = TO_ACCUMULATOR_TYPE(top_left) + (TO_ACCUMULATOR_TYPE(top_right) - TO_ACCUMULATOR_TYPE(top_left)) * dx; + ACCUMULATOR_TYPE bottom = TO_ACCUMULATOR_TYPE(bottom_left) + (TO_ACCUMULATOR_TYPE(bottom_right) - TO_ACCUMULATOR_TYPE(bottom_left)) * dx; - INPUT0_TYPE interp_val = top + (bottom - top) * dy; + ACCUMULATOR_TYPE interp_val = top + (bottom - top) * dy; #if HAS_FUSED_OPS #define OF_ID (in_f) FUSED_OPS; OUTPUT_TYPE res = FUSED_OPS_RESULT; #else - OUTPUT_TYPE res = ACTIVATION(interp_val, ACTIVATION_PARAMS); + OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(interp_val, ACTIVATION_PARAMS)); #endif output[OUTPUT_GET_INDEX(batch, in_f, oy, ox)] = res; } @@ -158,32 +163,32 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input, const int oz = (int)get_global_id(2) / OUTPUT_BATCH_NUM; #endif - const INPUT0_TYPE ix = ox * X_RATIO + X_RATIO_HALF - 0.5f; - const INPUT0_TYPE iy = oy * Y_RATIO + Y_RATIO_HALF - 0.5f; - const INPUT0_TYPE iz = oz * Z_RATIO + Z_RATIO_HALF - 0.5f; + const ACCUMULATOR_TYPE ix = ox * X_RATIO + X_RATIO_HALF - 0.5f; + const ACCUMULATOR_TYPE iy = oy * Y_RATIO + Y_RATIO_HALF - 0.5f; + const ACCUMULATOR_TYPE iz = oz * Z_RATIO + Z_RATIO_HALF - 0.5f; const int ix_r = (int)ix; const int iy_r = (int)iy; const int iz_r = (int)iz; #if ANTIALIAS == 1 - const INPUT0_TYPE ax = 1.0f / X_RATIO; - const INPUT0_TYPE ay = 1.0f / Y_RATIO; - const INPUT0_TYPE az = 1.0f / Z_RATIO; + const ACCUMULATOR_TYPE ax = 1.0f / X_RATIO; + const ACCUMULATOR_TYPE ay = 1.0f / Y_RATIO; + const ACCUMULATOR_TYPE az = 1.0f / Z_RATIO; #else - const INPUT0_TYPE ax = 1.0f; - const INPUT0_TYPE ay = 1.0f; - const INPUT0_TYPE az = 1.0f; + const ACCUMULATOR_TYPE ax = 1.0f; + const ACCUMULATOR_TYPE ay = 1.0f; + const ACCUMULATOR_TYPE az = 1.0f; #endif - const int rx = (X_RATIO < 1.0f) ? 2 : (int)ceil(TO_INPUT0_TYPE(KERNEL_W) / ax); - const int ry = (Y_RATIO < 1.0f) ? 2 : (int)ceil(TO_INPUT0_TYPE(KERNEL_W) / ay); - const int rz = (Z_RATIO < 1.0f) ? 2 : (int)ceil(TO_INPUT0_TYPE(KERNEL_W) / az); + const int rx = (X_RATIO < 1.0f) ? 2 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / ax); + const int ry = (Y_RATIO < 1.0f) ? 2 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / ay); + const int rz = (Z_RATIO < 1.0f) ? 2 : (int)ceil(TO_ACCUMULATOR_TYPE(KERNEL_W) / az); - INPUT0_TYPE sum[FEATURE_BLOCK_SIZE]; + ACCUMULATOR_TYPE sum[FEATURE_BLOCK_SIZE]; for (int i = 0; i < FEATURE_BLOCK_SIZE; i++) sum[i] = 0; - INPUT0_TYPE wsum = 0; + ACCUMULATOR_TYPE wsum = 0; int const y_init = max(0, iy_r - ry); int const x_init = max(0, ix_r - rx); @@ -195,13 +200,13 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input, unroll_for(int z = z_init; z < z_max; z++) { unroll_for(int y = y_init; y < y_max; y++) { unroll_for(int x = x_init; x < x_max; x++) { - INPUT0_TYPE dx = ix - x; - INPUT0_TYPE dy = iy - y; - INPUT0_TYPE dz = iz - z; + ACCUMULATOR_TYPE dx = ix - x; + ACCUMULATOR_TYPE dy = iy - y; + ACCUMULATOR_TYPE dz = iz - z; #if ANTIALIAS == 1 - INPUT0_TYPE w = ax * TRIANGLE_COEFF(ax * dx) * ay * TRIANGLE_COEFF(ay * dy) * az * triangleCoeff(az * dz); + ACCUMULATOR_TYPE w = ax * TRIANGLE_COEFF(ax * dx) * ay * TRIANGLE_COEFF(ay * dy) * az * triangleCoeff(az * dz); #else - INPUT0_TYPE w = TRIANGLE_COEFF(dx) * TRIANGLE_COEFF(dy) * TRIANGLE_COEFF(dz); + ACCUMULATOR_TYPE w = TRIANGLE_COEFF(dx) * TRIANGLE_COEFF(dy) * TRIANGLE_COEFF(dz); #endif #ifndef LEFTOVERS @@ -211,7 +216,7 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input, unroll_for(int f = 0; f < f_max; f++) { #endif if (w != 0) - sum[f] += w * input[FUNC_CALL(get_input_index)(batch, feature + f, z, y, x)]; + sum[f] += w * TO_ACCUMULATOR_TYPE(input[FUNC_CALL(get_input_index)(batch, feature + f, z, y, x)]); } wsum += w; } @@ -224,13 +229,13 @@ KERNEL (resample_gpu_ref)(__global INPUT0_TYPE* input, unroll_for (int f = 0; f < f_max; f++) { #endif - INPUT0_TYPE interp_val = (wsum == 0) ? 0 : (sum[f] / wsum); + ACCUMULATOR_TYPE interp_val = (wsum == 0) ? 0 : (sum[f] / wsum); #if HAS_FUSED_OPS #define OF_ID (feature + f) FUSED_OPS; OUTPUT_TYPE res = FUSED_OPS_RESULT; #else - OUTPUT_TYPE res = ACTIVATION(interp_val, ACTIVATION_PARAMS); + OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(interp_val, ACTIVATION_PARAMS)); #endif output[FUNC_CALL(get_output_index)(batch, feature + f, oz, oy, ox)] = res; } diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_buffer_fusing.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_buffer_fusing.cpp index 1e1ff1e..cabfa1d 100644 --- a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_buffer_fusing.cpp +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_buffer_fusing.cpp @@ -27,6 +27,7 @@ #include "reshape_inst.h" #include "scale_inst.h" #include "depth_to_space_inst.h" +#include "resample_inst.h" #include "pass_manager.h" #include "program_helpers.h" @@ -136,7 +137,8 @@ void prepare_buffer_fusing::run(program_impl& p) { // todo: we need add padding support for all optimized kernels to remove this condition if (!input->is_type() && !input->is_type() && !input->is_type() && !input->is_type() && - !input->is_type() && !input->is_type() && !input->is_type()) + !input->is_type() && !input->is_type() && !input->is_type() && + !input->is_type()) return; // if an input is marked as network output, prevent optimizations diff --git a/inference-engine/thirdparty/clDNN/src/resample.cpp b/inference-engine/thirdparty/clDNN/src/resample.cpp index b7a82ad..ac33f7c 100644 --- a/inference-engine/thirdparty/clDNN/src/resample.cpp +++ b/inference-engine/thirdparty/clDNN/src/resample.cpp @@ -32,6 +32,10 @@ layout resample_inst::calc_output_layout(resample_node const& node) { auto input_layout = node.input().get_output_layout(); auto output_type = input_layout.data_type; + if ((input_layout.data_type == data_types::i8 || input_layout.data_type == data_types::u8) + && desc->operation_type != resample_type::nearest) { + output_type = data_types::f32; + } if (node.has_fused_primitives()) { output_type = node.get_fused_output_layout().data_type; } diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 2323c3c..62ad18c 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -2332,6 +2332,16 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gemm_int8_2in_act_scale_quantize_eltwise_i8 #define CASE_RESAMPLE_FP16_9 {1, 16, 4, 5}, {1, 16, 7, 8}, data_types::f16, format::b_fs_yx_fsv16, resample_type::bilinear, data_types::f16, format::bfyx #define CASE_RESAMPLE_FP16_10 {2, 32, 4, 5}, {2, 32, 7, 8}, data_types::f16, format::fs_b_yx_fsv32, resample_type::bilinear, data_types::f16, format::bfyx +#define CASE_RESAMPLE_I8_1 {1, 16, 4, 5}, {1, 16, 2, 3}, data_types::i8, format::b_fs_yx_fsv16, resample_type::nearest, data_types::f32, format::bfyx +#define CASE_RESAMPLE_I8_2 {2, 32, 4, 5}, {2, 32, 2, 3}, data_types::i8, format::b_fs_yx_fsv16, resample_type::nearest, data_types::f32, format::bfyx +#define CASE_RESAMPLE_I8_3 {1, 16, 4, 5}, {1, 16, 2, 3}, data_types::i8, format::b_fs_yx_fsv16, resample_type::bilinear, data_types::f32, format::bfyx +#define CASE_RESAMPLE_I8_4 {2, 32, 4, 5}, {2, 32, 2, 3}, data_types::i8, format::b_fs_yx_fsv16, resample_type::bilinear, data_types::f32, format::bfyx + +#define CASE_RESAMPLE_U8_1 {1, 16, 4, 5}, {1, 16, 2, 3}, data_types::u8, format::b_fs_yx_fsv16, resample_type::nearest, data_types::f32, format::bfyx +#define CASE_RESAMPLE_U8_2 {2, 32, 4, 5}, {2, 32, 2, 3}, data_types::u8, format::b_fs_yx_fsv16, resample_type::nearest, data_types::f32, format::bfyx +#define CASE_RESAMPLE_U8_3 {1, 16, 4, 5}, {1, 16, 2, 3}, data_types::u8, format::b_fs_yx_fsv16, resample_type::bilinear, data_types::f32, format::bfyx +#define CASE_RESAMPLE_U8_4 {2, 32, 4, 5}, {2, 32, 2, 3}, data_types::u8, format::b_fs_yx_fsv16, resample_type::bilinear, data_types::f32, format::bfyx + class resample_quantize : public ResamplePrimitiveFusingTest {}; TEST_P(resample_quantize, basic) { auto p = GetParam(); @@ -2410,6 +2420,126 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_activation, resample_test_params{ CASE_RESAMPLE_FP16_8, 2, 4 }, resample_test_params{ CASE_RESAMPLE_FP16_9, 2, 4 }, resample_test_params{ CASE_RESAMPLE_FP16_10, 2, 4 }, + + resample_test_params{ CASE_RESAMPLE_I8_1, 2, 4 }, + resample_test_params{ CASE_RESAMPLE_I8_2, 2, 4 }, + resample_test_params{ CASE_RESAMPLE_I8_3, 2, 4 }, + resample_test_params{ CASE_RESAMPLE_I8_4, 2, 4 }, + + resample_test_params{ CASE_RESAMPLE_U8_1, 2, 4 }, + resample_test_params{ CASE_RESAMPLE_U8_2, 2, 4 }, + resample_test_params{ CASE_RESAMPLE_U8_3, 2, 4 }, + resample_test_params{ CASE_RESAMPLE_U8_4, 2, 4 }, +}), ); + +class resample_quantize_concat : public ResamplePrimitiveFusingTest {}; +TEST_P(resample_quantize_concat, along_f) { + auto p = GetParam(); + create_topologies( + input_layout("input", get_input_layout(p)), + resample("resample1", "input", p.out_shape, p.in_shape.feature[0], p.type), + data("in_lo_1", get_mem(get_per_channel_layout(p), min_random, 0)), + data("in_hi_1", get_mem(get_per_channel_layout(p), 1, max_random)), + data("out_lo_1", get_mem(get_single_element_layout(p), -128)), + data("out_hi_1", get_mem(get_single_element_layout(p), 127)), + quantize("quant1", "resample1", "in_lo_1", "in_hi_1", "out_lo_1", "out_hi_1", 256, data_types::i8), + resample("resample2", "input", p.out_shape, p.in_shape.feature[0], p.type), + data("in_lo_2", get_mem(get_per_channel_layout(p), min_random, 0)), + data("in_hi_2", get_mem(get_per_channel_layout(p), 1, max_random)), + data("out_lo_2", get_mem(get_single_element_layout(p), -127)), + data("out_hi_2", get_mem(get_single_element_layout(p), 127)), + quantize("quant2", "resample2", "in_lo_2", "in_hi_2", "out_lo_2", "out_hi_2", 255, data_types::i8), + concatenation("concat", { "quant1", "quant2" }, cldnn::concatenation::along_f), + reorder("reorder_bfyx", "concat", cldnn::format::bfyx, p.default_type) + ); + + tolerance = 1.f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_quantize_concat, + ::testing::ValuesIn(std::vector{ + resample_test_params{ CASE_RESAMPLE_FP32_1, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_2, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_3, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_4, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_5, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_6, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_7, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_8, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_9, 3, 6 }, + + resample_test_params{ CASE_RESAMPLE_FP16_1, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_2, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_3, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_4, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_5, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_6, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_7, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_8, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_9, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_10, 3, 6 }, + + resample_test_params{ CASE_RESAMPLE_I8_3, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_I8_4, 3, 6 }, + + resample_test_params{ CASE_RESAMPLE_U8_3, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_U8_4, 3, 6 }, +}), ); + +class resample_scale_concat : public ResamplePrimitiveFusingTest {}; +TEST_P(resample_scale_concat, along_f) { + auto p = GetParam(); + create_topologies( + input_layout("input", get_input_layout(p)), + resample("resample1", "input", p.out_shape, p.in_shape.feature[0], p.type), + data("scale1_scale", get_mem(get_per_channel_layout(p), -10, 10)), + data("scale1_shift", get_mem(get_per_channel_layout(p), -10, 10)), + scale("scale1", "resample1", "scale1_scale", "scale1_shift"), + resample("resample2", "input", p.out_shape, p.in_shape.feature[0], p.type), + data("scale2_scale", get_mem(get_per_channel_layout(p), -10, 10)), + data("scale2_shift", get_mem(get_per_channel_layout(p), -10, 10)), + scale("scale2", "resample2", "scale2_scale", "scale2_shift"), + concatenation("concat", { "scale1", "scale2" }, cldnn::concatenation::along_f), + reorder("reorder_bfyx", "concat", cldnn::format::bfyx, p.default_type) + ); + + tolerance = 1e-5f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, resample_scale_concat, + ::testing::ValuesIn(std::vector{ + resample_test_params{ CASE_RESAMPLE_FP32_1, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_2, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_3, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_4, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_5, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_6, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_7, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_8, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP32_9, 3, 6 }, + + resample_test_params{ CASE_RESAMPLE_FP16_1, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_2, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_3, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_4, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_5, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_6, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_7, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_8, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_9, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_FP16_10, 3, 6 }, + + resample_test_params{ CASE_RESAMPLE_I8_1, 3, 6}, + resample_test_params{ CASE_RESAMPLE_I8_2, 3, 6}, + resample_test_params{ CASE_RESAMPLE_I8_3, 3, 6}, + resample_test_params{ CASE_RESAMPLE_I8_4, 3, 6}, + + resample_test_params{ CASE_RESAMPLE_U8_1, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_U8_2, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_U8_3, 3, 6 }, + resample_test_params{ CASE_RESAMPLE_U8_4, 3, 6 }, }), ); /* ----------------------------------------------------------------------------------------------------- */ diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/resample_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/resample_gpu_test.cpp index 8e0b382..dfd7091 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/resample_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/resample_gpu_test.cpp @@ -530,20 +530,21 @@ struct resample_random_test_params { tensor output_size; uint32_t num_filter; resample_type operation_type; + uint32_t align_corners; format::type in_format; format::type out_format; }; struct resample_random_test : testing::TestWithParam{ template - void fill_random_typed(memory& mem, int min, int max) { + void fill_random_typed(memory& mem, int min, int max, int k) { auto size = mem.get_layout().size; size_t b = size.batch[0]; size_t f = size.feature[0]; size_t x = size.spatial[0]; size_t y = size.spatial[1]; - auto data = generate_random_4d(b, f, y, x, min, max); + auto data = generate_random_4d(b, f, y, x, min, max, k); auto ptr = mem.pointer(); for (size_t bi = 0; bi < b; ++bi) { for (size_t fi = 0; fi < f; ++fi) { @@ -562,16 +563,16 @@ struct resample_random_test : testing::TestWithParam(mem, -127, 127); + fill_random_typed(mem, -127, 127, 2); break; case data_types::f16: - fill_random_typed(mem, -127, 127); + fill_random_typed(mem, -127, 127, 2); break; case data_types::i8: - fill_random_typed(mem, -127, 127); + fill_random_typed(mem, -127, 127, 1); break; case data_types::u8: - fill_random_typed(mem, 0, 255); + fill_random_typed(mem, 0, 255, 1); break; default: break; @@ -579,14 +580,16 @@ struct resample_random_test : testing::TestWithParam - void compare_nearest_typed(const memory& input, const memory& output) { + void compare_nearest_typed(const memory& input, const memory& output, uint32_t align_corners) { auto output_lay = output.get_layout(); size_t b = output_lay.size.batch[0]; size_t f = output_lay.size.feature[0]; size_t x = output_lay.size.spatial[0]; size_t y = output_lay.size.spatial[1]; - float x_ratio = static_cast(input.get_layout().size.spatial[0]) / static_cast(x); - float y_ratio = static_cast(input.get_layout().size.spatial[1]) / static_cast(y); + size_t in_x = input.get_layout().size.spatial[0]; + size_t in_y = input.get_layout().size.spatial[1]; + float x_ratio = x > align_corners ? static_cast(in_x - align_corners) / static_cast(x - align_corners) : 0.f; + float y_ratio = y > align_corners ? static_cast(in_y - align_corners) / static_cast(y - align_corners) : 0.f; auto in_ptr = input.pointer(); auto out_ptr = output.pointer(); @@ -609,17 +612,88 @@ struct resample_random_test : testing::TestWithParam + void compare_bilinear_typed(const memory& input, const memory& output, uint32_t align_corners) { + auto output_lay = output.get_layout(); + size_t b = output_lay.size.batch[0]; + size_t f = output_lay.size.feature[0]; + size_t x = output_lay.size.spatial[0]; + size_t y = output_lay.size.spatial[1]; + auto input_lay = input.get_layout(); + size_t in_x = input_lay.size.spatial[0]; + size_t in_y = input_lay.size.spatial[1]; + float x_ratio = x > align_corners ? static_cast(in_x - align_corners) / static_cast(x - align_corners) : 0.f; + float y_ratio = y > align_corners ? static_cast(in_y - align_corners) / static_cast(y - align_corners) : 0.f; + + auto in_ptr = input.pointer(); + auto out_ptr = output.pointer(); + for (size_t bi = 0; bi < b; ++bi) { + for (size_t fi = 0; fi < f; ++fi) { + for (size_t yi = 0; yi < y; ++yi) { + for (size_t xi = 0; xi < x; ++xi) { + auto low_in_xi = static_cast(floor(x_ratio * xi)); + auto low_in_yi = static_cast(floor(y_ratio * yi)); + auto high_in_xi = static_cast(ceil(x_ratio * xi)); + auto high_in_yi = static_cast(ceil(y_ratio * yi)); + + high_in_xi = std::min(high_in_xi, static_cast(in_x - 1)); + high_in_yi = std::min(high_in_yi, static_cast(in_y - 1)); + + auto dx = x_ratio * xi - static_cast(low_in_xi); + auto dy = y_ratio * yi - static_cast(low_in_yi); + + auto top_left_coords = tensor(batch(bi), feature(fi), spatial(low_in_xi, low_in_yi, 0, 0)); + auto top_right_coords = tensor(batch(bi), feature(fi), spatial(high_in_xi, low_in_yi, 0, 0)); + auto bottom_left_coords = tensor(batch(bi), feature(fi), spatial(low_in_xi, high_in_yi, 0, 0)); + auto bottom_right_coords = tensor(batch(bi), feature(fi), spatial(high_in_xi, high_in_yi, 0, 0)); + + auto top_left_val = in_ptr[input_lay.get_linear_offset(top_left_coords)]; + auto top_right_val = in_ptr[input_lay.get_linear_offset(top_right_coords)]; + auto bottom_left_val = in_ptr[input_lay.get_linear_offset(bottom_left_coords)]; + auto bottom_right_val = in_ptr[input_lay.get_linear_offset(bottom_right_coords)]; + + auto top_val = static_cast(top_left_val) + + (static_cast(top_right_val) - static_cast(top_left_val)) * dx; + auto bottom_val = static_cast(bottom_left_val) + + (static_cast(bottom_right_val) - static_cast(bottom_left_val)) * dx; + + auto final_val = top_val + (bottom_val - top_val) * dy; + + auto output_coords = tensor(batch(bi), feature(fi), spatial(xi, yi, 0, 0)); + auto output_val = out_ptr[output_lay.get_linear_offset(output_coords)]; + + EXPECT_NEAR(static_cast(output_val), final_val, 1.e-1f) + << " at bi=" << bi << ", fi=" << fi << ", xi=" << xi << ", yi=" << yi; + } + } + } + } + } + + void compare(const memory& input, const memory& output, resample_type operation, uint32_t align_corners) { + auto dt = input.get_layout().data_type; if (operation == resample_type::nearest) { + // Nearest resampling implicitly ignores align_corners + if (dt == data_types::f32) { + compare_nearest_typed(input, output, 0); + } else if (dt == data_types::f16) { + compare_nearest_typed(input, output, 0); + } else if (dt == data_types::i8) { + compare_nearest_typed(input, output, 0); + } else if (dt == data_types::u8) { + compare_nearest_typed(input, output, 0); + } else { + FAIL() << "Not supported data type: " << static_cast(dt); + } + } else if (operation == resample_type::bilinear) { if (dt == data_types::f32) { - compare_nearest_typed(input, output); + compare_bilinear_typed(input, output, align_corners); } else if (dt == data_types::f16) { - compare_nearest_typed(input, output); + compare_bilinear_typed(input, output, align_corners); } else if (dt == data_types::i8) { - compare_nearest_typed(input, output); + compare_bilinear_typed(input, output, align_corners); } else if (dt == data_types::u8) { - compare_nearest_typed(input, output); + compare_bilinear_typed(input, output, align_corners); } else { FAIL() << "Not supported data type: " << static_cast(dt); } @@ -633,10 +707,11 @@ struct resample_random_test : testing::TestWithParam