From: Wen-Heng (Jack) Chung Date: Mon, 22 Jun 2020 14:52:12 +0000 (-0500) Subject: Fix a corner case in vector.shape_cast when the trailing dimensions are of size 1. X-Git-Tag: llvmorg-12-init~2277 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6bb4fc93c2fd7f63c7ed430928d1b85bfd4b3d79;p=platform%2Fupstream%2Fllvm.git Fix a corner case in vector.shape_cast when the trailing dimensions are of size 1. Differential Revision: https://reviews.llvm.org/D82304 --- diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 019f5fd..5d3a916 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1633,6 +1633,14 @@ static bool isValidShapeCast(ArrayRef a, ArrayRef b) { if (dimA != dimB) break; ++i; + + // Handle the case when trailing dimensions are of size 1. + // Include them into the contiguous sequence. + auto isOne = [](int64_t v) { return v == 1; }; + if (i < rankA && llvm::all_of(a.slice(i), isOne)) + i = rankA; + if (j < rankB && llvm::all_of(b.slice(j), isOne)) + j = rankB; } return i == rankA && j == rankB; diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 02ee4dd..4ea7286 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -266,8 +266,10 @@ func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) { // CHECK-LABEL: @shape_cast func @shape_cast(%arg0 : vector<5x1x3x2xf32>, - %arg1 : tuple, vector<3x4x2xf32>>) - -> (vector<15x2xf32>, tuple, vector<12x2xf32>>) { + %arg1 : tuple, vector<3x4x2xf32>>, + %arg2 : vector<8x1xf32>, + %arg3 : vector<16x1x1xf32>) + -> (vector<15x2xf32>, tuple, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>) { // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<15x2xf32> %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32> @@ -276,7 +278,16 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>, %1 = vector.shape_cast %arg1 : tuple, vector<3x4x2xf32>> to tuple, vector<12x2xf32>> - return %0, %1 : vector<15x2xf32>, tuple, vector<12x2xf32>> + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x1xf32> to vector<8xf32> + %2 = vector.shape_cast %arg2 : vector<8x1xf32> to vector<8xf32> + + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<16x1x1xf32> to vector<16xf32> + %3 = vector.shape_cast %arg3 : vector<16x1x1xf32> to vector<16xf32> + + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<16x1x1xf32> to vector<16x1xf32> + %4 = vector.shape_cast %arg3 : vector<16x1x1xf32> to vector<16x1xf32> + + return %0, %1, %2, %3, %4 : vector<15x2xf32>, tuple, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32> } // CHECK-LABEL: @vector_fma