From f44c76d6e919641655615d62ea8b432175571a0b Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Mon, 3 May 2021 10:47:02 -0700 Subject: [PATCH] [mlir][vector] Extend vector transfer unrolling to support permutations and broadcast Differential Revision: https://reviews.llvm.org/D101637 --- mlir/lib/Dialect/Vector/VectorTransforms.cpp | 82 +++++++------------ .../Dialect/Vector/vector-transfer-unroll.mlir | 93 +++++++++++++++++++++- 2 files changed, 119 insertions(+), 56 deletions(-) diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 7501797..2c8b3379 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -516,10 +516,12 @@ static void getVectorElementwiseOpUnrollState(Operation *op, /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and /// calls 'fn' with linear index and indices for each slice. -static void generateTransferOpSlices( - Type shapedElementType, VectorType vectorType, TupleType tupleType, - ArrayRef sizes, ArrayRef strides, ArrayRef indices, - OpBuilder &builder, function_ref)> fn) { +static void +generateTransferOpSlices(Type shapedElementType, VectorType vectorType, + TupleType tupleType, ArrayRef sizes, + ArrayRef strides, ArrayRef indices, + AffineMap permutationMap, OpBuilder &builder, + function_ref)> fn) { // Compute strides w.r.t. to slice counts in each dimension. auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); assert(maybeDimSliceCounts.hasValue()); @@ -527,7 +529,6 @@ static void generateTransferOpSlices( auto sliceStrides = computeStrides(sliceDimCounts); int64_t numSlices = tupleType.size(); - unsigned numSliceIndices = indices.size(); // Compute 'indexOffset' at which to update 'indices', which is equal // to the memref rank (indices.size) minus the effective 'vectorRank'. // The effective 'vectorRank', is equal to the rank of the vector type @@ -545,57 +546,38 @@ static void generateTransferOpSlices( assert(vectorRank >= sourceVectorElementType.getRank()); vectorRank -= sourceVectorElementType.getRank(); } - unsigned indexOffset = numSliceIndices - vectorRank; - + auto isBroadcast = [](AffineExpr expr) { + if (auto constExpr = expr.dyn_cast()) + return constExpr.getValue() == 0; + return false; + }; auto *ctx = builder.getContext(); for (unsigned i = 0; i < numSlices; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); auto elementOffsets = computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. - SmallVector sliceIndices(numSliceIndices); - for (unsigned j = 0; j < numSliceIndices; ++j) { - if (j < indexOffset) { - sliceIndices[j] = indices[j]; - } else { - auto expr = getAffineDimExpr(0, ctx) + - getAffineConstantExpr(elementOffsets[j - indexOffset], ctx); - auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - sliceIndices[j] = builder.create( - indices[j].getLoc(), map, ArrayRef(indices[j])); - } + SmallVector sliceIndices(indices.begin(), indices.end()); + for (auto dim : llvm::enumerate(permutationMap.getResults())) { + if (isBroadcast(dim.value())) + continue; + unsigned pos = dim.value().cast().getPosition(); + auto expr = getAffineDimExpr(0, ctx) + + getAffineConstantExpr(elementOffsets[dim.index()], ctx); + auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + sliceIndices[pos] = builder.create( + indices[pos].getLoc(), map, ArrayRef(indices[pos])); } // Call 'fn' to generate slice 'i' at 'sliceIndices'. fn(i, sliceIndices); } } -/// Returns true if 'map' is a suffix of an identity affine map, false -/// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)> -static bool isIdentitySuffix(AffineMap map) { - if (map.getNumDims() < map.getNumResults()) - return false; - ArrayRef results = map.getResults(); - Optional lastPos; - for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { - auto expr = results[i].dyn_cast(); - if (!expr) - return false; - int currPos = static_cast(expr.getPosition()); - if (lastPos.hasValue() && currPos != lastPos.getValue() + 1) - return false; - lastPos = currPos; - } - return true; -} - /// Unroll transfer_read ops to the given shape and create an aggregate with all /// the chunks. static Value unrollTransferReadOp(vector::TransferReadOp readOp, ArrayRef targetShape, OpBuilder &builder) { - if (!isIdentitySuffix(readOp.permutation_map())) - return nullptr; if (readOp.mask()) return nullptr; auto sourceVectorType = readOp.getVectorType(); @@ -623,7 +605,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp, readOp.in_bounds() ? *readOp.in_bounds() : ArrayAttr()); }; generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType, - targetShape, strides, indices, builder, createSlice); + targetShape, strides, indices, + readOp.permutation_map(), builder, createSlice); // Create tuple of splice transfer read operations. Value tupleOp = @@ -641,8 +624,6 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op, ArrayRef targetShape, SmallVector &result) { auto writeOp = cast(op); - if (!isIdentitySuffix(writeOp.permutation_map())) - return failure(); if (writeOp.mask()) return failure(); VectorType sourceVectorType = writeOp.getVectorType(); @@ -671,7 +652,8 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op, resultTensor = write->getResult(0); }; generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType, - targetShape, strides, indices, builder, createSlice); + targetShape, strides, indices, + writeOp.permutation_map(), builder, createSlice); if (resultTensor) result.push_back(resultTensor); return success(); @@ -729,11 +711,6 @@ public: if (readOp.mask()) return failure(); - // TODO: Support splitting TransferReadOp with non-identity permutation - // maps. Repurpose code from MaterializeVectors transformation. - if (!isIdentitySuffix(readOp.permutation_map())) - return failure(); - // Return unless there is only one user, and it is an ExtractSlicesOp. Value readResult = readOp.getResult(); if (!readResult.hasOneUse()) @@ -778,11 +755,6 @@ public: if (writeOp.mask()) return failure(); - // TODO: Support splitting TransferWriteOp with non-identity permutation - // maps. Repurpose code from MaterializeVectors transformation. - if (!isIdentitySuffix(writeOp.permutation_map())) - return failure(); - // Fail to match unless this is writing a vector resulting from an // InsertSlicesOp. auto insertSlicesOp = @@ -821,8 +793,8 @@ public: resultTensor = write->getResult(0); }; generateTransferOpSlices(shapedElementType, resultVectorType, - sourceTupleType, sizes, strides, indices, rewriter, - createSlice); + sourceTupleType, sizes, strides, indices, + writeOp.permutation_map(), rewriter, createSlice); if (resultTensor) rewriter.replaceOp(writeOp, ArrayRef(resultTensor)); diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir index d63809c..0929031 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns | FileCheck %s +// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s // CHECK-LABEL: func @transfer_read_unroll // CHECK-DAG: %[[C2:.*]] = constant 2 : index @@ -120,3 +120,94 @@ func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4 %r = vector.transfer_write %0, %arg1[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32> return %r: tensor<4x4xf32> } + +// ----- + +// CHECK-LABEL: func @transfer_read_unroll_permutation +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x6xf32> +// CHECK-NEXT: return %[[VEC]] : vector<4x6xf32> +#map0 = affine_map<(d0, d1) -> (d1, d0)> +func @transfer_read_unroll_permutation(%arg0 : memref<6x4xf32>) -> vector<4x6xf32> { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32> + return %0 : vector<4x6xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_read_unroll_broadcast +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<6x4xf32> +// CHECK-NEXT: return %[[VEC]] : vector<6x4xf32> +#map0 = affine_map<(d0, d1) -> (0, d1)> +func @transfer_read_unroll_broadcast(%arg0 : memref<6x4xf32>) -> vector<6x4xf32> { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32> + return %0 : vector<6x4xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_read_unroll_broadcast_permuation +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x6xf32> +// CHECK-NEXT: return %[[VEC]] : vector<4x6xf32> +#map0 = affine_map<(d0, d1) -> (0, d0)> +func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) -> vector<4x6xf32> { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32> + return %0 : vector<4x6xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_read_unroll_different_rank +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> +// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<6x4xf32> +// CHECK-NEXT: return %[[VEC]] : vector<6x4xf32> +#map0 = affine_map<(d0, d1, d2) -> (d2, d0)> +func @transfer_read_unroll_different_rank(%arg0 : memref) -> vector<6x4xf32> { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref, vector<6x4xf32> + return %0 : vector<6x4xf32> +} -- 2.7.4