[mlir][vector] Add pattern to shuffle bitcast ops
authorLei Zhang <antiagainst@google.com>
Fri, 5 Feb 2021 22:48:09 +0000 (17:48 -0500)
committerLei Zhang <antiagainst@google.com>
Fri, 5 Feb 2021 22:52:49 +0000 (17:52 -0500)
These patterns move vector.bitcast ops to be before
insert ops or after extract ops where suitable.
With them, bitcast will happen on smaller vectors
and there are more chances to share extract/insert
ops.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D96040

mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transforms.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp

index b01aa11..afc55c1 100644 (file)
@@ -44,6 +44,14 @@ void populateVectorToVectorTransformationPatterns(
 void populateCastAwayVectorLeadingOneDimPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context);
 
+/// Collect a set of patterns that bubble up/down bitcast ops.
+///
+/// These patterns move vector.bitcast ops to be before insert ops or after
+/// extract ops where suitable. With them, bitcast will happen on smaller
+/// vectors and there are more chances to share extract/insert ops.
+void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
+                                           MLIRContext *context);
+
 /// Collect a set of vector slices transformation patterns:
 ///    ExtractSlicesOpLowering, InsertSlicesOpLowering
 /// Useful for clients that want to express all vector "slices"
index 6a8ee49..765eb08 100644 (file)
@@ -2787,6 +2787,244 @@ struct CastAwayTransferWriteLeadingOneDim
   }
 };
 
+// Returns the values in `arrayAttr` as an integer vector.
+static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
+  return llvm::to_vector<4>(
+      llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
+                      [](IntegerAttr attr) { return attr.getInt(); }));
+};
+
+// Shuffles vector.bitcast op after vector.extract op.
+//
+// This transforms IR like:
+//   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
+//   %1 = vector.extract %0[3] : vector<8xf16>
+// Into:
+//   %0 = vector.extract %src[1] : vector<4xf32>
+//   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
+//   %2 = vector.extract %1[1] : vector<2xf16>
+struct BubbleDownVectorBitCastForExtract
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Only support extracting scalars for now.
+    if (extractOp.getVectorType().getRank() != 1)
+      return failure();
+
+    auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
+    if (!castOp)
+      return failure();
+
+    VectorType castSrcType = castOp.getSourceVectorType();
+    VectorType castDstType = castOp.getResultVectorType();
+    assert(castSrcType.getRank() == castDstType.getRank());
+
+    // Fail to match if we only have one element in the cast op source.
+    // This is to avoid infinite loop given that this pattern can generate
+    // such cases.
+    if (castSrcType.getNumElements() == 1)
+      return failure();
+
+    // Only support casting to a larger number of elements or now.
+    // E.g., vector<4xf32> -> vector<8xf16>.
+    if (castSrcType.getNumElements() > castDstType.getNumElements())
+      return failure();
+
+    unsigned expandRatio =
+        castDstType.getNumElements() / castSrcType.getNumElements();
+
+    auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
+      return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
+    };
+
+    uint64_t index = getFirstIntValue(extractOp.position());
+
+    // Get the single scalar (as a vector) in the source value that packs the
+    // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
+    VectorType oneScalarType =
+        VectorType::get({1}, castSrcType.getElementType());
+    Value packedValue = rewriter.create<vector::ExtractOp>(
+        extractOp.getLoc(), oneScalarType, castOp.source(),
+        rewriter.getI64ArrayAttr(index / expandRatio));
+
+    // Cast it to a vector with the desired scalar's type.
+    // E.g. f32 -> vector<2xf16>
+    VectorType packedType =
+        VectorType::get({expandRatio}, castDstType.getElementType());
+    Value castedValue = rewriter.create<vector::BitCastOp>(
+        extractOp.getLoc(), packedType, packedValue);
+
+    // Finally extract the desired scalar.
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+        extractOp, extractOp.getType(), castedValue,
+        rewriter.getI64ArrayAttr(index % expandRatio));
+
+    return success();
+  }
+};
+
+// Shuffles vector.bitcast op after vector.extract_strided_slice op.
+//
+// This transforms IR like:
+//    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
+//     %0 = vector.extract_strided_slice %cast {
+//            offsets = [4], sizes = [4], strides = [1]
+//          } : vector<8xf16> to vector<4xf16>
+// Into:
+//   %0 = vector.extract_strided_slice %src {
+//          offsets = [2], sizes = [2], strides = [1]
+//        } : vector<4xf32> to vector<2xf32>
+//   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
+struct BubbleDownBitCastForStridedSliceExtract
+    : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
+    if (!castOp)
+      return failure();
+
+    VectorType castSrcType = castOp.getSourceVectorType();
+    VectorType castDstType = castOp.getResultVectorType();
+    assert(castSrcType.getRank() == castDstType.getRank());
+
+    int64_t castSrcLastDim = castSrcType.getShape().back();
+    int64_t castDstLastDim = castDstType.getShape().back();
+    // Require casting to more elements for now; other cases to be implemented.
+    if (castSrcLastDim > castDstLastDim)
+      return failure();
+
+    // Only accept all one strides for now.
+    if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
+                     [](const APInt &val) { return !val.isOneValue(); }))
+      return failure();
+
+    unsigned rank = extractOp.getVectorType().getRank();
+    assert(castDstLastDim % castSrcLastDim == 0);
+    int64_t expandRatio = castDstLastDim / castSrcLastDim;
+
+    // If we have a less number of offsets than the rank, then implicitly we
+    // are selecting the full range for the last bitcasted dimension; other
+    // dimensions aren't affected. Otherwise, we need to scale down the last
+    // dimension's offset given we are extracting from less elements now.
+    ArrayAttr newOffsets = extractOp.offsets();
+    if (newOffsets.size() == rank) {
+      SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
+      if (offsets.back() % expandRatio != 0)
+        return failure();
+      offsets.back() = offsets.back() / expandRatio;
+      newOffsets = rewriter.getI64ArrayAttr(offsets);
+    }
+
+    // Similarly for sizes.
+    ArrayAttr newSizes = extractOp.sizes();
+    if (newSizes.size() == rank) {
+      SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
+      if (sizes.back() % expandRatio != 0)
+        return failure();
+      sizes.back() = sizes.back() / expandRatio;
+      newSizes = rewriter.getI64ArrayAttr(sizes);
+    }
+
+    SmallVector<int64_t, 4> dims =
+        llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
+    dims.back() = dims.back() / expandRatio;
+    VectorType newExtractType =
+        VectorType::get(dims, castSrcType.getElementType());
+
+    auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
+        extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
+        newSizes, extractOp.strides());
+
+    rewriter.replaceOpWithNewOp<vector::BitCastOp>(
+        extractOp, extractOp.getType(), newExtractOp);
+
+    return success();
+  }
+};
+
+// Shuffles vector.bitcast op before vector.insert_strided_slice op.
+//
+// This transforms IR like:
+//   %0 = vector.insert_strided_slice %src, %dst {
+//          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
+//   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
+// Into:
+//   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
+//   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
+//   %2 = vector.insert_strided_slice %src, %dst {
+//          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+struct BubbleUpBitCastForStridedSliceInsert
+    : public OpRewritePattern<vector::BitCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType castSrcType = bitcastOp.getSourceVectorType();
+    VectorType castDstType = bitcastOp.getResultVectorType();
+    assert(castSrcType.getRank() == castDstType.getRank());
+
+    int64_t castSrcLastDim = castSrcType.getShape().back();
+    int64_t castDstLastDim = castDstType.getShape().back();
+    // Require casting to less elements for now; other cases to be implemented.
+    if (castSrcLastDim < castDstLastDim)
+      return failure();
+
+    assert(castSrcLastDim % castDstLastDim == 0);
+    int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
+
+    auto insertOp =
+        bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
+    if (!insertOp)
+      return failure();
+
+    // Only accept all one strides for now.
+    if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
+                     [](const APInt &val) { return !val.isOneValue(); }))
+      return failure();
+
+    unsigned rank = insertOp.getSourceVectorType().getRank();
+    // Require insert op to have the same rank for the source and destination
+    // vector; other cases to be implemented.
+    if (rank != insertOp.getDestVectorType().getRank())
+      return failure();
+
+    ArrayAttr newOffsets = insertOp.offsets();
+    assert(newOffsets.size() == rank);
+    SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
+    if (offsets.back() % shrinkRatio != 0)
+      return failure();
+    offsets.back() = offsets.back() / shrinkRatio;
+    newOffsets = rewriter.getI64ArrayAttr(offsets);
+
+    SmallVector<int64_t, 4> srcDims =
+        llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
+    srcDims.back() = srcDims.back() / shrinkRatio;
+    VectorType newCastSrcType =
+        VectorType::get(srcDims, castDstType.getElementType());
+
+    auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
+        bitcastOp.getLoc(), newCastSrcType, insertOp.source());
+
+    SmallVector<int64_t, 4> dstDims =
+        llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
+    dstDims.back() = dstDims.back() / shrinkRatio;
+    VectorType newCastDstType =
+        VectorType::get(dstDims, castDstType.getElementType());
+
+    auto newCastDstOp = rewriter.create<vector::BitCastOp>(
+        bitcastOp.getLoc(), newCastDstType, insertOp.dest());
+
+    rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
+        bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
+        insertOp.strides());
+
+    return success();
+  }
+};
+
 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
@@ -2811,6 +3049,13 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
       context);
 }
 
+void mlir::vector::populateBubbleVectorBitCastOpPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<BubbleDownVectorBitCastForExtract,
+                  BubbleDownBitCastForStridedSliceExtract,
+                  BubbleUpBitCastForStridedSliceInsert>(context);
+}
+
 void mlir::vector::populateVectorSlicesLoweringPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
   patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
index 831d2eb..20c9188 100644 (file)
@@ -671,3 +671,92 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16>
   return
 }
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_extract
+//  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
+func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
+  %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
+  // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : vector<4xf32>
+  // CHECK:    %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<1xf32> to vector<2xf16>
+  // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : vector<2xf16>
+  %1 = vector.extract %0[3] : vector<8xf16>
+  // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : vector<4xf32>
+  // CHECK:    %[[CAST2:.+]] = vector.bitcast %[[EXTRACT3]] : vector<1xf32> to vector<2xf16>
+  // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : vector<2xf16>
+  %2 = vector.extract %0[4] : vector<8xf16>
+  // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]
+  return %1, %2: f16, f16
+}
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract
+//  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
+func @bubble_down_bitcast_in_strided_slice_extract(%arg0: vector<4xf32>) -> vector<4xf16> {
+  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+  // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2xf32> to vector<4xf16>
+  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
+  %0 = vector.extract_strided_slice %cast {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<4xf16>
+}
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim
+//  CHECK-SAME: %[[SRC:.+]]: vector<4x2xf32>
+func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim(%arg0: vector<4x2xf32>) -> vector<2x4xf16> {
+  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [1], sizes = [2], strides = [1]} : vector<4x2xf32> to vector<2x2xf32>
+  // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2x2xf32> to vector<2x4xf16>
+  %cast = vector.bitcast %arg0: vector<4x2xf32> to vector<4x4xf16>
+  %0 = vector.extract_strided_slice %cast {offsets = [1], sizes = [2], strides = [1]} : vector<4x4xf16> to vector<2x4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<2x4xf16>
+}
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_offset
+func @bubble_down_bitcast_in_strided_slice_extract_odd_offset(%arg0: vector<4xf32>) -> vector<4xf16> {
+  // CHECK: vector.bitcast
+  // CHECK-NEXT: vector.extract_strided_slice
+  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
+  %0 = vector.extract_strided_slice %cast {offsets = [3], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+  return %0: vector<4xf16>
+}
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_size
+func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4xf32>) -> vector<3xf16> {
+  // CHECK: vector.bitcast
+  // CHECK-NEXT: vector.extract_strided_slice
+  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
+  %0 = vector.extract_strided_slice %cast {offsets = [0], sizes = [3], strides = [1]} : vector<8xf16> to vector<3xf16>
+  return %0: vector<3xf16>
+}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert
+//  CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>)
+func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> {
+  // CHECK-DAG: %[[CAST_SRC1:.+]] = vector.bitcast %[[SRC1]] : vector<4xf16> to vector<2xf32>
+  // CHECK-DAG: %[[CAST_SRC2:.+]] = vector.bitcast %[[SRC2]] : vector<4xf16> to vector<2xf32>
+  // CHECK-DAG: %[[CAST_DST:.+]] = vector.bitcast %[[DST]] : vector<8xf16> to vector<4xf32>
+  // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST_SRC1]], %[[CAST_DST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+  // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST_SRC2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+  %0 = vector.insert_strided_slice %src1, %dst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
+  %1 = vector.insert_strided_slice %src2, %0   {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
+  %cast = vector.bitcast %1: vector<8xf16> to vector<4xf32>
+  // CHECK: return %[[INSERT2]]
+  return %cast: vector<4xf32>
+}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_offset
+func @bubble_up_bitcast_in_strided_slice_insert_odd_offset(%dst: vector<8xf16>, %src: vector<4xf16>) -> vector<4xf32> {
+  // CHECK: vector.insert_strided_slice
+  // CHECK-NEXT: vector.bitcast
+  %0 = vector.insert_strided_slice %src, %dst {offsets = [3], strides = [1]} : vector<4xf16> into vector<8xf16>
+  %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
+  return %cast: vector<4xf32>
+}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_different_rank
+func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector<16x4x8xf16>, %src: vector<2x4xf16>) -> vector<16x4x4xf32> {
+  // CHECK: vector.insert_strided_slice
+  // CHECK-NEXT: vector.bitcast
+  %0 = vector.insert_strided_slice %src, %dst {offsets = [0, 0, 2], strides = [1, 1]} : vector<2x4xf16> into vector<16x4x8xf16>
+  %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32>
+  return %cast: vector<16x4x4xf32>
+}
index 109a9fc..61b1717 100644 (file)
@@ -45,6 +45,7 @@ struct TestVectorToVectorConversion
     }
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
+    populateBubbleVectorBitCastOpPatterns(patterns, ctx);
     populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }