[IE CLDNN] Gather 5d/6d support (#1553)
authorLukasz Debski <Lukasz.Debski@intel.com>
Mon, 3 Aug 2020 07:05:53 +0000 (09:05 +0200)
committerGitHub <noreply@github.com>
Mon, 3 Aug 2020 07:05:53 +0000 (10:05 +0300)
inference-engine/src/cldnn_engine/cldnn_program.cpp
inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather.cpp
inference-engine/tests_deprecated/functional/cldnn/shared_tests_instance/single_layer_tests/gather_ftests.cpp
inference-engine/thirdparty/clDNN/api/gather.hpp
inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_kernel_ref.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_ref.cl
inference-engine/thirdparty/clDNN/src/gather.cpp
inference-engine/thirdparty/clDNN/src/gpu/gather_gpu.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/gather_gpu_test.cpp

index 10c3a0e..5897b94 100644 (file)
@@ -3760,16 +3760,48 @@ void Program::CreateGatherPrimitive(cldnn::topology& topology, InferenceEngine::
     int axis = gatherLayer->GetParamAsInt("axis", 0);
 
     // Be careful, TensorFlow consist negative axis interpretation bug. Here: -3 = b, -2 = f, -1 = y, but must be -3 = f, -2 = y, -1 = x
-    auto cldnnAxisFromIE = [](int axis) {
-        switch (axis) {
-            case 0: return cldnn::gather::gather_axis::along_b;
-            case 1: return cldnn::gather::gather_axis::along_f;
-            case 2: return cldnn::gather::gather_axis::along_y;
-            case 3: return cldnn::gather::gather_axis::along_x;
-            case -1: return cldnn::gather::gather_axis::along_y;
-            case -2: return cldnn::gather::gather_axis::along_f;
-            case -3: return cldnn::gather::gather_axis::along_b;
-            default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
+    auto cldnnAxisFromIE = [](int axis, cldnn::format inputFormat) {
+        if (axis == 0) {
+            return cldnn::gather::gather_axis::along_b;
+        } else if (axis == 1) {
+            return cldnn::gather::gather_axis::along_f;
+        }
+
+        if (inputFormat == cldnn::format::bfyx) {
+            switch (axis) {
+                case 2: return cldnn::gather::gather_axis::along_y;
+                case 3: return cldnn::gather::gather_axis::along_x;
+                case -1: return cldnn::gather::gather_axis::along_y;
+                case -2: return cldnn::gather::gather_axis::along_f;
+                case -3: return cldnn::gather::gather_axis::along_b;
+                default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
+            }
+        } else if (inputFormat == cldnn::format::bfzyx) {
+            switch (axis) {
+                case 2: return cldnn::gather::gather_axis::along_z;
+                case 3: return cldnn::gather::gather_axis::along_y;
+                case 4: return cldnn::gather::gather_axis::along_x;
+                case -1: return cldnn::gather::gather_axis::along_y;
+                case -2: return cldnn::gather::gather_axis::along_z;
+                case -3: return cldnn::gather::gather_axis::along_f;
+                case -4: return cldnn::gather::gather_axis::along_b;
+                default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
+            }
+        } else if (inputFormat == cldnn::format::bfwzyx) {
+            switch (axis) {
+                case 2: return cldnn::gather::gather_axis::along_w;
+                case 3: return cldnn::gather::gather_axis::along_z;
+                case 4: return cldnn::gather::gather_axis::along_y;
+                case 5: return cldnn::gather::gather_axis::along_x;
+                case -1: return cldnn::gather::gather_axis::along_y;
+                case -2: return cldnn::gather::gather_axis::along_z;
+                case -3: return cldnn::gather::gather_axis::along_w;
+                case -4: return cldnn::gather::gather_axis::along_f;
+                case -5: return cldnn::gather::gather_axis::along_b;
+                default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
+            }
+        } else {
+            THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
         }
     };
 
@@ -3798,127 +3830,18 @@ void Program::CreateGatherPrimitive(cldnn::topology& topology, InferenceEngine::
         }
     }
 
-    auto indicesDims = layer->insData[1].lock()->getTensorDesc().getDims();
-    auto indicesLayout = layer->insData[1].lock()->getTensorDesc().getLayout();
-    auto indicesFormat = FormatFromLayout(indicesLayout);
-
-    auto inputDims = layer->insData[0].lock()->getTensorDesc().getDims();
     auto inputLayout = layer->insData[0].lock()->getTensorDesc().getLayout();
-    auto inputFormat = FormatFromLayout(inputLayout);
-
-    auto outDimsOriginal = layer->outData[0]->getTensorDesc().getDims();
-    auto outputLayoutOriginal = layer->outData[0]->getTensorDesc().getLayout();
-    auto outputFormatOriginal = FormatFromLayout(outputLayoutOriginal);
-
-    auto outDims = outDimsOriginal;
-    auto targetDatatype = DataTypeFromPrecision(layer->precision);
-
-    auto nonNegativeAxis = (axis >= 0) ? axis : axis + 3;
-
-    // following vector is needed just to check if we can apply bfyx WA
-    SizeVector originalRequiredDims;
-    for (size_t d = 0; d < inputDims.size(); d++) {
-        if ((d == nonNegativeAxis) || (inputDims[d] > 1)) {
-            originalRequiredDims.push_back(d);
-        }
-    }
-
-    if (originalRequiredDims.size() < 4) {
-        // make sure that we will have at least 4 required dimensions
-        auto originalAxesIt = originalRequiredDims.begin();
-        for (size_t i = 0; i < 4; i++) {
-            int dimFoundAtIndex = -1;
-            for (size_t j = 0; j < originalRequiredDims.size(); j++) {
-                if (originalRequiredDims[j] == i) {
-                    dimFoundAtIndex = j;
-                }
-            }
-            if (dimFoundAtIndex == -1) {
-                originalAxesIt = originalRequiredDims.insert(originalAxesIt, i);
-            }
-            originalAxesIt++;
-        }
-    }
-
-    // clDNN primitive is missing proper support of 5d/6d inputs
-    // but we can still fall back to bfyx format in some cases
-    bool bfyx_wa = ((inputFormat == cldnn::format::bfzyx || inputFormat == cldnn::format::bfwzyx) &&
-                    (originalRequiredDims.size() == 4) &&
-                    (indicesFormat == cldnn::format::bfyx));
-
-    if (bfyx_wa) {
-        if (indicesDims.size() > 1) {
-            // reshape the indices dims to 1D (along batch axis)
-            size_t indDimAcc = std::accumulate(indicesDims.begin(), indicesDims.end(), 1, std::multiplies<size_t>());
-            SizeVector targetIndDims{ indDimAcc, 1, 1, 1 };
-
-            auto reshapeName = reorderedInputs[1] + "_" + layer->name + "_reshape";
-            auto targetTensor = CldnnTensorFromIEDims(targetIndDims);
-            auto reshapePrim = cldnn::reshape(reshapeName, reorderedInputs[1], CldnnTensorFromIEDims(targetIndDims));
-            topology.add(reshapePrim);
-            AddInnerPrimitiveToProfiler(reshapeName, gatherLayerName, layer);
-            reorderedInputs[1] = reshapeName;
-
-            // adjust expected output dims
-            outDims[nonNegativeAxis] = indDimAcc;
-            outDims.erase(outDims.begin() + nonNegativeAxis + 1, outDims.begin() + nonNegativeAxis + indicesDims.size());
-        }
-
-        // reorder input to bfyx
-        auto reorderName = reorderedInputs[0] + "_" + layer->name + "_format_reorder";
-        auto reorderPrim = cldnn::reorder(reorderName, reorderedInputs[0], cldnn::format::bfyx, targetDatatype);
-        topology.add(reorderPrim);
-        AddInnerPrimitiveToProfiler(reorderName, gatherLayerName, layer);
-        reorderedInputs[0] = reorderName;
-
-        // calculate new input/output dims in bfyx format
-        SizeVector targetInDims(4);
-        SizeVector targetOutDims(4);
-        for (size_t d = 0; d < 4; d++) {
-            targetInDims[d] = inputDims[originalRequiredDims[d]];
-            targetOutDims[d] = outDims[originalRequiredDims[d]];
-        }
-        outDims = targetOutDims;
-
-        // calculate new axis in bfyx format
-        for (size_t d = 0; d < originalRequiredDims.size(); d++) {
-            if (originalRequiredDims[d] == nonNegativeAxis) {
-                axis = d;
-            }
-        }
-
-        // reshape the input dims to the ones expected in bfyx format
-        auto reshapeName = reorderedInputs[0] + "_" + layer->name + "_reshape";
-        auto targetTensor = CldnnTensorFromIEDims(targetInDims);
-        auto reshapePrim = cldnn::reshape(reshapeName, reorderedInputs[0], CldnnTensorFromIEDims(targetInDims));
-        topology.add(reshapePrim);
-        AddInnerPrimitiveToProfiler(reshapeName, gatherLayerName, layer);
-        reorderedInputs[0] = reshapeName;
-    }
+    auto outDims = layer->outData[0]->getTensorDesc().getDims();
 
     auto gatherPrim = cldnn::gather(
         gatherLayerName,
         reorderedInputs[0],
         reorderedInputs[1],
-        cldnnAxisFromIE(axis),
+        cldnnAxisFromIE(axis, FormatFromLayout(inputLayout)),
         CldnnTensorFromIEDims(outDims));
 
     topology.add(gatherPrim);
     AddPrimitiveToProfiler(gatherLayerName, layer);
-
-    if (bfyx_wa) {
-        // reorder output back to original format
-        auto reorderName = gatherLayerName + "_" + layer->name + "_format_reorder";
-        auto reorderPrim = cldnn::reorder(reorderName, gatherPrim, outputFormatOriginal, targetDatatype);
-        topology.add(reorderPrim);
-        AddInnerPrimitiveToProfiler(reorderName, gatherLayerName, layer);
-
-        // reshape output back to original dims
-        auto reshapeName = gatherLayerName + "_" + layer->name + "_reshape";
-        auto reshapePrim = cldnn::reshape(reshapeName, reorderName, CldnnTensorFromIEDims(outDimsOriginal));
-        topology.add(reshapePrim);
-        AddInnerPrimitiveToProfiler(reshapeName, gatherLayerName, layer);
-    }
 }
 
 void CLDNNPlugin::Program::CreateGatherTreePrimitive(cldnn::topology & topology, InferenceEngine::CNNLayerPtr & layer) {
index 55d4bc7..2a66563 100644 (file)
@@ -11,39 +11,348 @@ using namespace LayerTestsDefinitions;
 
 namespace {
 
-const std::vector<InferenceEngine::Precision> netPrecisions = {
+const std::vector<InferenceEngine::Precision> netPrecisionsFP32 = {
         InferenceEngine::Precision::FP32,
 };
 
-const std::vector<std::vector<size_t>> inputShapes = {
-        std::vector<size_t>{10, 20, 30, 40},
+const std::vector<InferenceEngine::Precision> netPrecisionsI32 = {
+        InferenceEngine::Precision::I32,
+};
+
+const std::vector<InferenceEngine::Precision> netPrecisionsFP16 = {
+        InferenceEngine::Precision::FP16,
 };
 
 const std::vector<std::vector<int>> indices = {
         std::vector<int>{0, 3, 2, 1},
 };
-const std::vector<std::vector<size_t>> indicesShapes = {
-        std::vector<size_t>{4}
-        // 5d output not supported yet
-        // std::vector<size_t>{2, 2}
+const std::vector<std::vector<size_t>> indicesShapes12 = {
+        std::vector<size_t>{4},
+        std::vector<size_t>{2, 2}
+};
+
+const std::vector<std::vector<size_t>> indicesShapes1 = {
+        std::vector<size_t>{4},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes5 = {
+        std::vector<size_t>{5, 6, 7, 8, 9, 10},
+        std::vector<size_t>{1, 1, 7, 8, 9, 10},
+        std::vector<size_t>{5, 1, 1, 8, 9, 10},
+        std::vector<size_t>{5, 6, 1, 1, 9, 10},
+        std::vector<size_t>{5, 6, 7, 1, 1, 10},
+        std::vector<size_t>{1, 6, 1, 8, 9, 10},
+        std::vector<size_t>{5, 1, 7, 1, 9, 10},
+        std::vector<size_t>{5, 6, 1, 8, 1, 10},
+        std::vector<size_t>{1, 6, 7, 1, 9, 10},
+        std::vector<size_t>{5, 1, 7, 8, 1, 10},
+        std::vector<size_t>{1, 6, 7, 8, 1, 10},
+};
+
+const std::vector<int> axes5 = {5};
+
+const auto Gather6dAxes5 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes1),
+        testing::ValuesIn(axes5),
+        testing::ValuesIn(inputShapes6DAxes5),
+        testing::ValuesIn(netPrecisionsFP32),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes4 = {
+        std::vector<size_t>{5, 6, 7, 8, 9},
+        std::vector<size_t>{1, 6, 7, 8, 9},
+        std::vector<size_t>{5, 1, 7, 8, 9},
+        std::vector<size_t>{5, 6, 1, 8, 9},
+        std::vector<size_t>{5, 6, 7, 1, 9},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes4 = {
+        std::vector<size_t>{5, 6, 7, 8, 9, 10},
+        std::vector<size_t>{1, 1, 7, 8, 9, 10},
+        std::vector<size_t>{5, 1, 1, 8, 9, 10},
+        std::vector<size_t>{5, 6, 1, 1, 9, 10},
+        std::vector<size_t>{5, 6, 7, 1, 9, 1},
+        std::vector<size_t>{1, 6, 1, 8, 9, 10},
+        std::vector<size_t>{5, 1, 7, 1, 9, 10},
+        std::vector<size_t>{5, 6, 1, 8, 9, 1},
+        std::vector<size_t>{1, 6, 7, 1, 9, 10},
+        std::vector<size_t>{5, 1, 7, 8, 9, 1},
+        std::vector<size_t>{1, 6, 7, 8, 9, 1},
+};
+
+const std::vector<int> axes4 = {4};
+
+const auto GatherAxes4 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes12),
+        testing::ValuesIn(axes4),
+        testing::ValuesIn(inputShapesAxes4),
+        testing::ValuesIn(netPrecisionsFP16),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+        GatherAxes4,
+        GatherLayerTest,
+        GatherAxes4,
+        GatherLayerTest::getTestCaseName
+);
+
+const auto Gather6dAxes4 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes1),
+        testing::ValuesIn(axes4),
+        testing::ValuesIn(inputShapes6DAxes4),
+        testing::ValuesIn(netPrecisionsFP32),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+        Gather6dAxes4,
+        GatherLayerTest,
+        Gather6dAxes4,
+        GatherLayerTest::getTestCaseName
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes3 = {
+        std::vector<size_t>{5, 6, 7, 8},
+        std::vector<size_t>{1, 6, 7, 8},
+        std::vector<size_t>{5, 1, 7, 8},
+        std::vector<size_t>{5, 6, 1, 8},
+        std::vector<size_t>{5, 6, 7, 8, 9},
+        std::vector<size_t>{1, 6, 7, 8, 9},
+        std::vector<size_t>{5, 1, 7, 8, 9},
+        std::vector<size_t>{5, 6, 1, 8, 9},
+        std::vector<size_t>{5, 6, 7, 8, 1},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes3 = {
+        std::vector<size_t>{5, 6, 7, 8, 9, 10},
+        std::vector<size_t>{1, 1, 7, 8, 9, 10},
+        std::vector<size_t>{5, 1, 1, 8, 9, 10},
+        std::vector<size_t>{5, 6, 1, 8, 1, 10},
+        std::vector<size_t>{5, 6, 7, 8, 1, 1},
+        std::vector<size_t>{1, 6, 1, 8, 9, 10},
+        std::vector<size_t>{5, 1, 7, 8, 1, 10},
+        std::vector<size_t>{5, 6, 1, 8, 9, 1},
+        std::vector<size_t>{1, 6, 7, 8, 1, 10},
+        std::vector<size_t>{5, 1, 7, 8, 9, 1},
+        std::vector<size_t>{1, 6, 7, 8, 9, 1},
 };
 
-const std::vector<int> axes = {0, 1, 2, 3};
+const std::vector<int> axes3 = {3};
+
+const auto GatherAxes3 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes12),
+        testing::ValuesIn(axes3),
+        testing::ValuesIn(inputShapesAxes3),
+        testing::ValuesIn(netPrecisionsFP32),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+        GatherAxes3,
+        GatherLayerTest,
+        GatherAxes3,
+        GatherLayerTest::getTestCaseName
+);
+
+const auto Gather6dAxes3 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes1),
+        testing::ValuesIn(axes3),
+        testing::ValuesIn(inputShapes6DAxes3),
+        testing::ValuesIn(netPrecisionsI32),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+        Gather6dAxes3,
+        GatherLayerTest,
+        Gather6dAxes3,
+        GatherLayerTest::getTestCaseName
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes2 = {
+        std::vector<size_t>{5, 6, 7, 8},
+        std::vector<size_t>{1, 6, 7, 8},
+        std::vector<size_t>{5, 1, 7, 8},
+        std::vector<size_t>{5, 6, 7, 1},
+        std::vector<size_t>{5, 6, 7, 8, 9},
+        std::vector<size_t>{1, 6, 7, 8, 9},
+        std::vector<size_t>{5, 1, 7, 8, 9},
+        std::vector<size_t>{5, 6, 7, 1, 9},
+        std::vector<size_t>{5, 6, 7, 8, 1},
+};
 
+const std::vector<std::vector<size_t>> inputShapes6DAxes2 = {
+        std::vector<size_t>{5, 6, 7, 8, 9, 10},
+        std::vector<size_t>{1, 1, 7, 8, 9, 10},
+        std::vector<size_t>{5, 1, 7, 1, 9, 10},
+        std::vector<size_t>{5, 6, 7, 1, 1, 10},
+        std::vector<size_t>{5, 6, 7, 8, 1, 1},
+        std::vector<size_t>{1, 6, 7, 1, 9, 10},
+        std::vector<size_t>{5, 1, 7, 8, 1, 10},
+        std::vector<size_t>{5, 6, 7, 1, 9, 1},
+        std::vector<size_t>{1, 6, 7, 8, 1, 10},
+        std::vector<size_t>{5, 1, 7, 8, 9, 1},
+        std::vector<size_t>{1, 6, 7, 8, 9, 1},
+};
+
+const std::vector<int> axes2 = {2};
+
+const auto GatherAxes2 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes12),
+        testing::ValuesIn(axes2),
+        testing::ValuesIn(inputShapesAxes2),
+        testing::ValuesIn(netPrecisionsFP32),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+        GatherAxes2,
+        GatherLayerTest,
+        GatherAxes2,
+        GatherLayerTest::getTestCaseName
+);
+
+const auto Gather6dAxes2 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes1),
+        testing::ValuesIn(axes2),
+        testing::ValuesIn(inputShapes6DAxes2),
+        testing::ValuesIn(netPrecisionsFP16),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+        Gather6dAxes2,
+        GatherLayerTest,
+        Gather6dAxes2,
+        GatherLayerTest::getTestCaseName
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes1 = {
+        std::vector<size_t>{5, 6, 7, 8},
+        std::vector<size_t>{1, 6, 7, 8},
+        std::vector<size_t>{5, 6, 1, 8},
+        std::vector<size_t>{5, 6, 7, 1},
+        std::vector<size_t>{5, 6, 7, 8, 9},
+        std::vector<size_t>{1, 6, 7, 8, 9},
+        std::vector<size_t>{5, 6, 1, 8, 9},
+        std::vector<size_t>{5, 6, 7, 1, 9},
+        std::vector<size_t>{5, 6, 7, 8, 1},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes1 = {
+        std::vector<size_t>{5, 6, 7, 8, 9, 10},
+        std::vector<size_t>{1, 6, 1, 8, 9, 10},
+        std::vector<size_t>{5, 6, 1, 1, 9, 10},
+        std::vector<size_t>{5, 6, 7, 1, 1, 10},
+        std::vector<size_t>{5, 6, 7, 8, 1, 1},
+        std::vector<size_t>{1, 6, 7, 1, 9, 10},
+        std::vector<size_t>{5, 6, 1, 8, 1, 10},
+        std::vector<size_t>{5, 6, 1, 8, 9, 1},
+        std::vector<size_t>{1, 6, 7, 8, 1, 10},
+        std::vector<size_t>{1, 6, 7, 8, 9, 1},
+        std::vector<size_t>{5, 6, 7, 1, 9, 1},
+};
+
+const std::vector<int> axes1 = {1};
+
+const auto GatherAxes1 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes12),
+        testing::ValuesIn(axes1),
+        testing::ValuesIn(inputShapesAxes1),
+        testing::ValuesIn(netPrecisionsI32),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+        GatherAxes1,
+        GatherLayerTest,
+        GatherAxes1,
+        GatherLayerTest::getTestCaseName
+);
+
+const auto Gather6dAxes1 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes1),
+        testing::ValuesIn(axes1),
+        testing::ValuesIn(inputShapes6DAxes1),
+        testing::ValuesIn(netPrecisionsFP32),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+        Gather6dAxes1,
+        GatherLayerTest,
+        Gather6dAxes1,
+        GatherLayerTest::getTestCaseName
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes0 = {
+        std::vector<size_t>{5, 6, 7, 8},
+        std::vector<size_t>{5, 1, 7, 8},
+        std::vector<size_t>{5, 6, 1, 8},
+        std::vector<size_t>{5, 6, 7, 1},
+        std::vector<size_t>{5, 6, 7, 8, 9},
+        std::vector<size_t>{5, 1, 7, 8, 9},
+        std::vector<size_t>{5, 6, 1, 8, 9},
+        std::vector<size_t>{5, 6, 7, 1, 9},
+        std::vector<size_t>{5, 6, 7, 8, 1},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes0 = {
+        std::vector<size_t>{5, 6, 7, 8, 9, 10},
+        std::vector<size_t>{5, 1, 1, 8, 9, 10},
+        std::vector<size_t>{5, 6, 1, 1, 9, 10},
+        std::vector<size_t>{5, 6, 7, 1, 1, 10},
+        std::vector<size_t>{5, 6, 7, 8, 1, 1},
+        std::vector<size_t>{5, 1, 7, 1, 9, 10},
+        std::vector<size_t>{5, 6, 1, 8, 1, 10},
+        std::vector<size_t>{5, 6, 1, 8, 9, 1},
+        std::vector<size_t>{5, 1, 7, 8, 1, 10},
+        std::vector<size_t>{5, 1, 7, 8, 9, 1},
+        std::vector<size_t>{5, 6, 7, 1, 9, 1},
+};
+
+const std::vector<int> axes0 = {0};
+
+const auto GatherAxes0 = testing::Combine(
+        testing::ValuesIn(indices),
+        testing::ValuesIn(indicesShapes12),
+        testing::ValuesIn(axes0),
+        testing::ValuesIn(inputShapesAxes0),
+        testing::ValuesIn(netPrecisionsFP32),
+        testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+        GatherAxes0,
+        GatherLayerTest,
+        GatherAxes0,
+        GatherLayerTest::getTestCaseName
+);
 
-const auto params = testing::Combine(
+const auto Gather6dAxes0 = testing::Combine(
         testing::ValuesIn(indices),
-        testing::ValuesIn(indicesShapes),
-        testing::ValuesIn(axes),
-        testing::ValuesIn(inputShapes),
-        testing::ValuesIn(netPrecisions),
+        testing::ValuesIn(indicesShapes1),
+        testing::ValuesIn(axes0),
+        testing::ValuesIn(inputShapes6DAxes0),
+        testing::ValuesIn(netPrecisionsFP32),
         testing::Values(CommonTestUtils::DEVICE_GPU)
 );
 
 INSTANTIATE_TEST_CASE_P(
-        Gather,
+        Gather6dAxes0,
         GatherLayerTest,
-        params,
+        Gather6dAxes0,
         GatherLayerTest::getTestCaseName
 );
 
index 37b61ee..437b6ac 100644 (file)
@@ -7,18 +7,18 @@
 INSTANTIATE_TEST_CASE_P(
         smoke_GPU_TestsGather, GatherTFTests,
         ::testing::Values(
-        gatherTF_test_params{ "GPU", "FP32", { 1, 4 }, in0,{ 2, 2 }, dict2D, 0, { 1, 4, 2 }, ref_in0_a0_d22 },
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict, 0, { 2, 2, 2, 3 }, ref_in0_a0_d223 },
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict,-3, { 2, 2, 2, 3 }, ref_in0_a0_d223 },
+        gatherTF_test_params{ "GPU", "FP32", { 1, 4 }, in0,{ 2, 2 }, dict2D, 0, { 1, 4, 1, 2 }, ref_in0_a0_d22 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict, 0, { 2, 2, 1, 2, 3 }, ref_in0_a0_d223 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict,-3, { 2, 2, 1, 2, 3 }, ref_in0_a0_d223 },
 
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict, 0, { 2, 2, 2, 2 }, ref_in1_a0_d322 },
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict,-3, { 2, 2, 2, 2 }, ref_in1_a0_d322 },
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict, 1, { 2, 2, 2, 2 }, ref_in1_a1_d232 },
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict,-2, { 2, 2, 2, 2 }, ref_in1_a1_d232 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict, 0, { 2, 2, 1, 2, 2 }, ref_in1_a0_d322 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict,-3, { 2, 2, 1, 2, 2 }, ref_in1_a0_d322 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict, 1, { 2, 2, 2, 2, 1 }, ref_in1_a1_d232 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict,-2, { 2, 2, 2, 2, 1 }, ref_in1_a1_d232 },
 
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict, 2, { 2, 2, 2, 2 }, ref_in1_a2_d223 },
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict,-1, { 2, 2, 2, 2 }, ref_in1_a2_d223 },
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2 }, ref_in0_a2_d232 },
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict,-1, { 2, 3, 2, 2 }, ref_in0_a2_d232 },
-        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2 }, ref_in0_a2_d232 }
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict, 2, { 2, 2, 2, 2, 1 }, ref_in1_a2_d223 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict,-1, { 2, 2, 2, 2, 1 }, ref_in1_a2_d223 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2, 1 }, ref_in0_a2_d232 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict,-1, { 2, 3, 2, 2, 1 }, ref_in0_a2_d232 },
+        gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2, 1 }, ref_in0_a2_d232 }
 ));
index 9b22d5c..4b8874d 100644 (file)
@@ -35,7 +35,9 @@ struct gather : public primitive_base<gather> {
         along_b,
         along_f,
         along_x,
-        along_y
+        along_y,
+        along_z,
+        along_w
     };
 
     /// @brief Constructs gather primitive.
index b4f0f4c..78f248d 100644 (file)
@@ -23,10 +23,16 @@ namespace kernel_selector {
 static size_t GetGatherChannelIndex(const gather_params& params) {
     Tensor::DataChannelName name = Tensor::DataChannelName::X;
 
+    size_t inputSize = params.inputs[0].GetDims().size();
+
     switch (params.axis) {
         case GatherAxis::X:
-            return 3;
+            return inputSize - 1;
         case GatherAxis::Y:
+            return inputSize - 2;
+        case GatherAxis::Z:
+            return inputSize - 3;
+        case GatherAxis::W:
             return 2;
         case GatherAxis::FEATURE:
             return 1;
@@ -51,6 +57,10 @@ ParamsKey GatherKernelRef::GetSupportedKey() const {
     k.EnableOutputDataType(Datatype::UINT8);
     k.EnableInputLayout(DataLayout::bfyx);
     k.EnableOutputLayout(DataLayout::bfyx);
+    k.EnableInputLayout(DataLayout::bfzyx);
+    k.EnableOutputLayout(DataLayout::bfzyx);
+    k.EnableInputLayout(DataLayout::bfwzyx);
+    k.EnableOutputLayout(DataLayout::bfwzyx);
     k.EnableTensorOffset();
     k.EnableTensorPitches();
     k.EnableBatching();
@@ -78,12 +88,26 @@ static inline std::string GetOrderString(std::vector<std::string>& order) {
     std::string order_str = order[0];
     for (size_t i = 1; i < order.size(); i++)
         order_str += ", " + order[i];
-    
+
     return order_str;
 }
 
+static inline std::vector<std::string> GetOrder(size_t size) {
+    std::vector<std::string> idx_order;
+    if (size <= 4) {
+        idx_order = {"b", "f", "y", "x"};
+    } else if (size == 5) {
+        idx_order = {"b", "f", "z", "y", "x"};
+    } else if (size == 6) {
+        idx_order = {"b", "f", "w", "z", "y", "x"};
+    }
+    
+    return idx_order;
+}
+
 static std::string GetDictionaryIndexOrder(const gather_params& params, size_t axis) {
-    std::vector<std::string> default_order = { "b", "f", "y", "x" };
+    std::vector<std::string> idx_order = GetOrder(params.output.GetDims().size());
+
     const std::string input_axis_index_macro = "INPUT_AXIS_INDEX";
     const std::string zeroVal = "0";
 
@@ -92,38 +116,57 @@ static std::string GetDictionaryIndexOrder(const gather_params& params, size_t a
 
     // Shift indices of Gather dictionary input related to output dims
     for (size_t i = axis + 1; i < dictionary_dims_num; i++)
-        default_order[i] = default_order[i + indices_dims_num - 1];
+        idx_order[i] = idx_order[i + indices_dims_num - 1];
 
-    for (size_t i = dictionary_dims_num; i < default_order.size(); i++)
-        default_order[i] = zeroVal;
+    for (size_t i = dictionary_dims_num; i < idx_order.size(); i++)
+        idx_order[i] = zeroVal;
+    
+    // Fix size to inputs[0] dims size
+    for (size_t i = 0; i < params.output.GetDims().size() - params.inputs[0].GetDims().size(); i++)
+        idx_order.pop_back();
 
-    default_order[axis] = input_axis_index_macro;
+    idx_order[axis] = input_axis_index_macro;
 
-    return GetOrderString(default_order);
+    return GetOrderString(idx_order);
 }
 
 static std::string GetIndecesIdxOrder(const gather_params& params, size_t axis) {
-    std::vector<std::string> default_order = { "b", "f", "y", "x" };
+    std::vector<std::string> idx_order = GetOrder(params.output.GetDims().size());
+
     const std::string zero_val = "0";
 
     size_t indices_dims_num = GetNonEmptyDimsNumber(params.inputs[1]);
 
     // Shift indices of Gather indices input related to output dims
     for (size_t i = 0; i < indices_dims_num; i++)
-        default_order[i] = default_order[axis + i];
+        idx_order[i] = idx_order[axis + i];
 
-    for (size_t i = indices_dims_num; i < default_order.size(); i++)
-        default_order[i] = zero_val;
+    for (size_t i = indices_dims_num; i < idx_order.size(); i++)
+        idx_order[i] = zero_val;
 
-    return GetOrderString(default_order);
+    // Fix size to inputs[1] dims size
+    for (size_t i = 0; i < params.output.GetDims().size() - params.inputs[1].GetDims().size(); i++)
+        idx_order.pop_back();
+
+    return GetOrderString(idx_order);
 }
 
 CommonDispatchData GatherKernelRef::SetDefault(const gather_params& params, const optional_params&) const {
     CommonDispatchData runInfo;
     const auto& output = params.output;
 
-    std::vector<size_t> global = {output.Batch().v, output.Feature().v,output.X().v * output.Y().v};
-    std::vector<size_t> local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
+    std::vector<size_t> global;
+    std::vector<size_t> local;
+
+    if (output.GetLayout() == DataLayout::bfyx) {
+        global = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
+    } else if (output.GetLayout() == DataLayout::bfzyx) {
+        global = {output.X().v, output.Y().v * output.Z().v, output.Feature().v * output.Batch().v};
+    } else {
+        global = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
+    }
+
+    local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
 
     runInfo.gws0 = global[0];
     runInfo.gws1 = global[1];
@@ -145,7 +188,9 @@ JitConstants GatherKernelRef::GetJitConstants(const gather_params& params) const
     jit.AddConstant(MakeJitConstant("INDICES_INDEX_ORDER", GetIndecesIdxOrder(params, GetGatherChannelIndex(params))));
 
     if (!params.fused_ops.empty()) {
-        FusedOpsConfiguration conf = { "", {"b", "f", "y", "x"}, "val", params.inputs[0].GetDType() };
+        std::vector<std::string> idx_order = GetOrder(params.inputs[0].GetDims().size());
+
+        FusedOpsConfiguration conf = { "", idx_order, "val", params.inputs[0].GetDType() };
         jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
     }
 
index 1b205a6..e7bfaa9 100644 (file)
@@ -18,6 +18,7 @@
 #define INPUT_AXIS_INDEX (uint)indices[indices_idx]
 #define GET_DICTIONARY_INDEX(idx_order) INPUT0_GET_INDEX(idx_order)
 #define GET_INDICES_INDEX(idx_order) INPUT1_GET_INDEX(idx_order)
+#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order)
 
 KERNEL(gather_ref)(const __global INPUT0_TYPE* dictionary,
                    const __global INPUT1_TYPE* indices,
@@ -27,15 +28,32 @@ KERNEL(gather_ref)(const __global INPUT0_TYPE* dictionary,
 #endif
 )
 {
-    const uint b = get_global_id(0);
-    const uint f = get_global_id(1);
-    const uint yx = get_global_id(2);
-    const uint y = yx / OUTPUT_SIZE_X;
-    const uint x = yx % OUTPUT_SIZE_X;
+    #if OUTPUT_DIMS == 6
+        #define ORDER b,f,w,z,y,x
+        const uint x = (uint)get_global_id(0) % OUTPUT_SIZE_X;
+        const uint y = (uint)get_global_id(0) / OUTPUT_SIZE_X;
+        const uint z = (uint)get_global_id(1) % OUTPUT_SIZE_Z;
+        const uint w = (uint)get_global_id(1) / OUTPUT_SIZE_Z;
+        const uint f = (uint)get_global_id(2) % OUTPUT_FEATURE_NUM;
+        const uint b = (uint)get_global_id(2) / OUTPUT_FEATURE_NUM;
+    #elif OUTPUT_DIMS == 5
+        #define ORDER b,f,z,y,x
+        const uint x = (uint)get_global_id(0);
+        const uint y = (uint)get_global_id(1) % OUTPUT_SIZE_Y;
+        const uint z = (uint)get_global_id(1) / OUTPUT_SIZE_Y;
+        const uint f = (uint)get_global_id(2) % OUTPUT_FEATURE_NUM;
+        const uint b = (uint)get_global_id(2) / OUTPUT_FEATURE_NUM;
+    #elif OUTPUT_DIMS == 4
+        #define ORDER b,f,y,x
+        const uint x = (uint)get_global_id(0);
+        const uint y = (uint)get_global_id(1);
+        const uint f = (uint)get_global_id(2) % OUTPUT_FEATURE_NUM;
+        const uint b = (uint)get_global_id(2) / OUTPUT_FEATURE_NUM;
+    #endif
 
     const uint indices_idx = GET_INDICES_INDEX(INDICES_INDEX_ORDER);
     const uint dictionary_idx = GET_DICTIONARY_INDEX(DICTIONARY_INDEX_ORDER);
-    const uint output_idx = OUTPUT_GET_INDEX(b, f, y, x);
+    const uint output_idx = GET_INDEX(OUTPUT,,ORDER);
 
     INPUT0_TYPE val = dictionary[dictionary_idx];
 
index 4f47706..02e7824 100644 (file)
@@ -31,16 +31,28 @@ layout gather_inst::calc_output_layout(gather_node const& node) {
     auto desc = node.get_primitive();
 
     auto input_layout = node.input(0).get_output_layout();
-    auto input_format = input_layout.format;
-
     auto output_shape = desc->output_shape;
+    auto output_format = input_layout.format;
+
+    int spatialNum = 0;
+    for (auto i : node.input(1).get_output_layout().size.raw)
+         spatialNum += (i > 1) ? 1 : 0;
+
+    // change output format if input indeces > 1
+    if (spatialNum == 2 && output_format == cldnn::format::bfzyx) {
+        output_format = cldnn::format::bfwzyx;
+    } else if (spatialNum == 2 && output_format == cldnn::format::bfyx) {
+        output_format = cldnn::format::bfzyx;
+    } else if (spatialNum == 3 && output_format == cldnn::format::bfyx) {
+        output_format = cldnn::format::bfwzyx;
+    }
 
     auto output_type = input_layout.data_type;
     if (node.has_fused_primitives()) {
         output_type = node.get_fused_output_layout().data_type;
     }
 
-    return layout{output_type, input_format, output_shape};
+    return layout{output_type, output_format, output_shape};
 }
 
 std::string gather_inst::to_string(gather_node const& node) {
index 9280c20..c6f7f11 100644 (file)
@@ -32,6 +32,10 @@ kernel_selector::gather_axis convert_axis(gather::gather_axis axis) {
             return kernel_selector::gather_axis::X;
         case gather::along_y:
             return kernel_selector::gather_axis::Y;
+        case gather::along_z:
+            return kernel_selector::gather_axis::Z;
+        case gather::along_w:
+            return kernel_selector::gather_axis::W;
         case gather::along_f:
             return kernel_selector::gather_axis::FEATURE;
         case gather::along_b:
@@ -76,6 +80,14 @@ attach_gather_gpu::attach_gather_gpu() {
     implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
     implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
     implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
+    
+    implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
+    implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
+    implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
+    
+    implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
+    implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
+    implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
 }
 
 }  // namespace detail
index a65832a..0266e5b 100644 (file)
@@ -4959,6 +4959,18 @@ struct gather_test_params {
 #define CASE_GATHER_FP16_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
 #define CASE_GATHER_FP16_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
 
+#define CASE_GATHER_5D_FP32_1 {2, 3, 1, 4, 1}, {4, 1, 1, 1}, {4, 3, 1, 4, 1}, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_GATHER_5D_FP32_2 {2, 3, 2, 2, 2}, {2, 1, 1, 1}, {2, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_GATHER_5D_FP32_3 {5, 3, 2, 2, 2}, {3, 1, 1, 1}, {5, 3, 2, 3, 2}, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_GATHER_5D_FP32_4 {2, 3, 1, 4, 4}, {2, 1, 1, 1}, {2, 3, 1, 4, 2}, cldnn::gather::gather_axis::along_z, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_GATHER_5D_FP32_5 {3, 1, 5, 2, 1}, {2, 1, 1, 1}, {3, 1, 2, 2, 1}, cldnn::gather::gather_axis::along_x, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+
+#define CASE_GATHER_5D_FP16_1 {3, 2, 1, 2, 1}, {2, 1, 1, 1}, {2, 2, 2, 2, 1}, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_GATHER_5D_FP16_2 {1, 3, 1, 2, 1}, {2, 1, 1, 1}, {1, 2, 1, 2, 1}, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_GATHER_5D_FP16_3 {2, 3, 1, 3, 3}, {1, 2, 1, 1}, {2, 3, 1, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_GATHER_5D_FP16_4 {3, 2, 2, 2, 2}, {2, 1, 1, 1}, {3, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_z, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_GATHER_5D_FP16_5 {1, 1, 2, 1, 1}, {3, 1, 1, 1}, {1, 1, 3, 1, 1}, cldnn::gather::gather_axis::along_x, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+
 class GatherPrimitiveFusingTest : public ::BaseFusingTest<gather_test_params> {
 public:
     void execute(gather_test_params& p) {
@@ -4985,6 +4997,10 @@ public:
                 return p.dictionary_shape.spatial[0];
             case cldnn::gather::gather_axis::along_y:
                 return p.dictionary_shape.spatial[1];
+            case cldnn::gather::gather_axis::along_z:
+                return p.dictionary_shape.spatial[2];
+            case cldnn::gather::gather_axis::along_w:
+                return p.dictionary_shape.spatial[3];
             case cldnn::gather::gather_axis::along_f:
                 return p.dictionary_shape.feature[0];
             case cldnn::gather::gather_axis::along_b:
@@ -5030,6 +5046,18 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_quantize,
                         gather_test_params{ CASE_GATHER_FP16_3, 2, 3 },
                         gather_test_params{ CASE_GATHER_FP16_4, 2, 3 },
                         gather_test_params{ CASE_GATHER_FP16_5, 2, 3 },
+
+                        gather_test_params{ CASE_GATHER_5D_FP32_1, 2, 3 },
+                        gather_test_params{ CASE_GATHER_5D_FP32_2, 2, 3 },
+                        gather_test_params{ CASE_GATHER_5D_FP32_3, 2, 3 },
+                        gather_test_params{ CASE_GATHER_5D_FP32_4, 2, 3 },
+                        gather_test_params{ CASE_GATHER_5D_FP32_5, 2, 3 },
+
+                        gather_test_params{ CASE_GATHER_5D_FP16_1, 2, 3 },
+                        gather_test_params{ CASE_GATHER_5D_FP16_2, 2, 3 },
+                        gather_test_params{ CASE_GATHER_5D_FP16_3, 2, 3 },
+                        gather_test_params{ CASE_GATHER_5D_FP16_4, 2, 3 },
+                        gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 3 },
 }), );
 
 class gather_scale_activation : public GatherPrimitiveFusingTest {};
@@ -5061,6 +5089,18 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_scale_activation,
                         gather_test_params{ CASE_GATHER_FP16_3, 2, 4 },
                         gather_test_params{ CASE_GATHER_FP16_4, 2, 4 },
                         gather_test_params{ CASE_GATHER_FP16_5, 2, 4 },
+
+                        gather_test_params{ CASE_GATHER_5D_FP32_1, 2, 4 },
+                        gather_test_params{ CASE_GATHER_5D_FP32_2, 2, 4 },
+                        gather_test_params{ CASE_GATHER_5D_FP32_3, 2, 4 },
+                        gather_test_params{ CASE_GATHER_5D_FP32_4, 2, 4 },
+                        gather_test_params{ CASE_GATHER_5D_FP32_5, 2, 4 },
+
+                        gather_test_params{ CASE_GATHER_5D_FP16_1, 2, 4 },
+                        gather_test_params{ CASE_GATHER_5D_FP16_2, 2, 4 },
+                        gather_test_params{ CASE_GATHER_5D_FP16_3, 2, 4 },
+                        gather_test_params{ CASE_GATHER_5D_FP16_4, 2, 4 },
+                        gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 4 },
 }), );
 
 /* ------------------------------------------------------------------------------------------------------------ */
index 99c73cb..d308e40 100644 (file)
@@ -125,7 +125,7 @@ TEST(gather_gpu_fp16, d222_axisB) {
     topology.add(input_layout("InputDictionary", input1.get_layout()));
     topology.add(input_layout("InputText", input2.get_layout()));
     topology.add(
-        gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+        gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
     );
 
     network network(engine, topology);
@@ -186,7 +186,7 @@ TEST(gather_gpu_fp16, d22_axisY) {
     topology.add(input_layout("InputDictionary", input1.get_layout()));
     topology.add(input_layout("InputText", input2.get_layout()));
     topology.add(
-        gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+        gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
     );
 
     network network(engine, topology);
@@ -247,7 +247,7 @@ TEST(gather_gpu_fp16, d22_axisF) {
     topology.add(input_layout("InputDictionary", input1.get_layout()));
     topology.add(input_layout("InputText", input2.get_layout()));
     topology.add(
-            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
     );
 
     network network(engine, topology);
@@ -366,7 +366,7 @@ TEST(gather_gpu_fp32, d222_axisB) {
     topology.add(input_layout("InputDictionary", input1.get_layout()));
     topology.add(input_layout("InputText", input2.get_layout()));
     topology.add(
-        gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+        gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
     );
 
     network network(engine, topology);
@@ -427,7 +427,7 @@ TEST(gather_gpu_fp32, d22_axisY) {
     topology.add(input_layout("InputDictionary", input1.get_layout()));
     topology.add(input_layout("InputText", input2.get_layout()));
     topology.add(
-        gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+        gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
     );
 
     network network(engine, topology);
@@ -488,7 +488,7 @@ TEST(gather_gpu_fp32, d22_axisF) {
     topology.add(input_layout("InputDictionary", input1.get_layout()));
     topology.add(input_layout("InputText", input2.get_layout()));
     topology.add(
-            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
     );
 
     network network(engine, topology);
@@ -549,7 +549,7 @@ TEST(gather_gpu_int32, d22_axisF) {
     topology.add(input_layout("InputDictionary", input1.get_layout()));
     topology.add(input_layout("InputText", input2.get_layout()));
     topology.add(
-            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
     );
 
     network network(engine, topology);
@@ -668,7 +668,7 @@ TEST(gather_gpu_int32, d222_axisB) {
     topology.add(input_layout("InputDictionary", input1.get_layout()));
     topology.add(input_layout("InputText", input2.get_layout()));
     topology.add(
-            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
     );
 
     network network(engine, topology);
@@ -729,7 +729,7 @@ TEST(gather_gpu_int32, d22_axisY) {
     topology.add(input_layout("InputDictionary", input1.get_layout()));
     topology.add(input_layout("InputText", input2.get_layout()));
     topology.add(
-            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+            gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
     );
 
     network network(engine, topology);