int axis = gatherLayer->GetParamAsInt("axis", 0);
// Be careful, TensorFlow consist negative axis interpretation bug. Here: -3 = b, -2 = f, -1 = y, but must be -3 = f, -2 = y, -1 = x
- auto cldnnAxisFromIE = [](int axis) {
- switch (axis) {
- case 0: return cldnn::gather::gather_axis::along_b;
- case 1: return cldnn::gather::gather_axis::along_f;
- case 2: return cldnn::gather::gather_axis::along_y;
- case 3: return cldnn::gather::gather_axis::along_x;
- case -1: return cldnn::gather::gather_axis::along_y;
- case -2: return cldnn::gather::gather_axis::along_f;
- case -3: return cldnn::gather::gather_axis::along_b;
- default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
+ auto cldnnAxisFromIE = [](int axis, cldnn::format inputFormat) {
+ if (axis == 0) {
+ return cldnn::gather::gather_axis::along_b;
+ } else if (axis == 1) {
+ return cldnn::gather::gather_axis::along_f;
+ }
+
+ if (inputFormat == cldnn::format::bfyx) {
+ switch (axis) {
+ case 2: return cldnn::gather::gather_axis::along_y;
+ case 3: return cldnn::gather::gather_axis::along_x;
+ case -1: return cldnn::gather::gather_axis::along_y;
+ case -2: return cldnn::gather::gather_axis::along_f;
+ case -3: return cldnn::gather::gather_axis::along_b;
+ default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
+ }
+ } else if (inputFormat == cldnn::format::bfzyx) {
+ switch (axis) {
+ case 2: return cldnn::gather::gather_axis::along_z;
+ case 3: return cldnn::gather::gather_axis::along_y;
+ case 4: return cldnn::gather::gather_axis::along_x;
+ case -1: return cldnn::gather::gather_axis::along_y;
+ case -2: return cldnn::gather::gather_axis::along_z;
+ case -3: return cldnn::gather::gather_axis::along_f;
+ case -4: return cldnn::gather::gather_axis::along_b;
+ default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
+ }
+ } else if (inputFormat == cldnn::format::bfwzyx) {
+ switch (axis) {
+ case 2: return cldnn::gather::gather_axis::along_w;
+ case 3: return cldnn::gather::gather_axis::along_z;
+ case 4: return cldnn::gather::gather_axis::along_y;
+ case 5: return cldnn::gather::gather_axis::along_x;
+ case -1: return cldnn::gather::gather_axis::along_y;
+ case -2: return cldnn::gather::gather_axis::along_z;
+ case -3: return cldnn::gather::gather_axis::along_w;
+ case -4: return cldnn::gather::gather_axis::along_f;
+ case -5: return cldnn::gather::gather_axis::along_b;
+ default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
+ }
+ } else {
+ THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
}
};
}
}
- auto indicesDims = layer->insData[1].lock()->getTensorDesc().getDims();
- auto indicesLayout = layer->insData[1].lock()->getTensorDesc().getLayout();
- auto indicesFormat = FormatFromLayout(indicesLayout);
-
- auto inputDims = layer->insData[0].lock()->getTensorDesc().getDims();
auto inputLayout = layer->insData[0].lock()->getTensorDesc().getLayout();
- auto inputFormat = FormatFromLayout(inputLayout);
-
- auto outDimsOriginal = layer->outData[0]->getTensorDesc().getDims();
- auto outputLayoutOriginal = layer->outData[0]->getTensorDesc().getLayout();
- auto outputFormatOriginal = FormatFromLayout(outputLayoutOriginal);
-
- auto outDims = outDimsOriginal;
- auto targetDatatype = DataTypeFromPrecision(layer->precision);
-
- auto nonNegativeAxis = (axis >= 0) ? axis : axis + 3;
-
- // following vector is needed just to check if we can apply bfyx WA
- SizeVector originalRequiredDims;
- for (size_t d = 0; d < inputDims.size(); d++) {
- if ((d == nonNegativeAxis) || (inputDims[d] > 1)) {
- originalRequiredDims.push_back(d);
- }
- }
-
- if (originalRequiredDims.size() < 4) {
- // make sure that we will have at least 4 required dimensions
- auto originalAxesIt = originalRequiredDims.begin();
- for (size_t i = 0; i < 4; i++) {
- int dimFoundAtIndex = -1;
- for (size_t j = 0; j < originalRequiredDims.size(); j++) {
- if (originalRequiredDims[j] == i) {
- dimFoundAtIndex = j;
- }
- }
- if (dimFoundAtIndex == -1) {
- originalAxesIt = originalRequiredDims.insert(originalAxesIt, i);
- }
- originalAxesIt++;
- }
- }
-
- // clDNN primitive is missing proper support of 5d/6d inputs
- // but we can still fall back to bfyx format in some cases
- bool bfyx_wa = ((inputFormat == cldnn::format::bfzyx || inputFormat == cldnn::format::bfwzyx) &&
- (originalRequiredDims.size() == 4) &&
- (indicesFormat == cldnn::format::bfyx));
-
- if (bfyx_wa) {
- if (indicesDims.size() > 1) {
- // reshape the indices dims to 1D (along batch axis)
- size_t indDimAcc = std::accumulate(indicesDims.begin(), indicesDims.end(), 1, std::multiplies<size_t>());
- SizeVector targetIndDims{ indDimAcc, 1, 1, 1 };
-
- auto reshapeName = reorderedInputs[1] + "_" + layer->name + "_reshape";
- auto targetTensor = CldnnTensorFromIEDims(targetIndDims);
- auto reshapePrim = cldnn::reshape(reshapeName, reorderedInputs[1], CldnnTensorFromIEDims(targetIndDims));
- topology.add(reshapePrim);
- AddInnerPrimitiveToProfiler(reshapeName, gatherLayerName, layer);
- reorderedInputs[1] = reshapeName;
-
- // adjust expected output dims
- outDims[nonNegativeAxis] = indDimAcc;
- outDims.erase(outDims.begin() + nonNegativeAxis + 1, outDims.begin() + nonNegativeAxis + indicesDims.size());
- }
-
- // reorder input to bfyx
- auto reorderName = reorderedInputs[0] + "_" + layer->name + "_format_reorder";
- auto reorderPrim = cldnn::reorder(reorderName, reorderedInputs[0], cldnn::format::bfyx, targetDatatype);
- topology.add(reorderPrim);
- AddInnerPrimitiveToProfiler(reorderName, gatherLayerName, layer);
- reorderedInputs[0] = reorderName;
-
- // calculate new input/output dims in bfyx format
- SizeVector targetInDims(4);
- SizeVector targetOutDims(4);
- for (size_t d = 0; d < 4; d++) {
- targetInDims[d] = inputDims[originalRequiredDims[d]];
- targetOutDims[d] = outDims[originalRequiredDims[d]];
- }
- outDims = targetOutDims;
-
- // calculate new axis in bfyx format
- for (size_t d = 0; d < originalRequiredDims.size(); d++) {
- if (originalRequiredDims[d] == nonNegativeAxis) {
- axis = d;
- }
- }
-
- // reshape the input dims to the ones expected in bfyx format
- auto reshapeName = reorderedInputs[0] + "_" + layer->name + "_reshape";
- auto targetTensor = CldnnTensorFromIEDims(targetInDims);
- auto reshapePrim = cldnn::reshape(reshapeName, reorderedInputs[0], CldnnTensorFromIEDims(targetInDims));
- topology.add(reshapePrim);
- AddInnerPrimitiveToProfiler(reshapeName, gatherLayerName, layer);
- reorderedInputs[0] = reshapeName;
- }
+ auto outDims = layer->outData[0]->getTensorDesc().getDims();
auto gatherPrim = cldnn::gather(
gatherLayerName,
reorderedInputs[0],
reorderedInputs[1],
- cldnnAxisFromIE(axis),
+ cldnnAxisFromIE(axis, FormatFromLayout(inputLayout)),
CldnnTensorFromIEDims(outDims));
topology.add(gatherPrim);
AddPrimitiveToProfiler(gatherLayerName, layer);
-
- if (bfyx_wa) {
- // reorder output back to original format
- auto reorderName = gatherLayerName + "_" + layer->name + "_format_reorder";
- auto reorderPrim = cldnn::reorder(reorderName, gatherPrim, outputFormatOriginal, targetDatatype);
- topology.add(reorderPrim);
- AddInnerPrimitiveToProfiler(reorderName, gatherLayerName, layer);
-
- // reshape output back to original dims
- auto reshapeName = gatherLayerName + "_" + layer->name + "_reshape";
- auto reshapePrim = cldnn::reshape(reshapeName, reorderName, CldnnTensorFromIEDims(outDimsOriginal));
- topology.add(reshapePrim);
- AddInnerPrimitiveToProfiler(reshapeName, gatherLayerName, layer);
- }
}
void CLDNNPlugin::Program::CreateGatherTreePrimitive(cldnn::topology & topology, InferenceEngine::CNNLayerPtr & layer) {
namespace {
-const std::vector<InferenceEngine::Precision> netPrecisions = {
+const std::vector<InferenceEngine::Precision> netPrecisionsFP32 = {
InferenceEngine::Precision::FP32,
};
-const std::vector<std::vector<size_t>> inputShapes = {
- std::vector<size_t>{10, 20, 30, 40},
+const std::vector<InferenceEngine::Precision> netPrecisionsI32 = {
+ InferenceEngine::Precision::I32,
+};
+
+const std::vector<InferenceEngine::Precision> netPrecisionsFP16 = {
+ InferenceEngine::Precision::FP16,
};
const std::vector<std::vector<int>> indices = {
std::vector<int>{0, 3, 2, 1},
};
-const std::vector<std::vector<size_t>> indicesShapes = {
- std::vector<size_t>{4}
- // 5d output not supported yet
- // std::vector<size_t>{2, 2}
+const std::vector<std::vector<size_t>> indicesShapes12 = {
+ std::vector<size_t>{4},
+ std::vector<size_t>{2, 2}
+};
+
+const std::vector<std::vector<size_t>> indicesShapes1 = {
+ std::vector<size_t>{4},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes5 = {
+ std::vector<size_t>{5, 6, 7, 8, 9, 10},
+ std::vector<size_t>{1, 1, 7, 8, 9, 10},
+ std::vector<size_t>{5, 1, 1, 8, 9, 10},
+ std::vector<size_t>{5, 6, 1, 1, 9, 10},
+ std::vector<size_t>{5, 6, 7, 1, 1, 10},
+ std::vector<size_t>{1, 6, 1, 8, 9, 10},
+ std::vector<size_t>{5, 1, 7, 1, 9, 10},
+ std::vector<size_t>{5, 6, 1, 8, 1, 10},
+ std::vector<size_t>{1, 6, 7, 1, 9, 10},
+ std::vector<size_t>{5, 1, 7, 8, 1, 10},
+ std::vector<size_t>{1, 6, 7, 8, 1, 10},
+};
+
+const std::vector<int> axes5 = {5};
+
+const auto Gather6dAxes5 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes1),
+ testing::ValuesIn(axes5),
+ testing::ValuesIn(inputShapes6DAxes5),
+ testing::ValuesIn(netPrecisionsFP32),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes4 = {
+ std::vector<size_t>{5, 6, 7, 8, 9},
+ std::vector<size_t>{1, 6, 7, 8, 9},
+ std::vector<size_t>{5, 1, 7, 8, 9},
+ std::vector<size_t>{5, 6, 1, 8, 9},
+ std::vector<size_t>{5, 6, 7, 1, 9},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes4 = {
+ std::vector<size_t>{5, 6, 7, 8, 9, 10},
+ std::vector<size_t>{1, 1, 7, 8, 9, 10},
+ std::vector<size_t>{5, 1, 1, 8, 9, 10},
+ std::vector<size_t>{5, 6, 1, 1, 9, 10},
+ std::vector<size_t>{5, 6, 7, 1, 9, 1},
+ std::vector<size_t>{1, 6, 1, 8, 9, 10},
+ std::vector<size_t>{5, 1, 7, 1, 9, 10},
+ std::vector<size_t>{5, 6, 1, 8, 9, 1},
+ std::vector<size_t>{1, 6, 7, 1, 9, 10},
+ std::vector<size_t>{5, 1, 7, 8, 9, 1},
+ std::vector<size_t>{1, 6, 7, 8, 9, 1},
+};
+
+const std::vector<int> axes4 = {4};
+
+const auto GatherAxes4 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes12),
+ testing::ValuesIn(axes4),
+ testing::ValuesIn(inputShapesAxes4),
+ testing::ValuesIn(netPrecisionsFP16),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+ GatherAxes4,
+ GatherLayerTest,
+ GatherAxes4,
+ GatherLayerTest::getTestCaseName
+);
+
+const auto Gather6dAxes4 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes1),
+ testing::ValuesIn(axes4),
+ testing::ValuesIn(inputShapes6DAxes4),
+ testing::ValuesIn(netPrecisionsFP32),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+ Gather6dAxes4,
+ GatherLayerTest,
+ Gather6dAxes4,
+ GatherLayerTest::getTestCaseName
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes3 = {
+ std::vector<size_t>{5, 6, 7, 8},
+ std::vector<size_t>{1, 6, 7, 8},
+ std::vector<size_t>{5, 1, 7, 8},
+ std::vector<size_t>{5, 6, 1, 8},
+ std::vector<size_t>{5, 6, 7, 8, 9},
+ std::vector<size_t>{1, 6, 7, 8, 9},
+ std::vector<size_t>{5, 1, 7, 8, 9},
+ std::vector<size_t>{5, 6, 1, 8, 9},
+ std::vector<size_t>{5, 6, 7, 8, 1},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes3 = {
+ std::vector<size_t>{5, 6, 7, 8, 9, 10},
+ std::vector<size_t>{1, 1, 7, 8, 9, 10},
+ std::vector<size_t>{5, 1, 1, 8, 9, 10},
+ std::vector<size_t>{5, 6, 1, 8, 1, 10},
+ std::vector<size_t>{5, 6, 7, 8, 1, 1},
+ std::vector<size_t>{1, 6, 1, 8, 9, 10},
+ std::vector<size_t>{5, 1, 7, 8, 1, 10},
+ std::vector<size_t>{5, 6, 1, 8, 9, 1},
+ std::vector<size_t>{1, 6, 7, 8, 1, 10},
+ std::vector<size_t>{5, 1, 7, 8, 9, 1},
+ std::vector<size_t>{1, 6, 7, 8, 9, 1},
};
-const std::vector<int> axes = {0, 1, 2, 3};
+const std::vector<int> axes3 = {3};
+
+const auto GatherAxes3 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes12),
+ testing::ValuesIn(axes3),
+ testing::ValuesIn(inputShapesAxes3),
+ testing::ValuesIn(netPrecisionsFP32),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+ GatherAxes3,
+ GatherLayerTest,
+ GatherAxes3,
+ GatherLayerTest::getTestCaseName
+);
+
+const auto Gather6dAxes3 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes1),
+ testing::ValuesIn(axes3),
+ testing::ValuesIn(inputShapes6DAxes3),
+ testing::ValuesIn(netPrecisionsI32),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+ Gather6dAxes3,
+ GatherLayerTest,
+ Gather6dAxes3,
+ GatherLayerTest::getTestCaseName
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes2 = {
+ std::vector<size_t>{5, 6, 7, 8},
+ std::vector<size_t>{1, 6, 7, 8},
+ std::vector<size_t>{5, 1, 7, 8},
+ std::vector<size_t>{5, 6, 7, 1},
+ std::vector<size_t>{5, 6, 7, 8, 9},
+ std::vector<size_t>{1, 6, 7, 8, 9},
+ std::vector<size_t>{5, 1, 7, 8, 9},
+ std::vector<size_t>{5, 6, 7, 1, 9},
+ std::vector<size_t>{5, 6, 7, 8, 1},
+};
+const std::vector<std::vector<size_t>> inputShapes6DAxes2 = {
+ std::vector<size_t>{5, 6, 7, 8, 9, 10},
+ std::vector<size_t>{1, 1, 7, 8, 9, 10},
+ std::vector<size_t>{5, 1, 7, 1, 9, 10},
+ std::vector<size_t>{5, 6, 7, 1, 1, 10},
+ std::vector<size_t>{5, 6, 7, 8, 1, 1},
+ std::vector<size_t>{1, 6, 7, 1, 9, 10},
+ std::vector<size_t>{5, 1, 7, 8, 1, 10},
+ std::vector<size_t>{5, 6, 7, 1, 9, 1},
+ std::vector<size_t>{1, 6, 7, 8, 1, 10},
+ std::vector<size_t>{5, 1, 7, 8, 9, 1},
+ std::vector<size_t>{1, 6, 7, 8, 9, 1},
+};
+
+const std::vector<int> axes2 = {2};
+
+const auto GatherAxes2 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes12),
+ testing::ValuesIn(axes2),
+ testing::ValuesIn(inputShapesAxes2),
+ testing::ValuesIn(netPrecisionsFP32),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+ GatherAxes2,
+ GatherLayerTest,
+ GatherAxes2,
+ GatherLayerTest::getTestCaseName
+);
+
+const auto Gather6dAxes2 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes1),
+ testing::ValuesIn(axes2),
+ testing::ValuesIn(inputShapes6DAxes2),
+ testing::ValuesIn(netPrecisionsFP16),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+ Gather6dAxes2,
+ GatherLayerTest,
+ Gather6dAxes2,
+ GatherLayerTest::getTestCaseName
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes1 = {
+ std::vector<size_t>{5, 6, 7, 8},
+ std::vector<size_t>{1, 6, 7, 8},
+ std::vector<size_t>{5, 6, 1, 8},
+ std::vector<size_t>{5, 6, 7, 1},
+ std::vector<size_t>{5, 6, 7, 8, 9},
+ std::vector<size_t>{1, 6, 7, 8, 9},
+ std::vector<size_t>{5, 6, 1, 8, 9},
+ std::vector<size_t>{5, 6, 7, 1, 9},
+ std::vector<size_t>{5, 6, 7, 8, 1},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes1 = {
+ std::vector<size_t>{5, 6, 7, 8, 9, 10},
+ std::vector<size_t>{1, 6, 1, 8, 9, 10},
+ std::vector<size_t>{5, 6, 1, 1, 9, 10},
+ std::vector<size_t>{5, 6, 7, 1, 1, 10},
+ std::vector<size_t>{5, 6, 7, 8, 1, 1},
+ std::vector<size_t>{1, 6, 7, 1, 9, 10},
+ std::vector<size_t>{5, 6, 1, 8, 1, 10},
+ std::vector<size_t>{5, 6, 1, 8, 9, 1},
+ std::vector<size_t>{1, 6, 7, 8, 1, 10},
+ std::vector<size_t>{1, 6, 7, 8, 9, 1},
+ std::vector<size_t>{5, 6, 7, 1, 9, 1},
+};
+
+const std::vector<int> axes1 = {1};
+
+const auto GatherAxes1 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes12),
+ testing::ValuesIn(axes1),
+ testing::ValuesIn(inputShapesAxes1),
+ testing::ValuesIn(netPrecisionsI32),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+ GatherAxes1,
+ GatherLayerTest,
+ GatherAxes1,
+ GatherLayerTest::getTestCaseName
+);
+
+const auto Gather6dAxes1 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes1),
+ testing::ValuesIn(axes1),
+ testing::ValuesIn(inputShapes6DAxes1),
+ testing::ValuesIn(netPrecisionsFP32),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+ Gather6dAxes1,
+ GatherLayerTest,
+ Gather6dAxes1,
+ GatherLayerTest::getTestCaseName
+);
+
+const std::vector<std::vector<size_t>> inputShapesAxes0 = {
+ std::vector<size_t>{5, 6, 7, 8},
+ std::vector<size_t>{5, 1, 7, 8},
+ std::vector<size_t>{5, 6, 1, 8},
+ std::vector<size_t>{5, 6, 7, 1},
+ std::vector<size_t>{5, 6, 7, 8, 9},
+ std::vector<size_t>{5, 1, 7, 8, 9},
+ std::vector<size_t>{5, 6, 1, 8, 9},
+ std::vector<size_t>{5, 6, 7, 1, 9},
+ std::vector<size_t>{5, 6, 7, 8, 1},
+};
+
+const std::vector<std::vector<size_t>> inputShapes6DAxes0 = {
+ std::vector<size_t>{5, 6, 7, 8, 9, 10},
+ std::vector<size_t>{5, 1, 1, 8, 9, 10},
+ std::vector<size_t>{5, 6, 1, 1, 9, 10},
+ std::vector<size_t>{5, 6, 7, 1, 1, 10},
+ std::vector<size_t>{5, 6, 7, 8, 1, 1},
+ std::vector<size_t>{5, 1, 7, 1, 9, 10},
+ std::vector<size_t>{5, 6, 1, 8, 1, 10},
+ std::vector<size_t>{5, 6, 1, 8, 9, 1},
+ std::vector<size_t>{5, 1, 7, 8, 1, 10},
+ std::vector<size_t>{5, 1, 7, 8, 9, 1},
+ std::vector<size_t>{5, 6, 7, 1, 9, 1},
+};
+
+const std::vector<int> axes0 = {0};
+
+const auto GatherAxes0 = testing::Combine(
+ testing::ValuesIn(indices),
+ testing::ValuesIn(indicesShapes12),
+ testing::ValuesIn(axes0),
+ testing::ValuesIn(inputShapesAxes0),
+ testing::ValuesIn(netPrecisionsFP32),
+ testing::Values(CommonTestUtils::DEVICE_GPU)
+);
+
+INSTANTIATE_TEST_CASE_P(
+ GatherAxes0,
+ GatherLayerTest,
+ GatherAxes0,
+ GatherLayerTest::getTestCaseName
+);
-const auto params = testing::Combine(
+const auto Gather6dAxes0 = testing::Combine(
testing::ValuesIn(indices),
- testing::ValuesIn(indicesShapes),
- testing::ValuesIn(axes),
- testing::ValuesIn(inputShapes),
- testing::ValuesIn(netPrecisions),
+ testing::ValuesIn(indicesShapes1),
+ testing::ValuesIn(axes0),
+ testing::ValuesIn(inputShapes6DAxes0),
+ testing::ValuesIn(netPrecisionsFP32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
- Gather,
+ Gather6dAxes0,
GatherLayerTest,
- params,
+ Gather6dAxes0,
GatherLayerTest::getTestCaseName
);
INSTANTIATE_TEST_CASE_P(
smoke_GPU_TestsGather, GatherTFTests,
::testing::Values(
- gatherTF_test_params{ "GPU", "FP32", { 1, 4 }, in0,{ 2, 2 }, dict2D, 0, { 1, 4, 2 }, ref_in0_a0_d22 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict, 0, { 2, 2, 2, 3 }, ref_in0_a0_d223 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict,-3, { 2, 2, 2, 3 }, ref_in0_a0_d223 },
+ gatherTF_test_params{ "GPU", "FP32", { 1, 4 }, in0,{ 2, 2 }, dict2D, 0, { 1, 4, 1, 2 }, ref_in0_a0_d22 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict, 0, { 2, 2, 1, 2, 3 }, ref_in0_a0_d223 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict,-3, { 2, 2, 1, 2, 3 }, ref_in0_a0_d223 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict, 0, { 2, 2, 2, 2 }, ref_in1_a0_d322 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict,-3, { 2, 2, 2, 2 }, ref_in1_a0_d322 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict, 1, { 2, 2, 2, 2 }, ref_in1_a1_d232 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict,-2, { 2, 2, 2, 2 }, ref_in1_a1_d232 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict, 0, { 2, 2, 1, 2, 2 }, ref_in1_a0_d322 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict,-3, { 2, 2, 1, 2, 2 }, ref_in1_a0_d322 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict, 1, { 2, 2, 2, 2, 1 }, ref_in1_a1_d232 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict,-2, { 2, 2, 2, 2, 1 }, ref_in1_a1_d232 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict, 2, { 2, 2, 2, 2 }, ref_in1_a2_d223 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict,-1, { 2, 2, 2, 2 }, ref_in1_a2_d223 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2 }, ref_in0_a2_d232 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict,-1, { 2, 3, 2, 2 }, ref_in0_a2_d232 },
- gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2 }, ref_in0_a2_d232 }
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict, 2, { 2, 2, 2, 2, 1 }, ref_in1_a2_d223 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict,-1, { 2, 2, 2, 2, 1 }, ref_in1_a2_d223 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2, 1 }, ref_in0_a2_d232 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict,-1, { 2, 3, 2, 2, 1 }, ref_in0_a2_d232 },
+ gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2, 1 }, ref_in0_a2_d232 }
));
along_b,
along_f,
along_x,
- along_y
+ along_y,
+ along_z,
+ along_w
};
/// @brief Constructs gather primitive.
enum class GatherAxis {
X,
Y,
+ Z,
+ W,
FEATURE,
BATCH,
};
static size_t GetGatherChannelIndex(const gather_params& params) {
Tensor::DataChannelName name = Tensor::DataChannelName::X;
+ size_t inputSize = params.inputs[0].GetDims().size();
+
switch (params.axis) {
case GatherAxis::X:
- return 3;
+ return inputSize - 1;
case GatherAxis::Y:
+ return inputSize - 2;
+ case GatherAxis::Z:
+ return inputSize - 3;
+ case GatherAxis::W:
return 2;
case GatherAxis::FEATURE:
return 1;
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
+ k.EnableInputLayout(DataLayout::bfzyx);
+ k.EnableOutputLayout(DataLayout::bfzyx);
+ k.EnableInputLayout(DataLayout::bfwzyx);
+ k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
std::string order_str = order[0];
for (size_t i = 1; i < order.size(); i++)
order_str += ", " + order[i];
-
+
return order_str;
}
+static inline std::vector<std::string> GetOrder(size_t size) {
+ std::vector<std::string> idx_order;
+ if (size <= 4) {
+ idx_order = {"b", "f", "y", "x"};
+ } else if (size == 5) {
+ idx_order = {"b", "f", "z", "y", "x"};
+ } else if (size == 6) {
+ idx_order = {"b", "f", "w", "z", "y", "x"};
+ }
+
+ return idx_order;
+}
+
static std::string GetDictionaryIndexOrder(const gather_params& params, size_t axis) {
- std::vector<std::string> default_order = { "b", "f", "y", "x" };
+ std::vector<std::string> idx_order = GetOrder(params.output.GetDims().size());
+
const std::string input_axis_index_macro = "INPUT_AXIS_INDEX";
const std::string zeroVal = "0";
// Shift indices of Gather dictionary input related to output dims
for (size_t i = axis + 1; i < dictionary_dims_num; i++)
- default_order[i] = default_order[i + indices_dims_num - 1];
+ idx_order[i] = idx_order[i + indices_dims_num - 1];
- for (size_t i = dictionary_dims_num; i < default_order.size(); i++)
- default_order[i] = zeroVal;
+ for (size_t i = dictionary_dims_num; i < idx_order.size(); i++)
+ idx_order[i] = zeroVal;
+
+ // Fix size to inputs[0] dims size
+ for (size_t i = 0; i < params.output.GetDims().size() - params.inputs[0].GetDims().size(); i++)
+ idx_order.pop_back();
- default_order[axis] = input_axis_index_macro;
+ idx_order[axis] = input_axis_index_macro;
- return GetOrderString(default_order);
+ return GetOrderString(idx_order);
}
static std::string GetIndecesIdxOrder(const gather_params& params, size_t axis) {
- std::vector<std::string> default_order = { "b", "f", "y", "x" };
+ std::vector<std::string> idx_order = GetOrder(params.output.GetDims().size());
+
const std::string zero_val = "0";
size_t indices_dims_num = GetNonEmptyDimsNumber(params.inputs[1]);
// Shift indices of Gather indices input related to output dims
for (size_t i = 0; i < indices_dims_num; i++)
- default_order[i] = default_order[axis + i];
+ idx_order[i] = idx_order[axis + i];
- for (size_t i = indices_dims_num; i < default_order.size(); i++)
- default_order[i] = zero_val;
+ for (size_t i = indices_dims_num; i < idx_order.size(); i++)
+ idx_order[i] = zero_val;
- return GetOrderString(default_order);
+ // Fix size to inputs[1] dims size
+ for (size_t i = 0; i < params.output.GetDims().size() - params.inputs[1].GetDims().size(); i++)
+ idx_order.pop_back();
+
+ return GetOrderString(idx_order);
}
CommonDispatchData GatherKernelRef::SetDefault(const gather_params& params, const optional_params&) const {
CommonDispatchData runInfo;
const auto& output = params.output;
- std::vector<size_t> global = {output.Batch().v, output.Feature().v,output.X().v * output.Y().v};
- std::vector<size_t> local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
+ std::vector<size_t> global;
+ std::vector<size_t> local;
+
+ if (output.GetLayout() == DataLayout::bfyx) {
+ global = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
+ } else if (output.GetLayout() == DataLayout::bfzyx) {
+ global = {output.X().v, output.Y().v * output.Z().v, output.Feature().v * output.Batch().v};
+ } else {
+ global = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
+ }
+
+ local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
runInfo.gws0 = global[0];
runInfo.gws1 = global[1];
jit.AddConstant(MakeJitConstant("INDICES_INDEX_ORDER", GetIndecesIdxOrder(params, GetGatherChannelIndex(params))));
if (!params.fused_ops.empty()) {
- FusedOpsConfiguration conf = { "", {"b", "f", "y", "x"}, "val", params.inputs[0].GetDType() };
+ std::vector<std::string> idx_order = GetOrder(params.inputs[0].GetDims().size());
+
+ FusedOpsConfiguration conf = { "", idx_order, "val", params.inputs[0].GetDType() };
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
}
#define INPUT_AXIS_INDEX (uint)indices[indices_idx]
#define GET_DICTIONARY_INDEX(idx_order) INPUT0_GET_INDEX(idx_order)
#define GET_INDICES_INDEX(idx_order) INPUT1_GET_INDEX(idx_order)
+#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order)
KERNEL(gather_ref)(const __global INPUT0_TYPE* dictionary,
const __global INPUT1_TYPE* indices,
#endif
)
{
- const uint b = get_global_id(0);
- const uint f = get_global_id(1);
- const uint yx = get_global_id(2);
- const uint y = yx / OUTPUT_SIZE_X;
- const uint x = yx % OUTPUT_SIZE_X;
+ #if OUTPUT_DIMS == 6
+ #define ORDER b,f,w,z,y,x
+ const uint x = (uint)get_global_id(0) % OUTPUT_SIZE_X;
+ const uint y = (uint)get_global_id(0) / OUTPUT_SIZE_X;
+ const uint z = (uint)get_global_id(1) % OUTPUT_SIZE_Z;
+ const uint w = (uint)get_global_id(1) / OUTPUT_SIZE_Z;
+ const uint f = (uint)get_global_id(2) % OUTPUT_FEATURE_NUM;
+ const uint b = (uint)get_global_id(2) / OUTPUT_FEATURE_NUM;
+ #elif OUTPUT_DIMS == 5
+ #define ORDER b,f,z,y,x
+ const uint x = (uint)get_global_id(0);
+ const uint y = (uint)get_global_id(1) % OUTPUT_SIZE_Y;
+ const uint z = (uint)get_global_id(1) / OUTPUT_SIZE_Y;
+ const uint f = (uint)get_global_id(2) % OUTPUT_FEATURE_NUM;
+ const uint b = (uint)get_global_id(2) / OUTPUT_FEATURE_NUM;
+ #elif OUTPUT_DIMS == 4
+ #define ORDER b,f,y,x
+ const uint x = (uint)get_global_id(0);
+ const uint y = (uint)get_global_id(1);
+ const uint f = (uint)get_global_id(2) % OUTPUT_FEATURE_NUM;
+ const uint b = (uint)get_global_id(2) / OUTPUT_FEATURE_NUM;
+ #endif
const uint indices_idx = GET_INDICES_INDEX(INDICES_INDEX_ORDER);
const uint dictionary_idx = GET_DICTIONARY_INDEX(DICTIONARY_INDEX_ORDER);
- const uint output_idx = OUTPUT_GET_INDEX(b, f, y, x);
+ const uint output_idx = GET_INDEX(OUTPUT,,ORDER);
INPUT0_TYPE val = dictionary[dictionary_idx];
auto desc = node.get_primitive();
auto input_layout = node.input(0).get_output_layout();
- auto input_format = input_layout.format;
-
auto output_shape = desc->output_shape;
+ auto output_format = input_layout.format;
+
+ int spatialNum = 0;
+ for (auto i : node.input(1).get_output_layout().size.raw)
+ spatialNum += (i > 1) ? 1 : 0;
+
+ // change output format if input indeces > 1
+ if (spatialNum == 2 && output_format == cldnn::format::bfzyx) {
+ output_format = cldnn::format::bfwzyx;
+ } else if (spatialNum == 2 && output_format == cldnn::format::bfyx) {
+ output_format = cldnn::format::bfzyx;
+ } else if (spatialNum == 3 && output_format == cldnn::format::bfyx) {
+ output_format = cldnn::format::bfwzyx;
+ }
auto output_type = input_layout.data_type;
if (node.has_fused_primitives()) {
output_type = node.get_fused_output_layout().data_type;
}
- return layout{output_type, input_format, output_shape};
+ return layout{output_type, output_format, output_shape};
}
std::string gather_inst::to_string(gather_node const& node) {
return kernel_selector::gather_axis::X;
case gather::along_y:
return kernel_selector::gather_axis::Y;
+ case gather::along_z:
+ return kernel_selector::gather_axis::Z;
+ case gather::along_w:
+ return kernel_selector::gather_axis::W;
case gather::along_f:
return kernel_selector::gather_axis::FEATURE;
case gather::along_b:
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
+
+ implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
+ implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
+ implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
+
+ implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
+ implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
+ implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
}
} // namespace detail
#define CASE_GATHER_FP16_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GATHER_FP16_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
+#define CASE_GATHER_5D_FP32_1 {2, 3, 1, 4, 1}, {4, 1, 1, 1}, {4, 3, 1, 4, 1}, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_GATHER_5D_FP32_2 {2, 3, 2, 2, 2}, {2, 1, 1, 1}, {2, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_GATHER_5D_FP32_3 {5, 3, 2, 2, 2}, {3, 1, 1, 1}, {5, 3, 2, 3, 2}, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_GATHER_5D_FP32_4 {2, 3, 1, 4, 4}, {2, 1, 1, 1}, {2, 3, 1, 4, 2}, cldnn::gather::gather_axis::along_z, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+#define CASE_GATHER_5D_FP32_5 {3, 1, 5, 2, 1}, {2, 1, 1, 1}, {3, 1, 2, 2, 1}, cldnn::gather::gather_axis::along_x, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
+
+#define CASE_GATHER_5D_FP16_1 {3, 2, 1, 2, 1}, {2, 1, 1, 1}, {2, 2, 2, 2, 1}, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_GATHER_5D_FP16_2 {1, 3, 1, 2, 1}, {2, 1, 1, 1}, {1, 2, 1, 2, 1}, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_GATHER_5D_FP16_3 {2, 3, 1, 3, 3}, {1, 2, 1, 1}, {2, 3, 1, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_GATHER_5D_FP16_4 {3, 2, 2, 2, 2}, {2, 1, 1, 1}, {3, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_z, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+#define CASE_GATHER_5D_FP16_5 {1, 1, 2, 1, 1}, {3, 1, 1, 1}, {1, 1, 3, 1, 1}, cldnn::gather::gather_axis::along_x, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
+
class GatherPrimitiveFusingTest : public ::BaseFusingTest<gather_test_params> {
public:
void execute(gather_test_params& p) {
return p.dictionary_shape.spatial[0];
case cldnn::gather::gather_axis::along_y:
return p.dictionary_shape.spatial[1];
+ case cldnn::gather::gather_axis::along_z:
+ return p.dictionary_shape.spatial[2];
+ case cldnn::gather::gather_axis::along_w:
+ return p.dictionary_shape.spatial[3];
case cldnn::gather::gather_axis::along_f:
return p.dictionary_shape.feature[0];
case cldnn::gather::gather_axis::along_b:
gather_test_params{ CASE_GATHER_FP16_3, 2, 3 },
gather_test_params{ CASE_GATHER_FP16_4, 2, 3 },
gather_test_params{ CASE_GATHER_FP16_5, 2, 3 },
+
+ gather_test_params{ CASE_GATHER_5D_FP32_1, 2, 3 },
+ gather_test_params{ CASE_GATHER_5D_FP32_2, 2, 3 },
+ gather_test_params{ CASE_GATHER_5D_FP32_3, 2, 3 },
+ gather_test_params{ CASE_GATHER_5D_FP32_4, 2, 3 },
+ gather_test_params{ CASE_GATHER_5D_FP32_5, 2, 3 },
+
+ gather_test_params{ CASE_GATHER_5D_FP16_1, 2, 3 },
+ gather_test_params{ CASE_GATHER_5D_FP16_2, 2, 3 },
+ gather_test_params{ CASE_GATHER_5D_FP16_3, 2, 3 },
+ gather_test_params{ CASE_GATHER_5D_FP16_4, 2, 3 },
+ gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 3 },
}), );
class gather_scale_activation : public GatherPrimitiveFusingTest {};
gather_test_params{ CASE_GATHER_FP16_3, 2, 4 },
gather_test_params{ CASE_GATHER_FP16_4, 2, 4 },
gather_test_params{ CASE_GATHER_FP16_5, 2, 4 },
+
+ gather_test_params{ CASE_GATHER_5D_FP32_1, 2, 4 },
+ gather_test_params{ CASE_GATHER_5D_FP32_2, 2, 4 },
+ gather_test_params{ CASE_GATHER_5D_FP32_3, 2, 4 },
+ gather_test_params{ CASE_GATHER_5D_FP32_4, 2, 4 },
+ gather_test_params{ CASE_GATHER_5D_FP32_5, 2, 4 },
+
+ gather_test_params{ CASE_GATHER_5D_FP16_1, 2, 4 },
+ gather_test_params{ CASE_GATHER_5D_FP16_2, 2, 4 },
+ gather_test_params{ CASE_GATHER_5D_FP16_3, 2, 4 },
+ gather_test_params{ CASE_GATHER_5D_FP16_4, 2, 4 },
+ gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 4 },
}), );
/* ------------------------------------------------------------------------------------------------------------ */
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
- gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+ gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
- gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+ gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
- gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+ gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
);
network network(engine, topology);
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
- gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+ gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
- gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+ gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
- gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+ gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
);
network network(engine, topology);
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
- gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+ gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
);
network network(engine, topology);
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
- gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+ gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
- gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
+ gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);