From 7630520ae3c5af3f3536a81740cf316d3a21304e Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 5 Feb 2021 17:48:09 -0500 Subject: [PATCH] [mlir][vector] Add pattern to shuffle bitcast ops These patterns move vector.bitcast ops to be before insert ops or after extract ops where suitable. With them, bitcast will happen on smaller vectors and there are more chances to share extract/insert ops. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D96040 --- mlir/include/mlir/Dialect/Vector/VectorOps.h | 8 + mlir/lib/Dialect/Vector/VectorTransforms.cpp | 245 ++++++++++++++++++++++ mlir/test/Dialect/Vector/vector-transforms.mlir | 89 ++++++++ mlir/test/lib/Transforms/TestVectorTransforms.cpp | 1 + 4 files changed, 343 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h index b01aa11..afc55c1 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -44,6 +44,14 @@ void populateVectorToVectorTransformationPatterns( void populateCastAwayVectorLeadingOneDimPatterns( OwningRewritePatternList &patterns, MLIRContext *context); +/// Collect a set of patterns that bubble up/down bitcast ops. +/// +/// These patterns move vector.bitcast ops to be before insert ops or after +/// extract ops where suitable. With them, bitcast will happen on smaller +/// vectors and there are more chances to share extract/insert ops. +void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns, + MLIRContext *context); + /// Collect a set of vector slices transformation patterns: /// ExtractSlicesOpLowering, InsertSlicesOpLowering /// Useful for clients that want to express all vector "slices" diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 6a8ee49b..765eb08 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2787,6 +2787,244 @@ struct CastAwayTransferWriteLeadingOneDim } }; +// Returns the values in `arrayAttr` as an integer vector. +static SmallVector getIntValueVector(ArrayAttr arrayAttr) { + return llvm::to_vector<4>( + llvm::map_range(arrayAttr.getAsRange(), + [](IntegerAttr attr) { return attr.getInt(); })); +}; + +// Shuffles vector.bitcast op after vector.extract op. +// +// This transforms IR like: +// %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> +// %1 = vector.extract %0[3] : vector<8xf16> +// Into: +// %0 = vector.extract %src[1] : vector<4xf32> +// %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> +// %2 = vector.extract %1[1] : vector<2xf16> +struct BubbleDownVectorBitCastForExtract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // Only support extracting scalars for now. + if (extractOp.getVectorType().getRank() != 1) + return failure(); + + auto castOp = extractOp.vector().getDefiningOp(); + if (!castOp) + return failure(); + + VectorType castSrcType = castOp.getSourceVectorType(); + VectorType castDstType = castOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + + // Fail to match if we only have one element in the cast op source. + // This is to avoid infinite loop given that this pattern can generate + // such cases. + if (castSrcType.getNumElements() == 1) + return failure(); + + // Only support casting to a larger number of elements or now. + // E.g., vector<4xf32> -> vector<8xf16>. + if (castSrcType.getNumElements() > castDstType.getNumElements()) + return failure(); + + unsigned expandRatio = + castDstType.getNumElements() / castSrcType.getNumElements(); + + auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { + return (*attr.getAsValueRange().begin()).getZExtValue(); + }; + + uint64_t index = getFirstIntValue(extractOp.position()); + + // Get the single scalar (as a vector) in the source value that packs the + // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> + VectorType oneScalarType = + VectorType::get({1}, castSrcType.getElementType()); + Value packedValue = rewriter.create( + extractOp.getLoc(), oneScalarType, castOp.source(), + rewriter.getI64ArrayAttr(index / expandRatio)); + + // Cast it to a vector with the desired scalar's type. + // E.g. f32 -> vector<2xf16> + VectorType packedType = + VectorType::get({expandRatio}, castDstType.getElementType()); + Value castedValue = rewriter.create( + extractOp.getLoc(), packedType, packedValue); + + // Finally extract the desired scalar. + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType(), castedValue, + rewriter.getI64ArrayAttr(index % expandRatio)); + + return success(); + } +}; + +// Shuffles vector.bitcast op after vector.extract_strided_slice op. +// +// This transforms IR like: +// %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> +// %0 = vector.extract_strided_slice %cast { +// offsets = [4], sizes = [4], strides = [1] +// } : vector<8xf16> to vector<4xf16> +// Into: +// %0 = vector.extract_strided_slice %src { +// offsets = [2], sizes = [2], strides = [1] +// } : vector<4xf32> to vector<2xf32> +// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> +struct BubbleDownBitCastForStridedSliceExtract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, + PatternRewriter &rewriter) const override { + auto castOp = extractOp.vector().getDefiningOp(); + if (!castOp) + return failure(); + + VectorType castSrcType = castOp.getSourceVectorType(); + VectorType castDstType = castOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to more elements for now; other cases to be implemented. + if (castSrcLastDim > castDstLastDim) + return failure(); + + // Only accept all one strides for now. + if (llvm::any_of(extractOp.strides().getAsValueRange(), + [](const APInt &val) { return !val.isOneValue(); })) + return failure(); + + unsigned rank = extractOp.getVectorType().getRank(); + assert(castDstLastDim % castSrcLastDim == 0); + int64_t expandRatio = castDstLastDim / castSrcLastDim; + + // If we have a less number of offsets than the rank, then implicitly we + // are selecting the full range for the last bitcasted dimension; other + // dimensions aren't affected. Otherwise, we need to scale down the last + // dimension's offset given we are extracting from less elements now. + ArrayAttr newOffsets = extractOp.offsets(); + if (newOffsets.size() == rank) { + SmallVector offsets = getIntValueVector(newOffsets); + if (offsets.back() % expandRatio != 0) + return failure(); + offsets.back() = offsets.back() / expandRatio; + newOffsets = rewriter.getI64ArrayAttr(offsets); + } + + // Similarly for sizes. + ArrayAttr newSizes = extractOp.sizes(); + if (newSizes.size() == rank) { + SmallVector sizes = getIntValueVector(newSizes); + if (sizes.back() % expandRatio != 0) + return failure(); + sizes.back() = sizes.back() / expandRatio; + newSizes = rewriter.getI64ArrayAttr(sizes); + } + + SmallVector dims = + llvm::to_vector<4>(extractOp.getType().cast().getShape()); + dims.back() = dims.back() / expandRatio; + VectorType newExtractType = + VectorType::get(dims, castSrcType.getElementType()); + + auto newExtractOp = rewriter.create( + extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, + newSizes, extractOp.strides()); + + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType(), newExtractOp); + + return success(); + } +}; + +// Shuffles vector.bitcast op before vector.insert_strided_slice op. +// +// This transforms IR like: +// %0 = vector.insert_strided_slice %src, %dst { +// offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> +// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> +// Into: +// %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> +// %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> +// %2 = vector.insert_strided_slice %src, %dst { +// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> +struct BubbleUpBitCastForStridedSliceInsert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, + PatternRewriter &rewriter) const override { + VectorType castSrcType = bitcastOp.getSourceVectorType(); + VectorType castDstType = bitcastOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to less elements for now; other cases to be implemented. + if (castSrcLastDim < castDstLastDim) + return failure(); + + assert(castSrcLastDim % castDstLastDim == 0); + int64_t shrinkRatio = castSrcLastDim / castDstLastDim; + + auto insertOp = + bitcastOp.source().getDefiningOp(); + if (!insertOp) + return failure(); + + // Only accept all one strides for now. + if (llvm::any_of(insertOp.strides().getAsValueRange(), + [](const APInt &val) { return !val.isOneValue(); })) + return failure(); + + unsigned rank = insertOp.getSourceVectorType().getRank(); + // Require insert op to have the same rank for the source and destination + // vector; other cases to be implemented. + if (rank != insertOp.getDestVectorType().getRank()) + return failure(); + + ArrayAttr newOffsets = insertOp.offsets(); + assert(newOffsets.size() == rank); + SmallVector offsets = getIntValueVector(newOffsets); + if (offsets.back() % shrinkRatio != 0) + return failure(); + offsets.back() = offsets.back() / shrinkRatio; + newOffsets = rewriter.getI64ArrayAttr(offsets); + + SmallVector srcDims = + llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); + srcDims.back() = srcDims.back() / shrinkRatio; + VectorType newCastSrcType = + VectorType::get(srcDims, castDstType.getElementType()); + + auto newCastSrcOp = rewriter.create( + bitcastOp.getLoc(), newCastSrcType, insertOp.source()); + + SmallVector dstDims = + llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); + dstDims.back() = dstDims.back() / shrinkRatio; + VectorType newCastDstType = + VectorType::get(dstDims, castDstType.getElementType()); + + auto newCastDstOp = rewriter.create( + bitcastOp.getLoc(), newCastDstType, insertOp.dest()); + + rewriter.replaceOpWithNewOp( + bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, + insertOp.strides()); + + return success(); + } +}; + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( @@ -2811,6 +3049,13 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( context); } +void mlir::vector::populateBubbleVectorBitCastOpPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + void mlir::vector::populateVectorSlicesLoweringPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir index 831d2eb..20c9188 100644 --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -671,3 +671,92 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16> return } + +// CHECK-LABEL: func @bubble_down_bitcast_in_extract +// CHECK-SAME: %[[SRC:.+]]: vector<4xf32> +func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) { + %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> + // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : vector<4xf32> + // CHECK: %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<1xf32> to vector<2xf16> + // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : vector<2xf16> + %1 = vector.extract %0[3] : vector<8xf16> + // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : vector<4xf32> + // CHECK: %[[CAST2:.+]] = vector.bitcast %[[EXTRACT3]] : vector<1xf32> to vector<2xf16> + // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : vector<2xf16> + %2 = vector.extract %0[4] : vector<8xf16> + // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]] + return %1, %2: f16, f16 +} + +// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract +// CHECK-SAME: %[[SRC:.+]]: vector<4xf32> +func @bubble_down_bitcast_in_strided_slice_extract(%arg0: vector<4xf32>) -> vector<4xf16> { + // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2xf32> to vector<4xf16> + %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> + %0 = vector.extract_strided_slice %cast {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> + // CHECK: return %[[CAST]] + return %0: vector<4xf16> +} + +// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim +// CHECK-SAME: %[[SRC:.+]]: vector<4x2xf32> +func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim(%arg0: vector<4x2xf32>) -> vector<2x4xf16> { + // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [1], sizes = [2], strides = [1]} : vector<4x2xf32> to vector<2x2xf32> + // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2x2xf32> to vector<2x4xf16> + %cast = vector.bitcast %arg0: vector<4x2xf32> to vector<4x4xf16> + %0 = vector.extract_strided_slice %cast {offsets = [1], sizes = [2], strides = [1]} : vector<4x4xf16> to vector<2x4xf16> + // CHECK: return %[[CAST]] + return %0: vector<2x4xf16> +} + +// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_offset +func @bubble_down_bitcast_in_strided_slice_extract_odd_offset(%arg0: vector<4xf32>) -> vector<4xf16> { + // CHECK: vector.bitcast + // CHECK-NEXT: vector.extract_strided_slice + %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> + %0 = vector.extract_strided_slice %cast {offsets = [3], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> + return %0: vector<4xf16> +} + +// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_size +func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4xf32>) -> vector<3xf16> { + // CHECK: vector.bitcast + // CHECK-NEXT: vector.extract_strided_slice + %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> + %0 = vector.extract_strided_slice %cast {offsets = [0], sizes = [3], strides = [1]} : vector<8xf16> to vector<3xf16> + return %0: vector<3xf16> +} + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert +// CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>) +func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> { + // CHECK-DAG: %[[CAST_SRC1:.+]] = vector.bitcast %[[SRC1]] : vector<4xf16> to vector<2xf32> + // CHECK-DAG: %[[CAST_SRC2:.+]] = vector.bitcast %[[SRC2]] : vector<4xf16> to vector<2xf32> + // CHECK-DAG: %[[CAST_DST:.+]] = vector.bitcast %[[DST]] : vector<8xf16> to vector<4xf32> + // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST_SRC1]], %[[CAST_DST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST_SRC2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> + %0 = vector.insert_strided_slice %src1, %dst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> + %1 = vector.insert_strided_slice %src2, %0 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16> + %cast = vector.bitcast %1: vector<8xf16> to vector<4xf32> + // CHECK: return %[[INSERT2]] + return %cast: vector<4xf32> +} + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_offset +func @bubble_up_bitcast_in_strided_slice_insert_odd_offset(%dst: vector<8xf16>, %src: vector<4xf16>) -> vector<4xf32> { + // CHECK: vector.insert_strided_slice + // CHECK-NEXT: vector.bitcast + %0 = vector.insert_strided_slice %src, %dst {offsets = [3], strides = [1]} : vector<4xf16> into vector<8xf16> + %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32> + return %cast: vector<4xf32> +} + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_different_rank +func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector<16x4x8xf16>, %src: vector<2x4xf16>) -> vector<16x4x4xf32> { + // CHECK: vector.insert_strided_slice + // CHECK-NEXT: vector.bitcast + %0 = vector.insert_strided_slice %src, %dst {offsets = [0, 0, 2], strides = [1, 1]} : vector<2x4xf16> into vector<16x4x8xf16> + %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32> + return %cast: vector<16x4x4xf32> +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 109a9fc..61b1717 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -45,6 +45,7 @@ struct TestVectorToVectorConversion } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); + populateBubbleVectorBitCastOpPatterns(patterns, ctx); populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } -- 2.7.4