Feature/ar24 int8 optimizations (#1208)
authorMarcin Penkowski <marcin.penkowski@intel.com>
Tue, 4 Aug 2020 09:09:23 +0000 (11:09 +0200)
committerGitHub <noreply@github.com>
Tue, 4 Aug 2020 09:09:23 +0000 (12:09 +0300)
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv4_int8.cpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv4_int8.h [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_selector.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/convolution_gpu_b_fs_yx_fsv4_int8.cl [new file with mode: 0644]
inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_buffer_fusing.cpp
inference-engine/thirdparty/clDNN/src/layout_optimizer.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/convolution_gpu_test.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp

diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv4_int8.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv4_int8.cpp
new file mode 100644 (file)
index 0000000..3a9457b
--- /dev/null
@@ -0,0 +1,111 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#include "convolution_kernel_b_fs_yx_fsv4_int8.h"
+#include <vector>
+#include <utility>
+#include <algorithm>
+
+namespace kernel_selector {
+constexpr size_t sub_group_size = 16;
+
+ParamsKey ConvolutionKernel_b_fs_yx_fsv4_int8::GetSupportedKey() const {
+    ParamsKey k;
+    k.EnableOutputDataType(Datatype::F32);
+    k.EnableInputDataType(Datatype::INT8);
+    k.EnableInputWeightsType(WeightsType::INT8);
+    k.EnableOutputDataType(Datatype::INT8);
+    k.EnableInputDataType(Datatype::UINT8);
+    k.EnableInputLayout(DataLayout::b_fs_yx_fsv4);
+    k.EnableOutputLayout(DataLayout::b_fs_yx_fsv4);
+    k.EnableTensorOffset();
+    k.EnableTensorPitches();
+    k.EnableSubGroup();
+    k.EnableBiasPerFeature();
+    k.EnableNonBiasTerm();
+    k.EnableQuantization(QuantizationType::SYMMETRIC);
+    k.EnableDifferentTypes();
+    k.EnableDifferentInputWeightsTypes();
+    return k;
+}
+
+ConvolutionKernelBase::DispatchData ConvolutionKernel_b_fs_yx_fsv4_int8::SetDefault(const convolution_params& cp, int) const {
+    DispatchData runInfo = ConvolutionKernelBase::SetDefault(cp);
+
+    runInfo.efficiency = FORCE_PRIORITY_9;
+    if (cp.output.X().v > 512 && cp.filterSize.x == 5 && cp.filterSize.y == 5)
+        runInfo.efficiency = FORCE_PRIORITY_2;
+    runInfo.gws0 = CeilDiv(cp.output.X().v, sub_group_size) / 2;
+    runInfo.gws1 = cp.output.Y().v;
+    runInfo.gws2 = sub_group_size;
+
+    runInfo.lws0 = 1;
+    runInfo.lws1 = 1;
+    runInfo.lws2 = sub_group_size;
+
+    return runInfo;
+}
+
+bool ConvolutionKernel_b_fs_yx_fsv4_int8::Validate(const Params& p, const optional_params& o) const {
+    if (!ConvolutionKernelBase::Validate(p, o) || !CovolutionCheckInput(p, o)) {
+        return false;
+    }
+
+    const auto& params = static_cast<const convolution_params&>(p);
+    if (params.inputs[0].X().v % 64)
+        return false;
+
+    bool bFilterSize = (params.filterSize.x == 5 && params.filterSize.y == 5) ||
+                       (params.filterSize.x == 3 && params.filterSize.y == 3 && (params.inputs[0].Feature().v % 4) == 0) ||
+                       (params.filterSize.x == 1 && params.filterSize.y == 1);
+
+    bool bStride = (params.stride.x == 1 && params.stride.y == 1);
+
+    if (!bFilterSize || !bStride || (params.output.Feature().v % 4) != 0 || (params.output.Batch().v != 1)) {
+        return false;
+    }
+
+    return true;
+}
+
+JitConstants ConvolutionKernel_b_fs_yx_fsv4_int8::GetJitConstants(const convolution_params& params, const DispatchData& runInfo) const {
+    auto jit = Parent::GetJitConstants(params, runInfo);
+
+    jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", runInfo.lws2));
+
+    jit.Merge(MakeTypeJitConstants(GetAccumulatorType(params), "ACCUMULATOR"));
+    jit.Merge(MakeTypeJitConstants(GetActivationType(params), "ACTIVATION"));
+
+    if (!params.fused_ops.empty()) {
+        auto input_dt = GetActivationType(params);
+        FusedOpsConfiguration conf0 = { "_0", {"batch", "FILTER_OFM_MAX * iter + ofm + 0", "idy", "idx"}, "res0", input_dt, 1 };
+        FusedOpsConfiguration conf1 = { "_1", {"batch", "FILTER_OFM_MAX * iter + ofm + 1", "idy", "idx"}, "res1", input_dt, 1 };
+        FusedOpsConfiguration conf2 = { "_2", {"batch", "FILTER_OFM_MAX * iter + ofm + 2", "idy", "idx"}, "res2", input_dt, 1 };
+        FusedOpsConfiguration conf3 = { "_3", {"batch", "FILTER_OFM_MAX * iter + ofm + 3", "idy", "idx"}, "res3", input_dt, 1 };
+        FusedOpsConfiguration conf4 = { "_4", {"batch", "FILTER_OFM_MAX * iter + ofm + 0", "idy", "idx"}, "res4", input_dt, 1 };
+        FusedOpsConfiguration conf5 = { "_5", {"batch", "FILTER_OFM_MAX * iter + ofm + 1", "idy", "idx"}, "res5", input_dt, 1 };
+        FusedOpsConfiguration conf6 = { "_6", {"batch", "FILTER_OFM_MAX * iter + ofm + 2", "idy", "idx"}, "res6", input_dt, 1 };
+        FusedOpsConfiguration conf7 = { "_7", {"batch", "FILTER_OFM_MAX * iter + ofm + 3", "idy", "idx"}, "res7", input_dt, 1 };
+        jit.Merge(MakeFusedOpsJitConstants(params, { conf0, conf1, conf2, conf3, conf4, conf5, conf6, conf7 }));
+    }
+
+    return jit;
+}
+
+KernelsData ConvolutionKernel_b_fs_yx_fsv4_int8::GetKernelsData(const Params& params, const optional_params& options) const {
+    return GetTunedKernelsDataByIndex(params, options);
+}
+
+}  // namespace kernel_selector
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv4_int8.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv4_int8.h
new file mode 100644 (file)
index 0000000..9cbc775
--- /dev/null
@@ -0,0 +1,48 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#pragma once
+
+#include "convolution_kernel_base.h"
+#include <string>
+
+namespace kernel_selector {
+
+class ConvolutionKernel_b_fs_yx_fsv4_int8 : public ConvolutionKernelBase {
+public:
+    using Parent = ConvolutionKernelBase;
+    ConvolutionKernel_b_fs_yx_fsv4_int8() : Parent("convolution_gpu_b_fs_yx_fsv4_int8") {}
+    virtual ~ConvolutionKernel_b_fs_yx_fsv4_int8() {}
+
+    KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
+    ParamsKey GetSupportedKey() const override;
+
+protected:
+    WeightsLayout GetPreferredWeightsLayout(const convolution_params&) const override {
+        return WeightsLayout::os_is_yx_osv16_isv4;
+    }
+
+    JitConstants GetJitConstants(const convolution_params& params, const DispatchData& kd) const override;
+    bool Validate(const Params& p, const optional_params& o) const override;
+    bool NeedPaddedInput() const override { return true; }
+    DispatchData SetDefault(const convolution_params& arg, int autoTuneIndex = -1) const override;
+    std::vector<FusedOpType> GetSupportedFusedOps() const override {
+        return { FusedOpType::ELTWISE,
+                 FusedOpType::QUANTIZE,
+                 FusedOpType::SCALE,
+                 FusedOpType::ACTIVATION };
+    }
+};
+}  // namespace kernel_selector
index 558b226..e87157e 100644 (file)
@@ -64,6 +64,7 @@
 #include "convolution_kernel_b_fs_yx_fsv_16_32_imad_dw.hpp"
 #include "convolution_kernel_imad_bs_fs_yx_bsv16_fsv16_1x1.h"
 #include "convolution_kernel_imad_bs_fs_yx_bsv16_fsv16_3x3.h"
+#include "convolution_kernel_b_fs_yx_fsv4_int8.h"
 
 namespace kernel_selector {
 convolution_kernel_selector::convolution_kernel_selector() {
@@ -129,6 +130,7 @@ convolution_kernel_selector::convolution_kernel_selector() {
     Attach<ConvolutionKernel_imad_b_fs_yx_fsv4_1x1>();
     Attach<ConvolutionKernel_mmad_bfyx_to_b_fs_yx_fsv4>();
     Attach<ConvolutionKernel_imad_b_fs_yx_fsv4_dw>();
+    Attach<ConvolutionKernel_b_fs_yx_fsv4_int8>();
 
     // b_fs_yx_fsv32 kernels
     Attach<ConvolutionKernel_mmad_b_fs_yx_fsv32>();
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/convolution_gpu_b_fs_yx_fsv4_int8.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/convolution_gpu_b_fs_yx_fsv4_int8.cl
new file mode 100644 (file)
index 0000000..ac3988c
--- /dev/null
@@ -0,0 +1,163 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "include/common.cl"
+#include "include/data_types.cl"
+#include "include/fetch.cl"
+#include "include/imad.cl"
+
+#define INPUT0_PACKED_TYPE uint
+
+__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))
+__attribute__((reqd_work_group_size(1, 1, SUB_GROUP_SIZE)))
+KERNEL(convolution_gpu_b_fs_yx_fsv4_int8)(
+    const __global INPUT0_PACKED_TYPE* input,
+    __global OUTPUT_TYPE* output,
+    const __global FILTER_TYPE* weights,
+#if BIAS_TERM
+    const __global BIAS_TYPE* bias,
+#endif
+#if HAS_FUSED_OPS_DECLS
+    FUSED_OPS_DECLS,
+#endif
+    uint split_idx)
+{
+#define AS_TYPE_N_(type, n, x) as_##type##n(x)
+#define AS_TYPE_N(type, n, x) AS_TYPE_N_(type, n, x)
+#define AS_INPUT0_TYPE_4(x) AS_TYPE_N(INPUT0_TYPE, 4, x)
+#define AS_FILTER_TYPE_4(x) AS_TYPE_N(FILTER_TYPE, 4, x)
+
+    const uint items_per_EU = 2;
+    const uint idx = items_per_EU * ((uint)get_global_id(0) * SUB_GROUP_SIZE + (uint)get_global_id(2));
+    const uint idy = (uint)get_global_id(1);
+    uint filter_idx = 0;
+    uint output_idx = 0;
+    uint input_idx = 0;
+    int batch = 0;
+    const uint packed_values = 4;
+#if FILTER_OFM_NUM > 8
+#define FILTER_OFM_MAX 8
+#else
+#define FILTER_OFM_MAX FILTER_OFM_NUM
+#endif
+    __attribute__((opencl_unroll_hint(FILTER_OFM_NUM / FILTER_OFM_MAX)))
+        for (int iter = 0; iter < FILTER_OFM_NUM / FILTER_OFM_MAX + (FILTER_OFM_NUM % FILTER_OFM_MAX != 0); iter++) {
+            int out1[FILTER_OFM_MAX] = { 0 };
+            int out2[FILTER_OFM_MAX] = { 0 };
+            filter_idx = FILTER_OFM_MAX * iter * packed_values;
+
+            __attribute__((opencl_unroll_hint(1)))
+                for (int ifm = 0; ifm < (FILTER_IFM_NUM + 3) / 4; ifm++) {
+                    __attribute__((opencl_unroll_hint(FILTER_SIZE_Y)))
+                        for (int yy = 0; yy < FILTER_SIZE_Y; yy++) {
+                            uint inp_idx = ifm * (INPUT0_FEATURE_PITCH)+(idy + yy) * (INPUT0_Y_PITCH)+idx;
+
+                            __attribute__((opencl_unroll_hint(FILTER_SIZE_X)))
+                                for (int xx = 0; xx < FILTER_SIZE_X; xx++) {
+                                    char8 tmp = as_char8(vload2(0, (__global uint*)(input + inp_idx + xx)));
+
+                                    __attribute__((opencl_unroll_hint(FILTER_OFM_MAX)))
+                                        for (int ofm = 0; ofm < FILTER_OFM_MAX; ofm++) {
+                                            __global uint* www = weights + (filter_idx + ofm * packed_values);
+                                            char4 w = AS_FILTER_TYPE_4(www[0]);
+                                            out1[ofm] = IMAD(out1[ofm], AS_INPUT0_TYPE_4((char4)(tmp[0], tmp[1], tmp[2], tmp[3])), AS_FILTER_TYPE_4(w));
+                                            out2[ofm] = IMAD(out2[ofm], AS_INPUT0_TYPE_4((char4)(tmp[4], tmp[5], tmp[6], tmp[7])), AS_FILTER_TYPE_4(w));
+                                        }
+                                    filter_idx += (packed_values * SUB_GROUP_SIZE);
+                                }
+                        }
+                }
+
+            __attribute__((opencl_unroll_hint(FILTER_OFM_MAX / 4)))
+                for (int ofm = 0; ofm < FILTER_OFM_MAX && (ofm + iter * FILTER_OFM_MAX) < FILTER_OFM_NUM; ofm += packed_values) {
+
+#if BIAS_TERM
+                    ACTIVATION_TYPE res0 = TO_ACTIVATION_TYPE(out1[ofm + 0]) + TO_ACTIVATION_TYPE(bias[(iter * FILTER_OFM_MAX) + ofm + 0]);
+                    ACTIVATION_TYPE res1 = TO_ACTIVATION_TYPE(out1[ofm + 1]) + TO_ACTIVATION_TYPE(bias[(iter * FILTER_OFM_MAX) + ofm + 1]);
+                    ACTIVATION_TYPE res2 = TO_ACTIVATION_TYPE(out1[ofm + 2]) + TO_ACTIVATION_TYPE(bias[(iter * FILTER_OFM_MAX) + ofm + 2]);
+                    ACTIVATION_TYPE res3 = TO_ACTIVATION_TYPE(out1[ofm + 3]) + TO_ACTIVATION_TYPE(bias[(iter * FILTER_OFM_MAX) + ofm + 3]);
+                    ACTIVATION_TYPE res4 = TO_ACTIVATION_TYPE(out2[ofm + 0]) + TO_ACTIVATION_TYPE(bias[(iter * FILTER_OFM_MAX) + ofm + 0]);
+                    ACTIVATION_TYPE res5 = TO_ACTIVATION_TYPE(out2[ofm + 1]) + TO_ACTIVATION_TYPE(bias[(iter * FILTER_OFM_MAX) + ofm + 1]);
+                    ACTIVATION_TYPE res6 = TO_ACTIVATION_TYPE(out2[ofm + 2]) + TO_ACTIVATION_TYPE(bias[(iter * FILTER_OFM_MAX) + ofm + 2]);
+                    ACTIVATION_TYPE res7 = TO_ACTIVATION_TYPE(out2[ofm + 3]) + TO_ACTIVATION_TYPE(bias[(iter * FILTER_OFM_MAX) + ofm + 3]);
+#else
+                    ACTIVATION_TYPE res0 = TO_ACTIVATION_TYPE(out1[ofm + 0]);
+                    ACTIVATION_TYPE res1 = TO_ACTIVATION_TYPE(out1[ofm + 1]);
+                    ACTIVATION_TYPE res2 = TO_ACTIVATION_TYPE(out1[ofm + 2]);
+                    ACTIVATION_TYPE res3 = TO_ACTIVATION_TYPE(out1[ofm + 3]);
+                    ACTIVATION_TYPE res4 = TO_ACTIVATION_TYPE(out2[ofm + 0]);
+                    ACTIVATION_TYPE res5 = TO_ACTIVATION_TYPE(out2[ofm + 1]);
+                    ACTIVATION_TYPE res6 = TO_ACTIVATION_TYPE(out2[ofm + 2]);
+                    ACTIVATION_TYPE res7 = TO_ACTIVATION_TYPE(out2[ofm + 3]);
+#endif
+
+                    if (OUTPUT_PAD_BEFORE_FEATURE_NUM > 0) {
+                        uint output_feature_specific_offset = OUTPUT_Y_PITCH * OUTPUT_PAD_BEFORE_SIZE_Y +
+                            (OUTPUT_PAD_BEFORE_SIZE_X * OUTPUT_X_PITCH);
+                        output_idx = (iter * FILTER_OFM_MAX * OUTPUT_FEATURE_PITCH) + ofm * OUTPUT_FEATURE_PITCH +
+                            idy * OUTPUT_Y_PITCH * packed_values + idx * packed_values + OUTPUT_OFFSET + output_feature_specific_offset * 3;
+                    }
+                    else {
+                        output_idx = (iter * FILTER_OFM_MAX * OUTPUT_FEATURE_PITCH) + ofm * OUTPUT_FEATURE_PITCH +
+                            idy * OUTPUT_Y_PITCH * packed_values + idx * packed_values + packed_values * OUTPUT_OFFSET;
+                    }
+
+                    MAKE_VECTOR_TYPE(OUTPUT_TYPE, 4) pack1;
+                    MAKE_VECTOR_TYPE(OUTPUT_TYPE, 4) pack2;
+#if HAS_FUSED_OPS
+                    { FUSED_OPS_0; pack1[0] = FUSED_OPS_RESULT_0; };
+                    { FUSED_OPS_1; pack1[1] = FUSED_OPS_RESULT_1; };
+                    { FUSED_OPS_2; pack1[2] = FUSED_OPS_RESULT_2; };
+                    { FUSED_OPS_3; pack1[3] = FUSED_OPS_RESULT_3; };
+
+                    { FUSED_OPS_4; pack2[0] = FUSED_OPS_RESULT_4; };
+                    { FUSED_OPS_5; pack2[1] = FUSED_OPS_RESULT_5; };
+                    { FUSED_OPS_6; pack2[2] = FUSED_OPS_RESULT_6; };
+                    { FUSED_OPS_7; pack2[3] = FUSED_OPS_RESULT_7; };
+#else
+                    pack1[0] = TO_OUTPUT_TYPE(res0);
+                    pack1[1] = TO_OUTPUT_TYPE(res1);
+                    pack1[2] = TO_OUTPUT_TYPE(res2);
+                    pack1[3] = TO_OUTPUT_TYPE(res3);
+                    pack2[0] = TO_OUTPUT_TYPE(res4);
+                    pack2[1] = TO_OUTPUT_TYPE(res5);
+                    pack2[2] = TO_OUTPUT_TYPE(res6);
+                    pack2[3] = TO_OUTPUT_TYPE(res7);
+#endif
+
+#if OUTPUT_TYPE_SIZE == 1
+                    vstore2((float2)(
+                        as_float(pack1),
+                        as_float(pack2)
+                        ), 0, (__global float*)(output + output_idx));
+#else
+#if OUTPUT_OFFSET % 4
+                    output[output_idx + 0] = TO_OUTPUT_TYPE(pack1[0]);
+                    output[output_idx + 1] = TO_OUTPUT_TYPE(pack1[1]);
+                    output[output_idx + 2] = TO_OUTPUT_TYPE(pack1[2]);
+                    output[output_idx + 3] = TO_OUTPUT_TYPE(pack1[3]);
+                    output[output_idx + 4] = TO_OUTPUT_TYPE(pack2[0]);
+                    output[output_idx + 5] = TO_OUTPUT_TYPE(pack2[1]);
+                    output[output_idx + 6] = TO_OUTPUT_TYPE(pack2[2]);
+                    output[output_idx + 7] = TO_OUTPUT_TYPE(pack2[3]);
+#else
+                    vstore4((float4)(TO_OUTPUT_TYPE(pack1[0]), TO_OUTPUT_TYPE(pack1[1]), TO_OUTPUT_TYPE(pack1[2]),
+                        TO_OUTPUT_TYPE(pack1[3])), 0, (__global float*)(output + output_idx));
+                    vstore4((float4)(TO_OUTPUT_TYPE(pack2[0]), TO_OUTPUT_TYPE(pack2[1]), TO_OUTPUT_TYPE(pack2[2]),
+                        TO_OUTPUT_TYPE(pack2[3])), 0, (__global float*)(output + output_idx + 4));
+#endif
+#endif
+                }
+        }
+}
index cabfa1d..d2ce48e 100644 (file)
@@ -97,7 +97,10 @@ void prepare_buffer_fusing::run(program_impl& p) {
                 if (l.format == format::byxf_af32 && (l.size.feature[0] % 32 != 0 || node.get_primitive()->axis != concatenation::along_f))
                     return;
 
-                if (l.format == format::b_fs_yx_fsv4 || l.format == format::bs_fs_yx_bsv16_fsv16)
+                if (l.format == format::bs_fs_yx_bsv16_fsv16)
+                    return;
+
+                if (l.format == format::b_fs_yx_fsv4 && (l.size.feature[0] != 8 || node.get_primitive()->axis != concatenation::along_f))
                     return;
             }
 
index d8d1a38..c7494d1 100644 (file)
@@ -625,6 +625,10 @@ format layout_optimizer::imad_case(convolution_node const& node) const {
         return format::bfzyx;
     }
 
+    if ((out_size.feature[0] == 8 || out_size.feature[0] == 12) && out_size.spatial[1] > 512) {
+        return format::b_fs_yx_fsv4;
+    }
+
     if (stride.spatial[0] != stride.spatial[1] || out_size.spatial[0] != out_size.spatial[1] ||
         (weights_dt != data_types::u8 && weights_dt != data_types::i8)) {
         return format::byxf_af32;
index 0cd0a4e..33b5e55 100644 (file)
@@ -5353,7 +5353,99 @@ TEST_P(convolution_gpu_fs_byx_fsv32_crop, fs_byx_fsv32_crop)
                 }
 }
 
+TEST(convolution_f32_fw_gpu, convolution_int8_b_fs_yx_fsv4_to_bfyx) {
 
+    const int batch_num = 1;
+    const int output_f = 12;
+    const int input_f = 16;
+    const int filter_xy = 5;
+    const int output_padding = 2;
+    const int input_size_x = 1280;
+    const int input_size_y = 720;
+
+    const auto& engine = get_test_engine();
+
+    auto input_size = tensor(batch_num, input_f, input_size_x, input_size_y);
+    auto input_data = generate_random_4d<float>(batch_num, input_f, input_size_y, input_size_x, -10, 10);
+
+    auto input_data_bfyx = flatten_4d(format::bfyx, input_data);
+    auto input = memory::allocate(engine, { data_types::f32, format::bfyx, input_size });
+    set_values(input, input_data_bfyx);
+
+    auto weights_size = tensor(output_f, input_f, filter_xy, filter_xy);
+    auto weights_data = generate_random_4d<int8_t>(output_f, input_f, filter_xy, filter_xy, -10, 10);
+    auto weights_data_bfyx = flatten_4d(format::bfyx, weights_data);
+    auto weights = memory::allocate(engine, { data_types::i8, format::bfyx, weights_size });
+    set_values(weights, weights_data_bfyx);
+
+    auto biases_size = tensor(1, output_f, 1, 1);
+    auto biases_data = generate_random_4d<int8_t>(1, output_f, 1, 1, -10, 10);
+    auto biases_data_bfyx = flatten_4d(format::bfyx, biases_data);
+    auto biases = memory::allocate(engine, { data_types::i8, format::bfyx, biases_size });
+    set_values(biases, biases_data_bfyx);
+
+    topology topology_ref(
+        input_layout("input", input.get_layout()),
+        reorder("to_int", "input", { data_types::i8,format::bfyx,{ batch_num, input_f, input_size_x, input_size_y } }),
+        data("weights", weights),
+        data("biases", biases),
+        convolution("conv", "to_int", { "weights" }, { "biases" }, { 1, 1, 1, 1 }, { 0, 0, -2, -2 }, { 1,1,1,1 },
+            padding{ { 0, 0, output_padding, output_padding }, 0 }),
+        reorder("output", "conv", { data_types::f32,format::bfyx,{ batch_num, input_f, input_size_x, input_size_y } }));
+
+    build_options build_opt;
+
+    network network_ref(engine, topology_ref, build_opt);
+    network_ref.set_input_data("input", input);
+
+    auto outputs = network_ref.execute();
+    EXPECT_EQ(outputs.size(), size_t(1));
+    EXPECT_EQ(outputs.begin()->first, "output");
+
+    auto output_memory = outputs.at("output").get_memory();
+    auto output_layout = output_memory.get_layout();
+    auto output_ptr = output_memory.pointer<float>();
+
+    topology topology_act(
+        input_layout("input", input.get_layout()),
+        reorder("to_int", "input", { data_types::i8,format::b_fs_yx_fsv4,{ batch_num, input_f, input_size_x, input_size_y } }),
+        data("weights", weights),
+        data("biases", biases),
+        convolution("conv", "to_int", { "weights" }, { "biases" }, { 1, 1, 1, 1 }, { 0, 0, -2, -2 }, { 1,1,1,1 },
+            padding{ { 0, 0, output_padding, output_padding }, 0 }),
+        reorder("output", "conv", { data_types::f32,format::bfyx,{ batch_num, input_f, input_size_x, input_size_y } }));
+
+    build_options build_opt_act;
+
+    build_opt_act.set_option(build_option::optimize_data(true));
+
+    network network_act(engine, topology_act, build_opt_act);
+    network_act.set_input_data("input", input);
+
+    auto outputs_act = network_act.execute();
+    EXPECT_EQ(outputs_act.size(), size_t(1));
+    EXPECT_EQ(outputs_act.begin()->first, "output");
+
+    auto output_memory_act = outputs_act.at("output").get_memory();
+    auto output_act_ptr = output_memory_act.pointer<float>();
+
+    int y_size = output_layout.size.spatial[1];
+    int x_size = output_layout.size.spatial[0];
+    int f_size = output_layout.size.feature[0];
+    int b_size = output_layout.size.batch[0];
+    EXPECT_EQ(output_layout.format, format::bfyx);
+    EXPECT_EQ(y_size, 720);
+    EXPECT_EQ(x_size, 1280);
+    EXPECT_EQ(f_size, output_f);
+    EXPECT_EQ(b_size, 1);
+    for (int o = 0; o < f_size; ++o) {
+        for (int y = 0; y < y_size; ++y) {
+            for (int x = 0; x < x_size; ++x) {
+                EXPECT_EQ(output_act_ptr[o * x_size * y_size + y * x_size + x], output_ptr[o * x_size * y_size + y * x_size + x]);
+            }
+        }
+    }
+}
 
 TEST(convolution_gpu, bfyx_iyxo_5x5_fp16)
 {
index c30f3ae..4278d78 100644 (file)
@@ -464,6 +464,7 @@ public:
 #define CASE_CONV_S8S8_8 {1, 3, 4, 5}, {1, 32, 4, 5}, {1, 1, 3, 3}, tensor{1}, tensor{0, 0, -1, -1, 0, 0}, tensor{1}, 1, data_types::i8, format::bfyx, data_types::i8, format::bfyx, data_types::f32, format::bfyx
 #define CASE_CONV_S8S8_9 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 1, 1}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx
 #define CASE_CONV_S8S8_10 {16, 32, 5, 5}, {16, 32, 3, 3}, {1, 1, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::i8, format::bs_fs_yx_bsv16_fsv16, data_types::i8, format::os_is_yx_osv16_isv16, data_types::f32, format::bfyx
+#define CASE_CONV_S8S8_11 {1, 4, 1280, 720}, {1, 4, 1280, 720}, {1, 1, 5, 5}, tensor{1}, tensor{0, 0, -2, -2}, tensor{1}, 1, data_types::i8, format::b_fs_yx_fsv4, data_types::i8, format::os_is_yx_osv16_isv4, data_types::f32, format::bfyx
 
 #define CASE_CONV3D_U8S8_1 {1, 15, 5, 4, 5}, {1, 30, 3, 2, 3}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfzyx, data_types::i8, format::bfzyx, data_types::f32, format::bfzyx
 #define CASE_CONV3D_U8S8_2 {1, 15, 5, 5, 5}, {1, 30, 3, 3, 3}, {1, 1, 3, 3, 3}, tensor{1}, tensor{0}, tensor{1}, 1, data_types::u8, format::bfzyx, data_types::i8, format::bfzyx, data_types::f32, format::bfzyx
@@ -1724,6 +1725,34 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_int8_scale_quantize_i8,
                                 bc_test_params{CASE_CONV3D_S8S8_4, 2, 4},
                         }), );
 
+class conv_int8_scale_quantize_i8_conv_b_fs_yx_fsv4_int8 : public ConvFusingTest {};
+TEST_P(conv_int8_scale_quantize_i8_conv_b_fs_yx_fsv4_int8, basic) {
+    auto p = GetParam();
+    create_topologies(input_layout("input", get_input_layout(p)),
+                 data("weights", get_mem(get_weights_layout(p))),
+                 data("bias", get_mem(get_bias_layout(p))),
+                 data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
+                 data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
+                 data("out_lo", get_mem(get_single_element_layout(p), -127)),
+                 data("out_hi", get_mem(get_single_element_layout(p), 127)),
+                 data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / p.kernel.count() / 255)),
+                 convolution("conv_prim", "input", { "weights" }, { "bias" }, p.groups, p.stride, p.pad, p.dilation),
+                 scale("scale", "conv_prim", "scale_data"),
+                 quantize("quantize", "scale", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
+                 reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
+    );
+    implementation_desc conv_impl = { format::b_fs_yx_fsv4, "convolution_gpu_b_fs_yx_fsv4_int8" };
+    bo_fused.set_option(build_option::force_implementations({ {"conv_prim", conv_impl} }));
+
+    tolerance = 1.0f;
+    execute(p);
+}
+
+INSTANTIATE_TEST_CASE_P(fusings_gpu, conv_int8_scale_quantize_i8_conv_b_fs_yx_fsv4_int8,
+                        ::testing::ValuesIn(std::vector<bc_test_params>{
+                                bc_test_params{ CASE_CONV_S8S8_11, 2, 4 },
+                        }), );
+
 class conv_int8_relu_quantize : public ConvFusingTest {};
 TEST_P(conv_int8_relu_quantize, i8) {
     auto p = GetParam();