From 3c99c13febb619974a33b1be542a8497cd9ccb0c Mon Sep 17 00:00:00 2001 From: Vladimir Paramuzov Date: Mon, 27 Jul 2020 11:52:18 +0300 Subject: [PATCH] [IE CLDNN] Improvements for SpaceToDepth (#1454) --- inference-engine/src/cldnn_engine/cldnn_engine.cpp | 9 +- inference-engine/src/cldnn_engine/cldnn_graph.cpp | 3 +- .../src/cldnn_engine/cldnn_program.cpp | 2 +- .../single_layer_tests/space_to_depth_tests.cpp | 18 --- .../space_to_depth/space_to_depth_kernel_ref.cpp | 147 +++++++++++++-------- .../space_to_depth/space_to_depth_kernel_ref.h | 15 ++- .../core/cl_kernels/space_to_depth_ref.cl | 55 ++++++-- .../clDNN/src/gpu/space_to_depth_gpu.cpp | 24 +++- .../graph_optimizer/prepare_primitive_fusing.cpp | 11 +- .../thirdparty/clDNN/src/space_to_depth.cpp | 39 +++++- .../clDNN/tests/test_cases/fusings_gpu_test.cpp | 140 ++++++++++++++++++++ .../tests/test_cases/space_to_depth_gpu_test.cpp | 109 +++++++++++++++ 12 files changed, 470 insertions(+), 102 deletions(-) delete mode 100644 inference-engine/tests_deprecated/functional/cldnn/shared_tests_instance/single_layer_tests/space_to_depth_tests.cpp diff --git a/inference-engine/src/cldnn_engine/cldnn_engine.cpp b/inference-engine/src/cldnn_engine/cldnn_engine.cpp index 01d10c0..a189b3c 100644 --- a/inference-engine/src/cldnn_engine/cldnn_engine.cpp +++ b/inference-engine/src/cldnn_engine/cldnn_engine.cpp @@ -73,13 +73,18 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneNetwork(const InferenceEngin std::shared_ptr clonedNetwork = cloneNetwork(network); if (clonedNetwork->getFunction()) { const auto transformations_callback = [](const std::shared_ptr &node) -> bool { - // DepthToSpace node implementation supports only equal input/output tensors with rank <= 5 // Reshape->Permute->Reshape pattern in theory can change output rank, so this check is added to be sure - // that DepthToSpace impl will handle fused case + // that the following primitives will be handled correctly + // DepthToSpace node implementation supports only equal input/output tensors with rank <= 5 if (auto dtsOp = std::dynamic_pointer_cast(node)) { return dtsOp->input_value(0).get_shape().size() <= 5lu && dtsOp->input_value(0).get_shape().size() == dtsOp->get_output_shape(0).size(); } + // SpaceToDepth node implementation supports only equal input/output tensors with rank <= 5 + if (auto stdOp = std::dynamic_pointer_cast(node)) { + return stdOp->input_value(0).get_shape().size() <= 5lu && stdOp->input_value(0).get_shape().size() == stdOp->get_output_shape(0).size(); + } + return std::dynamic_pointer_cast(node) || std::dynamic_pointer_cast(node) || std::dynamic_pointer_cast(node) || diff --git a/inference-engine/src/cldnn_engine/cldnn_graph.cpp b/inference-engine/src/cldnn_engine/cldnn_graph.cpp index dfc6512..10c23f2 100644 --- a/inference-engine/src/cldnn_engine/cldnn_graph.cpp +++ b/inference-engine/src/cldnn_engine/cldnn_graph.cpp @@ -186,7 +186,8 @@ InferenceEngine::ICNNNetwork::Ptr CLDNNGraph::GetExecGraphInfoByPrimitivesInfo(s { "reduce_l1", "ReduceL1" }, { "reduce_l2", "ReduceL2" }, { "reduce_log_sum", "ReduceLogSum" }, - { "reduce_log_sum_exp", "ReduceLogSumExp" } + { "reduce_log_sum_exp", "ReduceLogSumExp" }, + { "space_to_depth", "SpaceToDepth" }, }; if (type_n2l.find(cldnn_name) != type_n2l.end()) diff --git a/inference-engine/src/cldnn_engine/cldnn_program.cpp b/inference-engine/src/cldnn_engine/cldnn_program.cpp index ec846cf..45fad3a 100644 --- a/inference-engine/src/cldnn_engine/cldnn_program.cpp +++ b/inference-engine/src/cldnn_engine/cldnn_program.cpp @@ -3979,7 +3979,7 @@ void Program::CreateSpaceToDepthPrimitive(cldnn::topology& topology, InferenceEn auto spaceToDepth = as (layer); size_t blockSize = static_cast(spaceToDepth->GetParamAsUInt("block_size", 1)); - std::string modeAsString = spaceToDepth->GetParamAsString("depth_mode", "blocks_first"); + std::string modeAsString = spaceToDepth->GetParamAsString("mode", "blocks_first"); cldnn::space_to_depth::depth_mode mode; mode = (modeAsString == "blocks_first") ? cldnn::space_to_depth::blocks_first : cldnn::space_to_depth::depth_first; diff --git a/inference-engine/tests_deprecated/functional/cldnn/shared_tests_instance/single_layer_tests/space_to_depth_tests.cpp b/inference-engine/tests_deprecated/functional/cldnn/shared_tests_instance/single_layer_tests/space_to_depth_tests.cpp deleted file mode 100644 index a827f16..0000000 --- a/inference-engine/tests_deprecated/functional/cldnn/shared_tests_instance/single_layer_tests/space_to_depth_tests.cpp +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (C) 2018-2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "space_to_depth_tests.hpp" - -INSTANTIATE_TEST_CASE_P( - smoke_GPU_TestsSpaceToDepth, SpaceToDepthTests, - ::testing::Values( - space_to_depth_test_params{ "GPU", "FP32", { 1, 1, 6, 4 }, "blocks_first", 2, { 1, 4, 3, 2 } }, - space_to_depth_test_params{ "GPU", "FP32", { 1, 1, 9, 9 }, "blocks_first", 3, { 1, 9, 3, 3 } }, - space_to_depth_test_params{ "GPU", "FP32", { 1, 2, 9, 9 }, "blocks_first", 3, { 1, 18, 3, 3 } }, - space_to_depth_test_params{ "GPU", "FP32", { 1, 10, 4096, 1024 }, "blocks_first", 4, { 1, 160, 1024, 256 } }, - space_to_depth_test_params{ "GPU", "FP32", { 1, 1, 6, 4 }, "depth_first", 2, { 1, 4, 3, 2 } }, - space_to_depth_test_params{ "GPU", "FP32", { 1, 1, 9, 9 }, "depth_first", 3, { 1, 9, 3, 3 } }, - space_to_depth_test_params{ "GPU", "FP32", { 1, 2, 9, 9 }, "depth_first", 3, { 1, 18, 3, 3 } }, - space_to_depth_test_params{ "GPU", "FP32", { 1, 10, 4096, 1024 }, "depth_first", 4, { 1, 160, 1024, 256 } } -)); diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/space_to_depth/space_to_depth_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/space_to_depth/space_to_depth_kernel_ref.cpp index 66b7243..8a0b228 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/space_to_depth/space_to_depth_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/space_to_depth/space_to_depth_kernel_ref.cpp @@ -20,72 +20,109 @@ #include namespace kernel_selector { - ParamsKey SpaceToDepthKernelRef::GetSupportedKey() const { - ParamsKey k; - k.EnableInputDataType(Datatype::F16); - k.EnableInputDataType(Datatype::F32); - k.EnableOutputDataType(Datatype::F16); - k.EnableOutputDataType(Datatype::F32); - k.EnableAllInputLayout(); - k.EnableAllOutputLayout(); - k.EnableTensorOffset(); - k.EnableTensorPitches(); - k.EnableBatching(); - return k; +ParamsKey SpaceToDepthKernelRef::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::UINT8); + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableAllInputLayout(); + k.EnableAllOutputLayout(); + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDifferentTypes(); + return k; +} + +bool SpaceToDepthKernelRef::Validate(const Params& p, const optional_params& o) const { + if (p.GetType() != KernelType::SPACE_TO_DEPTH || + o.GetType() != KernelType::SPACE_TO_DEPTH) { + return false; } - CommonDispatchData SpaceToDepthKernelRef::SetDefault(const space_to_depth_params& params, - const optional_params&) const { - CommonDispatchData runInfo; - - std::vector global = {params.output.Batch().v, - params.output.Feature().v, - params.output.Y().v * params.output.X().v}; - - auto local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo); - - runInfo.gws0 = global[0]; - runInfo.gws1 = global[1]; - runInfo.gws2 = global[2]; - - runInfo.lws0 = local[0]; - runInfo.lws1 = local[1]; - runInfo.lws2 = local[2]; - - return runInfo; + const space_to_depth_params& params = static_cast(p); + for (auto& fused_op : params.fused_ops) { + if (!IsFusedPrimitiveSupported(fused_op)) + return false; } - JitConstants SpaceToDepthKernelRef::GetJitConstants(const space_to_depth_params& params) const { - JitConstants jit = MakeBaseParamsJitConstants(params); + if (params.inputs[0].Dimentions() > 5) + return false; + + return true; +} + +CommonDispatchData SpaceToDepthKernelRef::SetDefault(const space_to_depth_params& params, + const optional_params&) const { + CommonDispatchData runInfo; + + std::vector global = {params.output.Batch().v, + params.output.Feature().v, + params.output.Z().v * params.output.Y().v * params.output.X().v}; + + auto local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo); + + runInfo.gws0 = global[0]; + runInfo.gws1 = global[1]; + runInfo.gws2 = global[2]; + + runInfo.lws0 = local[0]; + runInfo.lws1 = local[1]; + runInfo.lws2 = local[2]; + + return runInfo; +} + +JitConstants SpaceToDepthKernelRef::GetJitConstants(const space_to_depth_params& params) const { + JitConstants jit = MakeBaseParamsJitConstants(params); + + jit.AddConstant(MakeJitConstant("BLOCK_SIZE", params.block_size)); + if (params.depth_mode == SpaceToDepthMode::BLOCKS_FIRST) + jit.AddConstant(MakeJitConstant("BLOCKS_FIRST_MODE", true)); + else + jit.AddConstant(MakeJitConstant("DEPTH_FIRST_MODE", true)); + + auto input = params.inputs[0]; + auto input_dt = input.GetDType(); + if (!params.fused_ops.empty()) { + std::vector idx_order; + if (input.Dimentions() == 5) { + idx_order = {"batch", "feature", "z", "y", "x"}; + } else if (input.Dimentions() == 4) { + idx_order = {"batch", "feature", "y", "x"}; + } + FusedOpsConfiguration conf = {"", idx_order, "in_val", input_dt, 1}; + jit.Merge(MakeFusedOpsJitConstants(params, {conf})); + } - const size_t block_size = params.block_size; - const size_t squared_block_size = params.block_size * params.block_size; - const size_t blocks_first_mode = (size_t)params.depth_mode; + return jit; +} - jit.AddConstant(MakeJitConstant("BLOCK_SIZE", block_size)); - jit.AddConstant(MakeJitConstant("SQUARED_BLOCK_SIZE", squared_block_size)); - jit.AddConstant(MakeJitConstant("BLOCKS_FIRST_MODE", blocks_first_mode)); +KernelsData SpaceToDepthKernelRef::GetKernelsData(const Params& params, const optional_params& options) const { + KernelData kd = KernelData::Default(params); + space_to_depth_params& newParams = *static_cast(kd.params.get()); - return jit; + if (!Validate(params, options)) { + return {}; } - KernelsData SpaceToDepthKernelRef::GetKernelsData(const Params& params, const optional_params& options) const { - KernelData kd = KernelData::Default(params); - space_to_depth_params& newParams = *static_cast(kd.params.get()); - - assert(params.GetType() == KernelType::SPACE_TO_DEPTH); + auto runInfo = SetDefault(newParams, options); + auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options); + auto cldnn_jit = GetJitConstants(newParams); + std::string jit = CreateJit(kernelName, cldnn_jit, entry_point); - auto runInfo = SetDefault(newParams, options); - auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options); - auto cldnn_jit = GetJitConstants(newParams); - std::string jit = CreateJit(kernelName, cldnn_jit, entry_point); + auto& kernel = kd.kernels[0]; - auto& kernel = kd.kernels[0]; + FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point, + DEFAULT, false, false, 1, GetFusedPrimitiveInputsCount(params)); - FillCLKernelData(kernel, runInfo, params.engineInfo, kernelName, jit, entry_point); + kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE; - kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE; - - return {kd}; - } + return {kd}; +} } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/space_to_depth/space_to_depth_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/space_to_depth/space_to_depth_kernel_ref.h index b704593..29af4c5 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/space_to_depth/space_to_depth_kernel_ref.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/space_to_depth/space_to_depth_kernel_ref.h @@ -43,10 +43,19 @@ struct space_to_depth_optional_params : optional_params { class SpaceToDepthKernelRef : public common_kernel_base { public: SpaceToDepthKernelRef() : common_kernel_base("space_to_depth_ref") {} - virtual ~SpaceToDepthKernelRef() {} - virtual JitConstants GetJitConstants(const space_to_depth_params& params) const; - virtual CommonDispatchData SetDefault(const space_to_depth_params& params, const optional_params&) const; + virtual ~SpaceToDepthKernelRef() = default; KernelsData GetKernelsData(const Params& params, const optional_params& options) const override; ParamsKey GetSupportedKey() const override; + +protected: + virtual CommonDispatchData SetDefault(const space_to_depth_params& params, const optional_params&) const; + virtual JitConstants GetJitConstants(const space_to_depth_params& params) const; + virtual bool Validate(const Params& p, const optional_params& o) const; + std::vector GetSupportedFusedOps() const override { + return { FusedOpType::ELTWISE, + FusedOpType::QUANTIZE, + FusedOpType::SCALE, + FusedOpType::ACTIVATION }; + } }; } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/space_to_depth_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/space_to_depth_ref.cl index a91475a..028db55 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/space_to_depth_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/space_to_depth_ref.cl @@ -14,22 +14,59 @@ #include "include/include_all.cl" -KERNEL(space_to_depth_ref)(const __global UNIT_TYPE* input, __global UNIT_TYPE* output) +#if OUTPUT_DIMS == 5 +#define SPATIAL_BLOCK_SIZE (BLOCK_SIZE*BLOCK_SIZE*BLOCK_SIZE) +#else +#define SPATIAL_BLOCK_SIZE (BLOCK_SIZE*BLOCK_SIZE) +#endif + +KERNEL(space_to_depth_ref)(const __global INPUT0_TYPE* input, + __global OUTPUT_TYPE* output +#if HAS_FUSED_OPS_DECLS + , FUSED_OPS_DECLS +#endif +) { const uint batch = get_global_id(0); const uint feature = get_global_id(1); + +#if OUTPUT_DIMS == 5 + const uint z = ((uint)get_global_id(2) / OUTPUT_SIZE_X) / OUTPUT_SIZE_Y; + const uint y = ((uint)get_global_id(2) / OUTPUT_SIZE_X) % OUTPUT_SIZE_Y; + const uint x = (uint)get_global_id(2) % OUTPUT_SIZE_X; +#else + const uint z = 0; const uint y = (uint)get_global_id(2) / OUTPUT_SIZE_X; const uint x = (uint)get_global_id(2) % OUTPUT_SIZE_X; +#endif + +#if BLOCKS_FIRST_MODE + const uint input_offset = feature / INPUT0_FEATURE_NUM; + const uint input_feature = feature % INPUT0_FEATURE_NUM; +#else + const uint input_offset = feature % SPATIAL_BLOCK_SIZE; + const uint input_feature = feature / SPATIAL_BLOCK_SIZE; +#endif - const uint input_offset = BLOCKS_FIRST_MODE * (feature / INPUT0_FEATURE_NUM) + (!BLOCKS_FIRST_MODE) * (feature % SQUARED_BLOCK_SIZE); +#if OUTPUT_DIMS == 5 + const uint input_z = (z * BLOCK_SIZE) + ((input_offset / BLOCK_SIZE) / BLOCK_SIZE); + const uint input_y = (y * BLOCK_SIZE) + ((input_offset / BLOCK_SIZE) % BLOCK_SIZE); + const uint input_x = (x * BLOCK_SIZE) + (input_offset % BLOCK_SIZE); + const uint input_index = INPUT0_GET_INDEX(batch, input_feature, input_z, input_y, input_x); + const uint output_index = OUTPUT_GET_INDEX(batch, feature, z, y, x); +#else + const uint input_z = 0; const uint input_y = (y * BLOCK_SIZE) + (input_offset / BLOCK_SIZE); const uint input_x = (x * BLOCK_SIZE) + (input_offset % BLOCK_SIZE); + const uint input_index = INPUT0_GET_INDEX(batch, input_feature, input_y, input_x); + const uint output_index = OUTPUT_GET_INDEX(batch, feature, y, x); +#endif - const uint input_feature = BLOCKS_FIRST_MODE * (feature % INPUT0_FEATURE_NUM) + (!BLOCKS_FIRST_MODE) * (feature / SQUARED_BLOCK_SIZE); - const uint input_feature_offset = (input_y * INPUT0_Y_PITCH) + input_x; - - const uint input_index = INPUT0_OFFSET + (batch * INPUT0_BATCH_PITCH) + (input_feature * INPUT0_FEATURE_PITCH) + input_feature_offset; - const uint output_index = OUTPUT_OFFSET + (batch * OUTPUT_BATCH_PITCH) + (feature * OUTPUT_FEATURE_PITCH) + (y * OUTPUT_Y_PITCH) + x; - - output[output_index] = ACTIVATION(input[input_index], ACTIVATION_PARAMS); + INPUT0_TYPE in_val = input[input_index]; +#if HAS_FUSED_OPS + FUSED_OPS; + output[output_index] = FUSED_OPS_RESULT; +#else + output[output_index] = ACTIVATION(in_val, ACTIVATION_PARAMS); +#endif } diff --git a/inference-engine/thirdparty/clDNN/src/gpu/space_to_depth_gpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/space_to_depth_gpu.cpp index 98e5d53..b4fea78 100644 --- a/inference-engine/thirdparty/clDNN/src/gpu/space_to_depth_gpu.cpp +++ b/inference-engine/thirdparty/clDNN/src/gpu/space_to_depth_gpu.cpp @@ -60,10 +60,26 @@ namespace detail { attach_space_to_depth_gpu::attach_space_to_depth_gpu() { auto val_fw = space_to_depth_gpu::create; - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), - val_fw); - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), - val_fw); + + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfzyx), val_fw); + + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), val_fw); + + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv16), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv16), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv16), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv16), val_fw); + + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv4), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv4), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv4), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv4), val_fw); } } // namespace detail diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp index a609cc5..1915300 100644 --- a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp @@ -42,6 +42,7 @@ #include "scale_inst.h" #include "resample_inst.h" #include "depth_to_space_inst.h" +#include "space_to_depth_inst.h" #include "gather_inst.h" #include "reverse_sequence_inst.h" #include "shuffle_channels_inst.h" @@ -375,6 +376,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); + if (!should_fuse) return; @@ -420,6 +423,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); + if (!should_fuse) return; @@ -496,6 +501,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); + if (!should_fuse) return; @@ -517,12 +524,12 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { bool can_fuse_parent1 = (parent1->is_type() && conv_supports_fusings(parent1->as())) || (parent1->is_type() && mvn_supports_fusings(parent1->as())) || (parent1->is_type()) || (parent1->is_type()) || - (parent1->is_type()) || (parent1->is_type()); + (parent1->is_type()) || (parent1->is_type()) || (parent1->is_type()); bool can_fuse_parent2 = (parent2->is_type() && conv_supports_fusings(parent2->as())) || (parent2->is_type() && mvn_supports_fusings(parent2->as())) || (parent2->is_type()) || (parent2->is_type()) || - (parent1->is_type()) || (parent2->is_type()); + (parent1->is_type()) || (parent1->is_type()) || (parent2->is_type()); std::vector can_fuse_parents = { can_fuse_parent1, can_fuse_parent2 }; diff --git a/inference-engine/thirdparty/clDNN/src/space_to_depth.cpp b/inference-engine/thirdparty/clDNN/src/space_to_depth.cpp index c2593d0..287f2cf 100644 --- a/inference-engine/thirdparty/clDNN/src/space_to_depth.cpp +++ b/inference-engine/thirdparty/clDNN/src/space_to_depth.cpp @@ -36,6 +36,11 @@ layout space_to_depth_inst::calc_output_layout(space_to_depth_node const& node) const size_t block_size = desc->block_size; auto depth_mode = desc->mode; + auto output_type = input_layout.data_type; + if (node.has_fused_primitives()) { + output_type = node.get_fused_output_layout().data_type; + } + if (depth_mode != space_to_depth::depth_first && depth_mode != space_to_depth::blocks_first) CLDNN_ERROR_MESSAGE(node.id(), "Invalid mode for spaceToDepth: must be \"blocks_first\" or \"depth_first\" only"); @@ -52,14 +57,34 @@ layout space_to_depth_inst::calc_output_layout(space_to_depth_node const& node) std::to_string(input_layout.size.spatial[0]) + ", " + std::to_string(input_layout.size.spatial[1]) + " (x, y). Actual block size is " + std::to_string(block_size)); - const size_t feature = input_layout.size.feature[0] * block_size * block_size; - const size_t y = input_layout.size.spatial[1] / block_size; - const size_t x = input_layout.size.spatial[0] / block_size; - return layout{ - input_layout.data_type, - input_format, - tensor(TensorValue(input_layout.size.batch[0]), TensorValue(feature), TensorValue(x), TensorValue(y))}; + if (input_layout.format.dimension() == 5) { + if (input_layout.size.spatial[2] % block_size != 0) + CLDNN_ERROR_MESSAGE( + node.id(), + "Sizes of spatials z must be divisible by block size. Actual spatial sizes are " + + std::to_string(input_layout.size.spatial[2]) + + " (z). Block size is " + std::to_string(block_size)); + + const size_t feature = input_layout.size.feature[0] * block_size * block_size * block_size; + const size_t z = input_layout.size.spatial[2] / block_size; + const size_t y = input_layout.size.spatial[1] / block_size; + const size_t x = input_layout.size.spatial[0] / block_size; + + return layout{ + output_type, + input_format, + tensor(TensorValue(input_layout.size.batch[0]), TensorValue(feature), TensorValue(x), TensorValue(y), TensorValue(z))}; + } else { + const size_t feature = input_layout.size.feature[0] * block_size * block_size; + const size_t y = input_layout.size.spatial[1] / block_size; + const size_t x = input_layout.size.spatial[0] / block_size; + + return layout{ + output_type, + input_format, + tensor(TensorValue(input_layout.size.batch[0]), TensorValue(feature), TensorValue(x), TensorValue(y))}; + } } std::string space_to_depth_inst::to_string(space_to_depth_node const& node) { diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 25a0978..b48e04d 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -35,6 +35,7 @@ #include "api/permute.hpp" #include "api/gather.hpp" #include "api/depth_to_space.hpp" +#include "api/space_to_depth.hpp" #include "test_utils/test_utils.h" @@ -4696,6 +4697,145 @@ INSTANTIATE_TEST_CASE_P( }), ); /* ----------------------------------------------------------------------------------------------------- */ +/* -------------------------------- SpaceToDepth cases ------------------------------------------------- */ +/* ----------------------------------------------------------------------------------------------------- */ +struct space_to_depth_params { + tensor input_size; + tensor output_size; + space_to_depth::depth_mode mode; + data_types input_type; + format input_format; + size_t block_size; + data_types default_type; + format default_format; + size_t expected_fused_primitives; + size_t expected_not_fused_primitives; +}; + +#define CASE_SPACE_TO_DEPTH_F32_1 {2, 2, 8, 10}, {2, 8, 4, 5}, space_to_depth::depth_mode::blocks_first, data_types::f32, format::bfyx, 2, data_types::f32, format::bfyx +#define CASE_SPACE_TO_DEPTH_F32_2 {1, 2, 6, 6, 6}, {1, 54, 2, 2, 2}, space_to_depth::depth_mode::depth_first, data_types::f32, format::bfzyx, 3, data_types::f32, format::bfyx +#define CASE_SPACE_TO_DEPTH_F16_1 {1, 3, 6, 6}, {1, 12, 3, 3}, space_to_depth::depth_mode::blocks_first, data_types::f16, format::bfyx, 2, data_types::f32, format::bfyx +#define CASE_SPACE_TO_DEPTH_F16_2 {2, 1, 3, 3}, {2, 9, 1, 1}, space_to_depth::depth_mode::blocks_first, data_types::f16, format::b_fs_yx_fsv16, 3, data_types::f32, format::bfyx +#define CASE_SPACE_TO_DEPTH_U8_1 {2, 2, 8, 10}, {2, 8, 4, 5}, space_to_depth::depth_mode::blocks_first, data_types::u8, format::bfyx, 2, data_types::f32, format::bfyx +#define CASE_SPACE_TO_DEPTH_U8_2 {1, 2, 6, 6, 6}, {1, 54, 2, 2, 2}, space_to_depth::depth_mode::depth_first, data_types::u8, format::bfzyx, 3, data_types::f32, format::bfyx +#define CASE_SPACE_TO_DEPTH_I8_1 {1, 3, 6, 6}, {1, 12, 3, 3}, space_to_depth::depth_mode::blocks_first, data_types::i8, format::bfyx, 2, data_types::f32, format::bfyx +#define CASE_SPACE_TO_DEPTH_I8_2 {2, 1, 3, 3}, {2, 9, 1, 1}, space_to_depth::depth_mode::blocks_first, data_types::i8, format::b_fs_yx_fsv16, 3, data_types::f32, format::bfyx + +class SpaceToDepthFusingsTest : public ::BaseFusingTest { +public: + void execute(space_to_depth_params& p) { + auto input_prim = get_mem(get_input_layout(p)); + + network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused); + network network_fused(this->engine, this->topology_fused, bo_fused); + + network_fused.set_input_data("input", input_prim); + network_not_fused.set_input_data("input", input_prim); + + compare(network_not_fused, network_fused, p); + } + + layout get_input_layout(space_to_depth_params& p) { return layout{p.input_type, p.input_format, p.input_size}; } + + layout get_per_channel_layout(space_to_depth_params& p) { + return layout{p.default_type, p.default_format, tensor{1, p.output_size.feature[0], 1, 1}}; + } + format get_input_format(space_to_depth_params &p) { return p.input_format; } +}; + +class space_to_depth_quantize_i8 : public SpaceToDepthFusingsTest {}; +TEST_P(space_to_depth_quantize_i8, basic) { + auto p = GetParam(); + create_topologies(input_layout("input", get_input_layout(p)), + space_to_depth("space_to_depth", "input", p.mode, p.block_size), + data("in_low", get_mem(get_per_channel_layout(p), min_random, 0)), + data("in_high", get_mem(get_per_channel_layout(p), 1, max_random)), + data("out_low", get_mem(get_single_element_layout(p), -128)), + data("out_high", get_mem(get_single_element_layout(p), 127)), + quantize("quant", "space_to_depth", "in_low", "in_high", "out_low", "out_high", 256, data_types::i8), + reorder("reorder_bfyx", "quant", format::bfyx, data_types::f32)); + + tolerance = 1.f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P( + fusings_gpu, + space_to_depth_quantize_i8, + ::testing::ValuesIn(std::vector{ + space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_1, 2, 3}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_2, 2, 3}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_1, 2, 3}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_2, 2, 3}, + }), ); + +class space_to_depth_scale_act_eltwise_quantize_u8 : public SpaceToDepthFusingsTest {}; +TEST_P(space_to_depth_scale_act_eltwise_quantize_u8, basic) { + auto p = GetParam(); + create_topologies(input_layout("input", get_input_layout(p)), + space_to_depth("space_to_depth", "input", p.mode, p.block_size), + data("scale1_data", get_mem(get_per_channel_layout(p), -0.125f)), + scale("scale1", "space_to_depth", "scale1_data"), + activation("actv1", "scale1", activation_func::relu), + data("eltw_data", get_mem(layout(p.default_type, p.input_format, p.output_size))), + eltwise("eltw", {"actv1", "eltw_data"}, eltwise_mode::sum, p.default_type), + data("in_low", get_mem(get_per_channel_layout(p), min_random, 0)), + data("in_high", get_mem(get_per_channel_layout(p), 1, max_random)), + data("out_low", get_mem(get_single_element_layout(p), 0)), + data("out_high", get_mem(get_single_element_layout(p), 255)), + quantize("quant", "eltw", "in_low", "in_high", "out_low", "out_high", 256, data_types::u8), + reorder("reorder_bfyx", "quant", format::bfyx, data_types::f32)); + + tolerance = 1.f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P( + fusings_gpu, + space_to_depth_scale_act_eltwise_quantize_u8, + ::testing::ValuesIn(std::vector{ + space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_1, 2, 6}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_2, 2, 6}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_1, 2, 6}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_2, 2, 6}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_U8_1, 2, 6}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_U8_2, 2, 6}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_I8_1, 2, 6}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_I8_2, 2, 6}, + }), ); + + +class space_to_depth_scale_act_eltw : public SpaceToDepthFusingsTest {}; +TEST_P(space_to_depth_scale_act_eltw, basic) { + auto p = GetParam(); + create_topologies(input_layout("input", get_input_layout(p)), + space_to_depth("space_to_depth", "input", p.mode, p.block_size), + data("scale1_data", get_mem(get_per_channel_layout(p), -0.125f)), + scale("scale1", "space_to_depth", "scale1_data"), + activation("actv1", "scale1", activation_func::relu), + data("eltw_data", get_mem(layout(p.default_type, p.input_format, p.output_size))), + eltwise("eltw", {"actv1", "eltw_data"}, eltwise_mode::sum, p.default_type), + reorder("reorder_bfyx", "eltw", format::bfyx, data_types::f32)); + + tolerance = 1e-5f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P( + fusings_gpu, + space_to_depth_scale_act_eltw, + ::testing::ValuesIn(std::vector{ + space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_1, 2, 5}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_F32_2, 2, 5}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_1, 2, 5}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_F16_2, 2, 5}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_U8_1, 2, 5}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_U8_2, 2, 5}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_I8_1, 2, 5}, + space_to_depth_params{CASE_SPACE_TO_DEPTH_I8_2, 2, 5}, + }), ); + +/* ----------------------------------------------------------------------------------------------------- */ /* ------------------------------------------ Gather cases --------------------------------------------- */ /* ----------------------------------------------------------------------------------------------------- */ struct gather_test_params { diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/space_to_depth_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/space_to_depth_gpu_test.cpp index 7d8007e..5b6935b 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/space_to_depth_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/space_to_depth_gpu_test.cpp @@ -811,3 +811,112 @@ TEST(space_to_depth_fp32_gpu, d1199_bs3_mdf) { } } +TEST(space_to_depth_fp32_gpu, d1199_bs3_mdf_fsv16) { + // Input : 1x1x9x9 + // Block size : 3 + // Output : 1x9x3x3 + // Input values in fp32 + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 1, 1, 9, 9 } }); + size_t block_size = 3; + + set_values(input1, { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, + 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, + 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, + 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, + 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, + 60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, + 70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f, + 80.0f + }); + + topology topology; + topology.add(input_layout("Input0", input1.get_layout())); + topology.add(reorder("reorder", "Input0", format::b_fs_yx_fsv16, data_types::f32)); + topology.add(space_to_depth("space_to_depth", "reorder", space_to_depth::depth_first, block_size)); + topology.add(reorder("reorder_out", "space_to_depth", format::bfyx, data_types::f32)); + + network network(engine, topology); + + network.set_input_data("Input0", input1); + + auto outputs = network.execute(); + + auto output = outputs.at("reorder_out").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 0.0f, 3.0f, 6.0f, 27.0f, 30.0f, 33.0f, 54.0f, 57.0f, 60.0f, 1.0f, + 4.0f, 7.0f, 28.0f, 31.0f, 34.0f, 55.0f, 58.0f, 61.0f, 2.0f, 5.0f, + 8.0f, 29.0f, 32.0f, 35.0f, 56.0f, 59.0f, 62.0f, 9.0f, 12.0f, 15.0f, + 36.0f, 39.0f, 42.0f, 63.0f, 66.0f, 69.0f, 10.0f, 13.0f, 16.0f, 37.0f, + 40.0f, 43.0f, 64.0f, 67.0f, 70.0f, 11.0f, 14.0f, 17.0f, 38.0f, 41.0f, + 44.0f, 65.0f, 68.0f, 71.0f, 18.0f, 21.0f, 24.0f, 45.0f, 48.0f, 51.0f, + 72.0f, 75.0f, 78.0f, 19.0f, 22.0f, 25.0f, 46.0f, 49.0f, 52.0f, 73.0f, + 76.0f, 79.0f, 20.0f, 23.0f, 26.0f, 47.0f, 50.0f, 53.0f, 74.0f, 77.0f, + 80.0f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], output_ptr[i]); + } +} + +TEST(space_to_depth_fp32_gpu, d1199_bs3_mdf_fsv4) { + // Input : 1x1x9x9 + // Block size : 3 + // Output : 1x9x3x3 + // Input values in fp32 + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f32, format::bfyx, { 1, 1, 9, 9 } }); + size_t block_size = 3; + + set_values(input1, { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, + 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, + 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, + 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, + 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, + 60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f, 68.0f, 69.0f, + 70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, 79.0f, + 80.0f + }); + + topology topology; + topology.add(input_layout("Input0", input1.get_layout())); + topology.add(reorder("reorder", "Input0", format::b_fs_yx_fsv4, data_types::f32)); + topology.add(space_to_depth("space_to_depth", "reorder", space_to_depth::depth_first, block_size)); + topology.add(reorder("reorder_out", "space_to_depth", format::bfyx, data_types::f32)); + + network network(engine, topology); + + network.set_input_data("Input0", input1); + + auto outputs = network.execute(); + + auto output = outputs.at("reorder_out").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 0.0f, 3.0f, 6.0f, 27.0f, 30.0f, 33.0f, 54.0f, 57.0f, 60.0f, 1.0f, + 4.0f, 7.0f, 28.0f, 31.0f, 34.0f, 55.0f, 58.0f, 61.0f, 2.0f, 5.0f, + 8.0f, 29.0f, 32.0f, 35.0f, 56.0f, 59.0f, 62.0f, 9.0f, 12.0f, 15.0f, + 36.0f, 39.0f, 42.0f, 63.0f, 66.0f, 69.0f, 10.0f, 13.0f, 16.0f, 37.0f, + 40.0f, 43.0f, 64.0f, 67.0f, 70.0f, 11.0f, 14.0f, 17.0f, 38.0f, 41.0f, + 44.0f, 65.0f, 68.0f, 71.0f, 18.0f, 21.0f, 24.0f, 45.0f, 48.0f, 51.0f, + 72.0f, 75.0f, 78.0f, 19.0f, 22.0f, 25.0f, 46.0f, 49.0f, 52.0f, 73.0f, + 76.0f, 79.0f, 20.0f, 23.0f, 26.0f, 47.0f, 50.0f, 53.0f, 74.0f, 77.0f, + 80.0f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], output_ptr[i]); + } +} -- 2.7.4