[IE CLDNN] Add b_fs_fsv16 concat optimizations (#1452)
authorKonrad Dobros <konrad.dobros@intel.com>
Mon, 27 Jul 2020 11:49:22 +0000 (13:49 +0200)
committerGitHub <noreply@github.com>
Mon, 27 Jul 2020 11:49:22 +0000 (14:49 +0300)
1. Add fsv16 int8 support to optimized kernel
2. Optimize fsv16 concat kernel
3. Add graph optimization to improve concat alignment

Issue: CVS-28494

inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_b_fs_yx_fsv16.h
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/concatenation/concatenation_kernel_base.h
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/concatenation_gpu_blocked.cl
inference-engine/thirdparty/clDNN/src/graph_optimizer/concat_input_order.cpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/src/include/pass_manager.h
inference-engine/thirdparty/clDNN/src/program.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/concatenation_gpu_test.cpp

index 8251a42..57fc050 100644 (file)
 #include "kernel_selector_utils.h"
 
 namespace kernel_selector {
+
+namespace {
+
+size_t getTileXY(const concatenation_params& params) {
+    auto& input = params.inputs[0];
+    size_t tileXY =  1;
+    if (params.isAligned) {
+        switch (input.GetDType()) {
+        case Datatype::F16:
+        case Datatype::INT8:
+        case Datatype::UINT8:
+            tileXY = 4;
+            break;
+        default:
+            return 1;
+        }
+    } else {
+        switch (input.GetDType()) {
+        case Datatype::F32:
+            tileXY = 2;
+            break;
+        case Datatype::F16:
+            tileXY = 4;
+            break;
+        case Datatype::INT8:
+        case Datatype::UINT8:
+            tileXY = 8;
+            break;
+        default:
+            return 1;
+        }
+    }
+
+    auto tileXYMultiple = input.X().v;
+    bool noInputPad = input.X().pad.Total() == 0;
+    bool noOutputPad = params.output.X().pad.Total() == 0;
+    if (noInputPad && noOutputPad)
+        tileXYMultiple = input.X().v * input.Y().v;
+
+    while (tileXYMultiple % tileXY != 0)
+        tileXY /= 2;
+
+    return tileXY;
+}
+
+}  // namespace
+
 ParamsKey ConcatenationKernel_b_fs_yx_fsv16::GetSupportedKey() const {
     ParamsKey k;
     k.EnableInputDataType(Datatype::F16);
     k.EnableOutputDataType(Datatype::F16);
     k.EnableInputDataType(Datatype::F32);
     k.EnableOutputDataType(Datatype::F32);
+    k.EnableInputDataType(Datatype::INT8);
+    k.EnableOutputDataType(Datatype::INT8);
+    k.EnableInputDataType(Datatype::UINT8);
+    k.EnableOutputDataType(Datatype::UINT8);
     k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
     k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
     k.EnableTensorOffset();
@@ -60,10 +111,13 @@ bool ConcatenationKernel_b_fs_yx_fsv16::Validate(const Params& p, const optional
 ConcatenationKernelBase::DispatchData ConcatenationKernel_b_fs_yx_fsv16::SetDefault(const concatenation_params& params) const {
     DispatchData runInfo = ConcatenationKernelBase::SetDefault(params);
     const auto& input = params.inputs[0];
+    auto tileXY = getTileXY(params);
 
-    runInfo.gws0 = input.Batch().v;
-    runInfo.gws1 = Align(input.Feature().v, 16);
-    runInfo.gws2 = input.X().v * input.Y().v;
+    size_t tileF = params.misalignment == 0 ? 1 : 2;
+
+    runInfo.gws0 = CeilDiv(input.X().v * input.Y().v, tileXY);
+    runInfo.gws1 = Align(input.Feature().v, 16 * tileF) / tileF;
+    runInfo.gws2 = input.Batch().v;
 
     runInfo.lws0 = 1;
     runInfo.lws1 = 16;
@@ -77,7 +131,9 @@ ConcatenationKernelBase::DispatchData ConcatenationKernel_b_fs_yx_fsv16::SetDefa
 JitConstants ConcatenationKernel_b_fs_yx_fsv16::GetJitConstants(const concatenation_params& params) const {
     JitConstants jit = MakeBaseParamsJitConstants(params);
 
-    jit.AddConstant(MakeJitConstant("ALIGNED", params.isAligned));
+    jit.AddConstant(MakeJitConstant("ALIGNED", params.misalignment == 0));
+    jit.AddConstant(MakeJitConstant("MISALIGNMENT", params.misalignment));
+    jit.AddConstant(MakeJitConstant("TILE_XY", getTileXY(params)));
 
     return jit;
 }
@@ -85,4 +141,8 @@ JitConstants ConcatenationKernel_b_fs_yx_fsv16::GetJitConstants(const concatenat
 KernelsData ConcatenationKernel_b_fs_yx_fsv16::GetKernelsData(const Params& params, const optional_params& optParams) const {
     return GetCommonKernelsData(params, optParams);
 }
+
+size_t ConcatenationKernel_b_fs_yx_fsv16::GetAlignment(const concatenation_params& /*params*/) const {
+    return 16;
+}
 }  // namespace kernel_selector
index 9bf2af8..cf8e3f9 100644 (file)
@@ -28,5 +28,6 @@ public:
     DispatchData SetDefault(const concatenation_params& params) const override;
     JitConstants GetJitConstants(const concatenation_params& params) const override;
     bool Validate(const Params& p, const optional_params& o) const override;
+    size_t GetAlignment(const concatenation_params& params) const override;
 };
 }  // namespace kernel_selector
index d76d5e1..0eb3fb2 100644 (file)
@@ -115,7 +115,8 @@ KernelsData ConcatenationKernelBase::GetCommonKernelsData(const Params& params,
         newParams.inputs.resize(1);
         newParams.inputs[0] = input;
         size_t ifm = input.Feature().v;
-        newParams.isAligned = ifm_offset % 16 == 0 && ifm % 16 == 0;
+        newParams.isAligned = ifm_offset % GetAlignment(newParams) == 0;
+        newParams.misalignment = ifm_offset % GetAlignment(newParams);
         ifm_offset += ifm;
 
         auto& kernel = kd.kernels[i];
@@ -127,7 +128,7 @@ KernelsData ConcatenationKernelBase::GetCommonKernelsData(const Params& params,
         kernel.workGroups.global = {runInfo.gws0, runInfo.gws1, runInfo.gws2};
         kernel.workGroups.local = {runInfo.lws0, runInfo.lws1, runInfo.lws2};
         kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo);
-        kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, (uint32_t)i});
+        kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, (uint32_t)i });
         kernel.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
 
         ScalarDescriptor s;
index 1239506..645bc99 100644 (file)
@@ -26,6 +26,7 @@ struct concatenation_params : public base_params {
 
     ConcatAxis axis = ConcatAxis::FEATURE;
     bool isAligned = true;
+    size_t misalignment = 0;
 
     virtual ParamsKey GetParamsKey() const {
         auto k = base_params::GetParamsKey();
@@ -71,5 +72,8 @@ protected:
     KernelsData GetCommonKernelsData(const Params& params, const optional_params&) const;
     int32_t GetConcatChannelIndex(const concatenation_params& params) const;
     Tensor::DataChannelName GetConcatChannel(const concatenation_params& params) const;
+    virtual size_t GetAlignment(const concatenation_params& /*params*/) const {
+        return 1;
+    }
 };
 }  // namespace kernel_selector
index 0182baa..0fdbe4d 100644 (file)
 
 
 #include "include/fetch.cl"
-#include "include/unit_type.cl"
+#include "include/data_types.cl"
 
 #define WORK_GROUP_SIZE 16
 #define IC_BLOCK 16
 
+#define INPUT_VEC_TYPE                          MAKE_VECTOR_TYPE(INPUT0_TYPE, TILE_XY)
+#define OUTPUT_VEC_TYPE                         MAKE_VECTOR_TYPE(OUTPUT_TYPE, TILE_XY)
+#define TO_OUTPUT_VEC_TYPE(x)                   CAT(convert_, OUTPUT_VEC_TYPE)(x)
+#define INPUT_BLOCK_READ(ptr, offset)           MAKE_VECTOR_TYPE(DT_INPUT_BLOCK_READ, TILE_XY)(ptr, offset)
+#define OUTPUT_BLOCK_WRITE(ptr, offset, val)    MAKE_VECTOR_TYPE(DT_OUTPUT_BLOCK_WRITE, TILE_XY)(ptr, offset, val)
+
+#if !ALIGNED
+// For non-aligned case process two features together to mitigate misalignment
+#   define TILE_F 2
+#else
+#   define TILE_F 1
+#endif
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
 __attribute__((reqd_work_group_size(1, WORK_GROUP_SIZE, 1)))
 __attribute__((intel_reqd_sub_group_size(WORK_GROUP_SIZE)))
-KERNEL (concatenation_gpu_blocked)(__global UNIT_TYPE* input, __global UNIT_TYPE* output, uint output_offset_in_concat_axis)
+KERNEL (concatenation_gpu_blocked)(
+    __global INPUT0_TYPE* input,
+    __global OUTPUT_TYPE* output,
+    uint output_offset_in_concat_axis)
 {
-    const int b = get_global_id(0);
-    const int f_block = get_group_id(1);
-    const int xy = get_global_id(2);
+    const int xy = (uint)get_global_id(0) * TILE_XY;
+    const int f_block = (uint)get_group_id(1) * TILE_F;
+    const int b = get_group_id(2);
     const int lid = get_sub_group_local_id();
 
     const int x = xy % OUTPUT_SIZE_X;
     const int y = xy / OUTPUT_SIZE_X;
 
+    const uint input_offset = INPUT0_GET_INDEX(b, f_block*IC_BLOCK, y, x);
 
 #if ALIGNED
-    const uint input_offset = INPUT0_GET_INDEX(b, f_block*IC_BLOCK, y, x);
+    INPUT_VEC_TYPE src = INPUT_BLOCK_READ(input, input_offset);
     const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + output_offset_in_concat_axis), y, x);
 
-    UNIT_TYPE src = UNIT_BLOCK_READ(input, input_offset);
-    src = ACTIVATION(src, ACTIVATION_PARAMS);
-    UNIT_BLOCK_WRITE(output, dst_index, src);
+    bool do_block_write = (INPUT0_FEATURE_NUM % IC_BLOCK == 0)
+                        || (f_block * IC_BLOCK + IC_BLOCK <= INPUT0_FEATURE_NUM);
+
+    if (do_block_write) {
+        OUTPUT_VEC_TYPE res = TO_OUTPUT_VEC_TYPE(ACTIVATION(src, ACTIVATION_PARAMS));
+        OUTPUT_BLOCK_WRITE(output, dst_index, res);
+    } else {
+        if (lid < INPUT0_FEATURE_NUM % IC_BLOCK) {
+            __attribute__((opencl_unroll_hint))
+            for (uint tx = 0; tx < TILE_XY; ++tx) {
+                OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(((INPUT0_TYPE*)&src)[tx], ACTIVATION_PARAMS));
+                output[dst_index + tx * IC_BLOCK + lid] = res;
+            }
+        }
+    }
 #else
-    if (f_block*IC_BLOCK + lid >= INPUT0_FEATURE_NUM)
-        return;
 
-    const uint input_offset = INPUT0_GET_INDEX(b, f_block*IC_BLOCK + lid, y, x);
-    const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + lid + output_offset_in_concat_axis), y, x);
+#if TILE_F != 1
+    bool full_write = (INPUT0_FEATURE_NUM % (IC_BLOCK * TILE_F) == 0) || (f_block * IC_BLOCK + TILE_F * IC_BLOCK <= INPUT0_FEATURE_NUM);
+    if (full_write) {
+        INPUT_VEC_TYPE src0 = INPUT_BLOCK_READ(input, input_offset + 0 * INPUT0_FEATURE_PITCH * IC_BLOCK);
+        INPUT_VEC_TYPE src1 = INPUT_BLOCK_READ(input, input_offset + 1 * INPUT0_FEATURE_PITCH * IC_BLOCK);
+    #if TILE_F == 4
+        INPUT_VEC_TYPE src2 = INPUT_BLOCK_READ(input, input_offset + 2 * INPUT0_FEATURE_PITCH * IC_BLOCK);
+        INPUT_VEC_TYPE src3 = INPUT_BLOCK_READ(input, input_offset + 3 * INPUT0_FEATURE_PITCH * IC_BLOCK);
+    #endif
 
-    UNIT_TYPE src = input[input_offset];
-    src = ACTIVATION(src, ACTIVATION_PARAMS);
-    output[dst_index] = src;
+        uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + (IC_BLOCK - MISALIGNMENT) + output_offset_in_concat_axis), y, x);
+
+        INPUT_VEC_TYPE src_al0 = 0;
+    #if TILE_F == 4
+        INPUT_VEC_TYPE src_al1 = 0;
+        INPUT_VEC_TYPE src_al2 = 0;
+    #endif
+        __attribute__((opencl_unroll_hint))
+        for (uint tx = 0; tx < TILE_XY; ++tx) {
+            ((INPUT0_TYPE*)&src_al0)[tx] = intel_sub_group_shuffle_down(((INPUT0_TYPE*)&src0)[tx], ((INPUT0_TYPE*)&src1)[tx], (IC_BLOCK - MISALIGNMENT));
+    #if TILE_F == 4
+            ((INPUT0_TYPE*)&src_al1)[tx] = intel_sub_group_shuffle_down(((INPUT0_TYPE*)&src1)[tx], ((INPUT0_TYPE*)&src2)[tx], (IC_BLOCK - MISALIGNMENT));
+            ((INPUT0_TYPE*)&src_al2)[tx] = intel_sub_group_shuffle_down(((INPUT0_TYPE*)&src2)[tx], ((INPUT0_TYPE*)&src3)[tx], (IC_BLOCK - MISALIGNMENT));
+    #endif
+        }
+        OUTPUT_VEC_TYPE res_al0 = TO_OUTPUT_VEC_TYPE(ACTIVATION(src_al0, ACTIVATION_PARAMS));
+        OUTPUT_BLOCK_WRITE(output, dst_index, res_al0);
+    #if TILE_F == 4
+        OUTPUT_VEC_TYPE res_al1 = TO_OUTPUT_VEC_TYPE(ACTIVATION(src_al1, ACTIVATION_PARAMS));
+        OUTPUT_BLOCK_WRITE(output, dst_index + 1 * OUTPUT_FEATURE_PITCH * IC_BLOCK, res_al1);
+        OUTPUT_VEC_TYPE res_al2 = TO_OUTPUT_VEC_TYPE(ACTIVATION(src_al2, ACTIVATION_PARAMS));
+        OUTPUT_BLOCK_WRITE(output, dst_index + 2 * OUTPUT_FEATURE_PITCH * IC_BLOCK, res_al2);
+    #endif
+        uint lid_f_offset = lid;
+        INPUT_VEC_TYPE src_unal = 0;
+
+        lid_f_offset += lid < (IC_BLOCK - MISALIGNMENT) ? 0 : IC_BLOCK * (TILE_F - 1);
+    #if TILE_F == 2
+        src_unal = lid < (IC_BLOCK - MISALIGNMENT) ? src0 : src1;
+    #elif TILE_F == 4
+        src_unal = lid < (IC_BLOCK - MISALIGNMENT) ? src0 : src3;
+    #endif
+
+        dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + lid_f_offset + output_offset_in_concat_axis), y, x);
+        __attribute__((opencl_unroll_hint))
+        for (uint tx = 0; tx < TILE_XY; ++tx) {
+            OUTPUT_TYPE res_unal = TO_OUTPUT_TYPE(ACTIVATION(((INPUT0_TYPE*)&src_unal)[tx], ACTIVATION_PARAMS));
+            output[dst_index + tx * IC_BLOCK] = res_unal;
+        }
+    } else
+#endif  // TILE_F != 1
+    {
+        const uint dst_index = OUTPUT_GET_INDEX(b, (f_block*IC_BLOCK + lid + output_offset_in_concat_axis), y, x);
+
+        __attribute__((opencl_unroll_hint))
+        for (uint fw = 0; fw < TILE_F; ++fw) {
+            if (TILE_F != 1 && CEIL_DIV(INPUT0_FEATURE_NUM, IC_BLOCK) % TILE_F != 0 && CEIL_DIV(INPUT0_FEATURE_NUM, IC_BLOCK) % TILE_F == fw)
+                break;
+
+            bool do_leftover_write = INPUT0_FEATURE_NUM % IC_BLOCK == 0 || f_block * IC_BLOCK + fw * IC_BLOCK + lid < INPUT0_FEATURE_NUM;
+            if (do_leftover_write) {
+                __attribute__((opencl_unroll_hint))
+                for (uint tx = 0; tx < TILE_XY; ++tx) {
+                    INPUT0_TYPE src = input[input_offset + lid + tx * IC_BLOCK + fw * INPUT0_FEATURE_PITCH * IC_BLOCK];
+                    OUTPUT_TYPE res = TO_OUTPUT_TYPE(ACTIVATION(src, ACTIVATION_PARAMS));
+                    output[dst_index + tx * IC_BLOCK + fw * OUTPUT_FEATURE_PITCH * IC_BLOCK] = res;
+                }
+            }
+        }
+    }
 #endif
 }
 
 #undef WORK_GROUP_SIZE
 #undef IC_BLOCK
+
+#undef INPUT_VEC_TYPE
+#undef OUTPUT_VEC_TYPE
+#undef TO_OUTPUT_VEC_TYPE
+#undef INPUT_BLOCK_READ
+#undef OUTPUT_BLOCK_WRITE
+
+#undef TILE_F
+#undef CEIL_DIV
diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/concat_input_order.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/concat_input_order.cpp
new file mode 100644 (file)
index 0000000..8559432
--- /dev/null
@@ -0,0 +1,224 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+
+#include "pass_manager.h"
+#include "pooling_inst.h"
+#include "convolution_inst.h"
+#include "fully_connected_inst.h"
+#include "data_inst.h"
+#include "memory_impl.h"
+#include "program_impl.h"
+
+#include <vector>
+#include <tuple>
+
+using namespace cldnn;
+
+namespace {
+
+using shuffle_range = std::pair<int32_t, int32_t>;
+
+bool can_shuffle_features(program_node& node) {
+    if (node.is_type<convolution>()) {
+        auto& conv_node = node.as<convolution>();
+        auto& wei_node = conv_node.weights();
+
+        return conv_node.get_groups() == 1 && conv_node.get_split() == 1 &&
+            conv_node.get_deformable_groups() == 1 && !conv_node.get_transposed() &&
+            !conv_node.activations_zero_points_term() &&
+            wei_node.is_type<data>() && wei_node.is_constant() && !wei_node.is_output();
+    }
+    if (node.is_type<fully_connected>()) {
+        auto& fc_node = node.as<fully_connected>();
+        auto& wei_node = fc_node.weights();
+
+        return wei_node.is_type<data>() && wei_node.is_constant() && !wei_node.is_output();
+    }
+
+    bool pass_through = false;
+    pass_through |= node.is_type<activation>();
+    pass_through |= node.is_type<pooling>();
+    // General conditions for pass-through layers
+    pass_through &= !node.is_output() && node.get_dependencies().size() == 1 && !node.has_fused_primitives();
+    if (pass_through) {
+        // Primitives that are feature order invariant, pass-through shuffled features to users
+        for (auto& user : node.get_users()) {
+            if (!can_shuffle_features(*user))
+                return false;
+        }
+        return true;
+    }
+
+    return false;
+}
+
+void shuffle_weights(data_node& node, const std::vector<shuffle_range>& ranges) {
+    // Correct for shuffled features by shuffling input feature dimension in weights.
+    // This allows to restore correct feature order on output and only changes calculation order.
+    auto wei_layout = node.get_output_layout();
+    auto& old_weights_memory = node.get_attached_memory();
+    bool need_reset = static_cast<bool>(wei_layout.data_padding) || wei_layout.format.is_blocked();
+    auto new_weights_memory = old_weights_memory.get_engine()->allocate_memory(wei_layout, old_weights_memory.get_net_id(), need_reset);
+
+    auto bytes_per_elem = data_type_traits::size_of(wei_layout.data_type);
+    auto old_ptr = static_cast<char*>(old_weights_memory.lock());
+    auto new_ptr = static_cast<char*>(new_weights_memory->lock());
+    for (int32_t ofi = 0; ofi < wei_layout.size.batch[0]; ++ofi) {
+        int32_t new_ifi = 0;
+        for (auto& range : ranges) {
+            for (int32_t ifi = range.first; ifi < range.second; ++ifi, ++new_ifi) {
+                for (int32_t wi = 0; wi < wei_layout.size.spatial[3]; ++wi) {
+                    for (int32_t zi = 0; zi < wei_layout.size.spatial[2]; ++zi) {
+                        for (int32_t yi = 0; yi < wei_layout.size.spatial[1]; ++yi) {
+                            for (int32_t xi = 0; xi < wei_layout.size.spatial[0]; ++xi) {
+                                auto old_coords = tensor(batch(ofi), feature(ifi), spatial(xi, yi, zi, wi));
+                                auto new_coords = tensor(batch(ofi), feature(new_ifi), spatial(xi, yi, zi, wi));
+                                auto old_offset = wei_layout.get_linear_offset(old_coords);
+                                auto new_offset = wei_layout.get_linear_offset(new_coords);
+                                for (size_t byte = 0; byte < bytes_per_elem; ++byte) {
+                                    new_ptr[new_offset * bytes_per_elem + byte] = old_ptr[old_offset * bytes_per_elem + byte];
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+    old_weights_memory.unlock();
+    new_weights_memory->unlock();
+
+    node.attach_memory(*new_weights_memory, false);
+}
+
+void shuffle_features(program_node& node, const std::vector<shuffle_range>& ranges) {
+    if (node.is_type<convolution>()) {
+        auto& conv = node.as<convolution>();
+        shuffle_weights(conv.weights().as<data>(), ranges);
+    } else if (node.is_type<fully_connected>()) {
+        auto& fc = node.as<fully_connected>();
+        shuffle_weights(fc.weights().as<data>(), ranges);
+    } else {
+        // General case for pass-through layers
+        for (auto& user : node.get_users()) {
+            shuffle_features(*user, ranges);
+        }
+    }
+}
+
+}  // namespace
+
+void concat_input_order::run(program_impl& p) {
+    for (auto node : p.get_processing_order()) {
+        // Check that optimization can be performed:
+        // 1. Not an output
+        // 2. Concatenation along features
+        // 3. Currently only fsv16 format on input/output
+        // 4. Not already aligned
+        // 5. Users can accept shuffled features
+        // 6. No fused primitives
+        if (!node->is_type<concatenation>() || node->is_output())
+            continue;
+
+        auto& concat_node = node->as<concatenation>();
+        auto prim = concat_node.get_primitive();
+
+        bool along_f = prim->axis == concatenation::along_f;
+        size_t inputs_count = prim->input_size();
+        bool no_fusing = !concat_node.has_fused_primitives() && concat_node.get_dependencies().size() == inputs_count;
+
+        auto out_format = concat_node.get_output_layout().format;
+        bool correct_format = out_format == format::b_fs_yx_fsv16;
+        tensor::value_type alignment = 1;
+        if (out_format == format::b_fs_yx_fsv16)
+            alignment = 16;
+
+        bool single_format = true;
+        std::vector<tensor::value_type> feature_sizes;
+        feature_sizes.reserve(inputs_count);
+        for (size_t input_idx = 0; input_idx < inputs_count; ++input_idx) {
+            auto& dep = concat_node.get_dependency(input_idx);
+            auto dep_layout = dep.get_output_layout();
+            single_format &= dep_layout.format == out_format;
+            feature_sizes.push_back(dep_layout.size.feature[0]);
+        }
+        // Alignment is not optimal if aligned input follows unaligned one
+        bool already_aligned = true;
+        for (size_t i = 1; i < feature_sizes.size(); ++i) {
+            bool current_aligned = feature_sizes[i] % alignment == 0;
+            bool previous_aligned = feature_sizes[i - 1] % alignment == 0;
+            already_aligned &= previous_aligned || !current_aligned;
+        }
+        // Check that we can fuse shuffling to users
+        bool can_shuffle_users = true;
+        for (auto user : concat_node.get_users()) {
+            can_shuffle_users &= can_shuffle_features(*user);
+        }
+
+        if (!along_f || !no_fusing || !correct_format || !single_format || already_aligned || !can_shuffle_users)
+            continue;
+
+        // Perform the optimization
+        // Calculate new input order - first inputs preserving alignment, then rest
+        std::vector<size_t> new_order;
+        new_order.reserve(inputs_count);
+        for (size_t i = 0; i < feature_sizes.size(); ++i) {
+            if (feature_sizes[i] % alignment == 0)
+                new_order.push_back(i);
+        }
+        for (size_t i = 0; i < feature_sizes.size(); ++i) {
+            if (feature_sizes[i] % alignment != 0)
+                new_order.push_back(i);
+        }
+        // Calculate new ranges
+        int32_t current_offset = 0;
+        std::vector<shuffle_range> original_ranges;
+        original_ranges.reserve(inputs_count);
+        for (auto& feature_size : feature_sizes) {
+            original_ranges.emplace_back(current_offset, current_offset + feature_size);
+            current_offset += feature_size;
+        }
+        std::vector<shuffle_range> shuffled_ranges;
+        shuffled_ranges.reserve(inputs_count);
+        for (auto& ord : new_order) {
+            shuffled_ranges.push_back(original_ranges[ord]);
+        }
+        // Change input order
+        std::vector<program_node*> new_dependencies = {};
+        new_dependencies.reserve(inputs_count);
+        for (auto& ord : new_order) {
+            new_dependencies.push_back(&concat_node.get_dependency(ord));
+        }
+        // Update in place with const cast instead of replacing
+        auto& dependencies = concat_node.get_dependencies();
+        auto& mutable_dependencies = const_cast<std::vector<program_node*>&>(dependencies);
+        for (size_t i = 0; i < new_dependencies.size(); ++i) {
+            mutable_dependencies[i] = new_dependencies[i];
+        }
+        std::vector<primitive_id> new_input_ids;
+        new_input_ids.reserve(inputs_count);
+        for (auto& ord : new_order) {
+            new_input_ids.push_back(prim->input[ord]);
+        }
+        auto mutable_prim = std::const_pointer_cast<concatenation>(prim);
+        mutable_prim->input = new_input_ids;
+        // Correct users for shuffled features
+        for (auto& user : concat_node.get_users()) {
+            shuffle_features(*user, shuffled_ranges);
+        }
+    }
+}
+
index 8bd7cfe..bc620bf 100644 (file)
@@ -332,6 +332,28 @@ public:
     void run(program_impl& p) override;
 };
 
+class concat_input_order : public base_pass {
+    // This optimization changes order of inputs for concatenation to provide
+    // better alignment for execution and allow for optimizing out in some cases.
+    // For example concatenation along features with inputs [13, 1024] in format fsv16
+    // has only first input aligned to feature blocks, blocking performant implementation
+    // for second one.
+    // This can be fixed by chaning order to [1024, 13] and fusing reshuffling of those features
+    // into following layers, such as convolution or fully connected, where it can be
+    // implemented as compile-time weights shuffling.
+    //
+    // Requirements - may work incorrectly if not fullfiled:
+    // - formats are selected
+    // - implementations aren't selected
+    //
+    // Soft requirements - reduce applicability if not fullfiled:
+    // - constant primitives are reduced to data nodes
+    // - no fused primitives
+public:
+    concat_input_order() : base_pass("concat_input_order") {}
+    void run(program_impl& p) override;
+};
+
 class memory_dependency_pass : public base_pass {
 public:
     explicit memory_dependency_pass(const std::string& pass_name) : base_pass(pass_name) {}
index 2a1733c..398eeaa 100644 (file)
@@ -420,6 +420,10 @@ void program_impl::pre_optimize_graph(bool is_internal) {
         apply_opt_pass<prepare_primitive_fusing>(lo);
 
         apply_opt_pass<reorder_inputs>(lo, rf);
+        // Ideally this should be done before fusing to simplify logic and make the pass more powerful,
+        // but after format selection to select correct alignment.
+        // Unfortunately those passes currently happen in reverse order.
+        apply_opt_pass<concat_input_order>();
 
         // TODO this code should be moved to post compilation after kernel selector will support handling reorder bias
         apply_opt_pass<pre_optimize_bias>(rf);
index 535cc85..48f4dca 100644 (file)
@@ -638,6 +638,10 @@ TEST_P(concat_gpu_4d_i8, b_fs_yx_fsv32) {
     ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv32));
 }
 
+TEST_P(concat_gpu_4d_i8, b_fs_yx_fsv16) {
+    ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv16));
+}
+
 INSTANTIATE_TEST_CASE_P(smoke_low_precision,
                         concat_gpu_4d_i8,
                         concat_gpu_all_params,
@@ -651,3 +655,140 @@ INSTANTIATE_TEST_CASE_P(smoke_low_precision,
                         concat_gpu_4d_u8,
                         concat_gpu_all_params,
                         concat_gpu::PrintToStringParamName);
+
+template <typename Type, typename OutputT>
+struct concat_id_conv_gpu_4d : public concat_gpu {
+public:
+
+    void test(format::type fmt) {
+        auto data_type = type_to_data_type<Type>::value;
+
+        const auto& engine = get_test_engine();
+        const size_t batch_num = testing::get<0>(GetParam());
+        const std::vector<size_t> in_features = testing::get<1>(GetParam());
+        const size_t input_y = testing::get<2>(GetParam());
+        const size_t input_x = testing::get<3>(GetParam());
+        size_t output_f = 0;
+        for (auto& f : in_features)
+            output_f += f;
+
+        topology topology;
+
+        std::vector<VVVVF<Type>> in_data;
+        std::vector<memory> in_memory;
+        std::vector<primitive_id> input_ids;
+        for (size_t i = 0; i < in_features.size(); i++) {
+            auto size = tensor(static_cast<int32_t>(batch_num),
+                               static_cast<int32_t>(in_features[i]),
+                               static_cast<int32_t>(input_x),
+                               static_cast<int32_t>(input_y));
+            auto data = generate_random_4d<Type>(batch_num, in_features[i], input_y, input_x, -128, 128);
+            auto in_lay = layout(data_type, fmt, size);
+            auto data_flat = std::vector<Type>(in_lay.get_linear_size(), 0);
+
+            for (size_t bi = 0; bi < batch_num; ++bi) {
+                for (size_t fi = 0; fi < in_features[i]; ++fi) {
+                    for (size_t yi = 0; yi < input_y; ++yi) {
+                        for (size_t xi = 0; xi < input_x; ++xi) {
+                            auto coords = tensor(batch(bi), feature(fi), spatial(xi, yi, 0, 0));
+                            auto in_offset = in_lay.get_linear_offset(coords);
+
+                            data_flat[in_offset] = data[bi][fi][yi][xi];
+                        }
+                    }
+                }
+            }
+
+            auto in_mem = memory::allocate(engine, in_lay);
+            set_values(in_mem, data_flat);
+            in_memory.push_back(in_mem);
+
+            topology.add(input_layout("input" + std::to_string(i), in_lay));
+            in_data.emplace_back(std::move(data));
+            input_ids.push_back("input" + std::to_string(i));
+        }
+
+        topology.add(concatenation("concat", input_ids, concatenation::concatenation_axis::along_f));
+        // Add identity convolution
+        auto weights_lay = cldnn::layout(data_type, cldnn::format::bfyx, tensor(batch(output_f), feature(output_f)));
+        auto weights_mem = cldnn::memory::allocate(engine, weights_lay);
+        {
+            auto weights_ptr = weights_mem.pointer<Type>();
+            for (size_t fi = 0; fi < output_f; ++fi) {
+                auto coords = tensor(batch(fi), feature(fi), spatial(0, 0, 0, 0));
+                auto offset = weights_lay.get_linear_offset(coords);
+                weights_ptr[offset] = static_cast<Type>(1.f);
+            }
+        }
+        topology.add(data("weights", weights_mem));
+        topology.add(convolution("conv", "concat", { "weights" }));
+
+        build_options options;
+        options.set_option(build_option::optimize_data(true));
+        auto conv_forcing = implementation_desc{ fmt, std::string() };
+        options.set_option(build_option::force_implementations({ {primitive_id("conv"), conv_forcing} }));
+        network network(engine, topology, options);
+
+        for (size_t i = 0; i < in_features.size(); i++) {
+            network.set_input_data(input_ids[i], in_memory[i]);
+        }
+
+        network.execute();
+
+        auto out_mem = network.get_output("conv").get_memory();
+        auto out_ptr = out_mem.pointer<OutputT>();
+
+        for (size_t bi = 0; bi < batch_num; bi++) {
+            size_t f_sum = 0;
+            for (size_t in_i = 0; in_i < in_features.size(); in_i++) {
+                for (size_t fi = 0; fi < in_features[in_i]; fi++) {
+                    for (size_t yi = 0; yi < input_y; yi++) {
+                        for (size_t xi = 0; xi < input_x; xi++) {
+                            auto output_coords = tensor(batch(bi), feature(f_sum + fi), spatial(xi, yi, 0, 0));
+                            auto output_offset = out_mem.get_layout().get_linear_offset(output_coords);
+
+                            auto ref_val = in_data[in_i][bi][fi][yi][xi];
+                            auto actual_val = static_cast<Type>(out_ptr[output_offset]);
+                            EXPECT_EQ(ref_val, actual_val)
+                                << " b=" << bi << ", f=" << f_sum + fi << "(input " << in_i << "), y=" << yi << ", x=" << xi;
+                        }
+                    }
+                }
+                f_sum += in_features[in_i];
+            }
+        }
+    }
+};
+
+using concat_id_conv_gpu_4d_f16 = concat_id_conv_gpu_4d<FLOAT16, FLOAT16>;
+using concat_id_conv_gpu_4d_i8 = concat_id_conv_gpu_4d<int8_t, float>;
+
+TEST_P(concat_id_conv_gpu_4d_f16, input_order_opt_b_fs_yx_fsv16) {
+    ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv16));
+}
+
+INSTANTIATE_TEST_CASE_P(smoke_low_precision,
+                        concat_id_conv_gpu_4d_f16,
+                        ::testing::Values(
+                            TestParamType_concat(2, { 2, 32 }, 2, 1),
+                            TestParamType_concat(2, { 31, 64 }, 2, 2),
+                            TestParamType_concat(2, { 15, 15, 16 }, 2, 1),
+                            TestParamType_concat(2, { 16, 15, 16 }, 2, 2),
+                            TestParamType_concat(2, { 15, 2, 16, 64 }, 1, 2)
+                        ),
+                        concat_gpu::PrintToStringParamName);
+
+TEST_P(concat_id_conv_gpu_4d_i8, input_order_opt_b_fs_yx_fsv16) {
+    ASSERT_NO_FATAL_FAILURE(test(format::b_fs_yx_fsv16));
+}
+
+INSTANTIATE_TEST_CASE_P(smoke_low_precision,
+                        concat_id_conv_gpu_4d_i8,
+                        ::testing::Values(
+                            TestParamType_concat(2, { 2, 32 }, 2, 1),
+                            TestParamType_concat(2, { 31, 64 }, 2, 2),
+                            TestParamType_concat(2, { 15, 15, 16 }, 2, 1),
+                            TestParamType_concat(2, { 16, 15, 16 }, 2, 2),
+                            TestParamType_concat(2, { 15, 2, 16, 64 }, 1, 2)
+                        ),
+                        concat_gpu::PrintToStringParamName);