From 6dc9725471e05fe12bd72406f97daca49a47a0c0 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Sun, 15 Jan 2023 07:25:00 +0000 Subject: [PATCH] [mlir][vector] Fix lowering of permutation maps for transfer_write op The lowering of transfer write permutation maps didn't match the op definition: https://github.com/llvm/llvm-project/blob/93ccccb00d9717b58ba93f0942a243ba6dac4ef6/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td#L1476 Fix the lowering and add a case to the integration test in order to enforce the correct semantic. Differential Revision: https://reviews.llvm.org/D141801 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 25 +++++++++++ .../Dialect/Linalg/Transforms/Vectorization.cpp | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 34 +++++--------- ...VectorTransferPermutationMapRewritePatterns.cpp | 22 ++++----- .../vector-transfer-to-vector-load-store.mlir | 8 ++-- .../Dialect/Vector/CPU/test-transfer-write.mlir | 52 +++++++++++++++++++++- 6 files changed, 104 insertions(+), 39 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index abf4f88..e0711a4 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1330,6 +1330,18 @@ def Vector_TransferReadOp : memref, vector<32x256xf32> }}} + // or equivalently (rewrite with vector.transpose) + %f0 = arith.constant 0.0f : f32 + for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 step 256 { + affine.for %i2 = 0 to %2 step 32 { + %v0 = vector.transfer_read %A[%i0, %i1, %i2], (%f0) + {permutation_map: (d0, d1, d2) -> (d1, d2)} : + memref, vector<256x32xf32> + %v = vector.transpose %v0, [1, 0] : + vector<256x32xf32> to vector<32x256f32> + }}} + // Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into // vector<128xf32>. The underlying implementation will require a 1-D vector // broadcast: @@ -1485,6 +1497,19 @@ def Vector_TransferWriteOp : vector<16x32x64xf32>, memref }}}} + // or equivalently (rewrite with vector.transpose) + for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 step 32 { + affine.for %i2 = 0 to %2 step 64 { + affine.for %i3 = 0 to %3 step 16 { + %val = `ssa-value` : vector<16x32x64xf32> + %valt = vector.transpose %val, [1, 2, 0] : + vector<16x32x64xf32> -> vector<32x64x16xf32> + vector.transfer_write %valt, %A[%i0, %i1, %i2, %i3] + {permutation_map: (d0, d1, d2, d3) -> (d1, d2, d3)} : + vector<32x64x16xf32>, memref + }}}} + // write to a memref with vector element type. vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0bb31e9..2bc0c2a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -456,7 +456,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, Operation *write; if (vectorType.getRank() > 0) { - AffineMap writeMap = reindexIndexingMap(opOperandMap); + AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); SmallVector indices(linalgOp.getRank(outputOperand), rewriter.create(loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType.getShape()); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 9339452..db2cd82 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3373,12 +3373,12 @@ void TransferReadOp::print(OpAsmPrinter &p) { p << " : " << getShapedType() << ", " << getVectorType(); } -/// Infers the mask type for a transfer read given its vector type and -/// permutation map. The mask in a transfer read operation applies to the -/// tensor/buffer reading part of it and its type should match the shape read +/// Infers the mask type for a transfer op given its vector type and +/// permutation map. The mask in a transfer op operation applies to the +/// tensor/buffer part of it and its type should match the vector shape /// *before* any permutation or broadcasting. -static VectorType inferTransferReadMaskType(VectorType vecType, - AffineMap permMap) { +static VectorType inferTransferOpMaskType(VectorType vecType, + AffineMap permMap) { auto i1Type = IntegerType::get(permMap.getContext(), 1); AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap)); assert(invPermMap && "Inversed permutation map couldn't be computed"); @@ -3436,7 +3436,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { maskInfo.location, "does not support masks with vector element type"); // Instead of adding the mask type as an op type, compute it based on the // vector type and the permutation map (to keep the type signature small). - auto maskType = inferTransferReadMaskType(vectorType, permMap); + auto maskType = inferTransferOpMaskType(vectorType, permMap); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); } @@ -3455,7 +3455,7 @@ LogicalResult TransferReadOp::verify() { auto paddingType = getPadding().getType(); auto permutationMap = getPermutationMap(); VectorType inferredMaskType = - maskType ? inferTransferReadMaskType(vectorType, permutationMap) + maskType ? inferTransferOpMaskType(vectorType, permutationMap) : VectorType(); auto sourceElementType = shapedType.getElementType(); @@ -3495,7 +3495,7 @@ LogicalResult TransferReadOp::verify() { /// Returns the mask type expected by this operation. Mostly used for /// verification purposes. It requires the operation to be vectorized." Type TransferReadOp::getExpectedMaskType() { - return inferTransferReadMaskType(getVectorType(), getPermutationMap()); + return inferTransferOpMaskType(getVectorType(), getPermutationMap()); } template @@ -3836,18 +3836,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result, build(builder, result, vector, dest, indices, permutationMap, inBounds); } -/// Infers the mask type for a transfer write given its vector type and -/// permutation map. The mask in a transfer read operation applies to the -/// tensor/buffer writing part of it and its type should match the shape written -/// *after* any permutation. -static VectorType inferTransferWriteMaskType(VectorType vecType, - AffineMap permMap) { - auto i1Type = IntegerType::get(permMap.getContext(), 1); - SmallVector maskShape = - compressUnusedDims(permMap).compose(vecType.getShape()); - return VectorType::get(maskShape, i1Type); -} - ParseResult TransferWriteOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); @@ -3892,7 +3880,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser, if (shapedType.getElementType().dyn_cast()) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); - auto maskType = inferTransferWriteMaskType(vectorType, permMap); + auto maskType = inferTransferOpMaskType(vectorType, permMap); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); } @@ -3919,7 +3907,7 @@ LogicalResult TransferWriteOp::verify() { VectorType maskType = getMaskType(); auto permutationMap = getPermutationMap(); VectorType inferredMaskType = - maskType ? inferTransferWriteMaskType(vectorType, permutationMap) + maskType ? inferTransferOpMaskType(vectorType, permutationMap) : VectorType(); if (llvm::size(getIndices()) != shapedType.getRank()) @@ -3945,7 +3933,7 @@ LogicalResult TransferWriteOp::verify() { /// Returns the mask type expected by this operation. Mostly used for /// verification purposes. Type TransferWriteOp::getExpectedMaskType() { - return inferTransferWriteMaskType(getVectorType(), getPermutationMap()); + return inferTransferOpMaskType(getVectorType(), getPermutationMap()); } /// Fold: diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp index d7ec87e..df8ba7b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp @@ -20,15 +20,16 @@ using namespace mlir; using namespace mlir::vector; -/// Transpose a vector transfer op's `in_bounds` attribute according to given -/// indices. +/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse +/// permutation based on the given indices. static ArrayAttr -transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, - const SmallVector &permutation) { - SmallVector newInBoundsValues; +inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, + const SmallVector &permutation) { + SmallVector newInBoundsValues(permutation.size()); + size_t index = 0; for (unsigned pos : permutation) - newInBoundsValues.push_back( - attr.getValue()[pos].cast().getValue()); + newInBoundsValues[pos] = + attr.getValue()[index++].cast().getValue(); return builder.getBoolArrayAttr(newInBoundsValues); } @@ -85,7 +86,7 @@ struct TransferReadPermutationLowering // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = - op.getInBounds() ? transposeInBoundsAttr( + op.getInBounds() ? inverseTransposeInBoundsAttr( rewriter, op.getInBounds().value(), permutation) : ArrayAttr(); @@ -142,16 +143,17 @@ struct TransferWritePermutationLowering // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) // comp = (d0, d1, d2) -> (d2, d0, d1) auto comp = compressUnusedDims(map); + AffineMap permutationMap = inversePermutation(comp); // Get positions of remaining result dims. SmallVector indices; - llvm::transform(comp.getResults(), std::back_inserter(indices), + llvm::transform(permutationMap.getResults(), std::back_inserter(indices), [](AffineExpr expr) { return expr.dyn_cast().getPosition(); }); // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = - op.getInBounds() ? transposeInBoundsAttr( + op.getInBounds() ? inverseTransposeInBoundsAttr( rewriter, op.getInBounds().value(), permutation) : ArrayAttr(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 0da64de..779b84f 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -337,11 +337,11 @@ func.func @transfer_write_permutations( // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<8x14x16x7xi1> - %mask0 = vector.splat %m : vector<8x14x16x7xi1> + // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1> + %mask0 = vector.splat %m : vector<16x14x7x8xi1> %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor - // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32> - // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor + // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32> + // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [true, false, true, false]} : vector<16x14x7x8xf32>, tensor vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir index 91b0758..cee90c7 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s @@ -39,6 +39,34 @@ func.func @transfer_read_1d(%A : memref) -> vector<32xf32> { return %r : vector<32xf32> } +func.func @transfer_write_inbounds_3d(%A : memref<4x4x4xf32>) { + %c0 = arith.constant 0: index + %f = arith.constant 0.0 : f32 + %v0 = vector.splat %f : vector<2x3x4xf32> + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %f3 = arith.constant 3.0 : f32 + %f4 = arith.constant 4.0 : f32 + %f5 = arith.constant 5.0 : f32 + %f6 = arith.constant 6.0 : f32 + %f7 = arith.constant 7.0 : f32 + %f8 = arith.constant 8.0 : f32 + + %v1 = vector.insert %f1, %v0[0, 0, 0] : f32 into vector<2x3x4xf32> + %v2 = vector.insert %f2, %v1[0, 0, 3] : f32 into vector<2x3x4xf32> + %v3 = vector.insert %f3, %v2[0, 2, 0] : f32 into vector<2x3x4xf32> + %v4 = vector.insert %f4, %v3[0, 2, 3] : f32 into vector<2x3x4xf32> + %v5 = vector.insert %f5, %v4[1, 0, 0] : f32 into vector<2x3x4xf32> + %v6 = vector.insert %f6, %v5[1, 0, 3] : f32 into vector<2x3x4xf32> + %v7 = vector.insert %f7, %v6[1, 2, 0] : f32 into vector<2x3x4xf32> + %v8 = vector.insert %f8, %v7[1, 2, 3] : f32 into vector<2x3x4xf32> + vector.transfer_write %v8, %A[%c0, %c0, %c0] + {permutation_map = affine_map<(d0, d1, d2) -> (d2, d0, d1)>, + in_bounds = [true, true, true]} + : vector<2x3x4xf32>, memref<4x4x4xf32> + return +} + func.func @entry() { %c0 = arith.constant 0: index %c1 = arith.constant 1: index @@ -90,6 +118,24 @@ func.func @entry() { vector.print %6 : vector<32xf32> memref.dealloc %A : memref + + // 3D case + %c4 = arith.constant 4: index + %A1 = memref.alloc() {alignment=64} : memref<4x4x4xf32> + scf.for %i = %c0 to %c4 step %c1 { + scf.for %j = %c0 to %c4 step %c1 { + scf.for %k = %c0 to %c4 step %c1 { + %f = arith.constant 0.0: f32 + memref.store %f, %A1[%i, %j, %k] : memref<4x4x4xf32> + } + } + } + call @transfer_write_inbounds_3d(%A1) : (memref<4x4x4xf32>) -> () + %f = arith.constant 0.0: f32 + %r = vector.transfer_read %A1[%c0, %c0, %c0], %f + : memref<4x4x4xf32>, vector<4x4x4xf32> + vector.print %r : vector<4x4x4xf32> + return } @@ -100,3 +146,7 @@ func.func @entry() { // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13 ) + +// 3D case. +// CHECK: ( ( ( 1, 5, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 2, 6, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ), +// CHECK-SAME: ( ( 3, 7, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 4, 8, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ) ) -- 2.7.4