[mlir][vector] Remove usage of shapecast to remove unit dim
authorThomas Raoux <thomasraoux@google.com>
Fri, 19 Nov 2021 00:09:49 +0000 (16:09 -0800)
committerThomas Raoux <thomasraoux@google.com>
Fri, 19 Nov 2021 18:25:21 +0000 (10:25 -0800)
Instead of using shape_cast op in the pattern removing leading unit
dimensions we use extract/broadcast ops. This is part of the effort to
restrict ShapeCastOp fuirther in the future and only allow them to
convert to or from 1D vector.

This also adds extra canonicalization to fill the gaps in simplifying
broadcast/extract ops.

Differential Revision: https://reviews.llvm.org/D114205

mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir [deleted file]
mlir/test/Dialect/Vector/vector-transforms.mlir

index 3ca8fa0..4b67b39 100644 (file)
@@ -1125,9 +1125,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
                        b.getI64ArrayAttr(extractPos));
     return extractOp.getResult();
   }
-  // TODO: In case the rank of the broadcast source is greater than the rank of
-  // the extract result this can be combined into a new broadcast op. This needs
-  // to be added a canonicalization pattern if needed.
   return Value();
 }
 
@@ -1208,12 +1205,63 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
 
 namespace {
 
+// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
+class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    Operation *defOp = extractOp.vector().getDefiningOp();
+    if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+      return failure();
+    Value source = defOp->getOperand(0);
+    if (extractOp.getType() == source.getType())
+      return failure();
+    auto getRank = [](Type type) {
+      return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
+    };
+    unsigned broadcasrSrcRank = getRank(source.getType());
+    unsigned extractResultRank = getRank(extractOp.getType());
+    // We only consider the case where the rank of the source is smaller than
+    // the rank of the extract dst. The other cases are handled in the folding
+    // patterns.
+    if (extractResultRank <= broadcasrSrcRank)
+      return failure();
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        extractOp, extractOp.getType(), source);
+    return success();
+  }
+};
+
+// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
+class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Return if 'extractStridedSliceOp' operand is not defined by a
+    // ConstantOp.
+    auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>();
+    if (!constantOp)
+      return failure();
+    auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
+    if (!dense)
+      return failure();
+    Attribute newAttr = dense.getSplatValue<Attribute>();
+    if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
+      newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  // ExtractToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
+  results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -1555,10 +1603,31 @@ static LogicalResult verify(InsertOp op) {
   return success();
 }
 
+namespace {
+
+// If insertOp is only inserting unit dimensions it can be transformed to a
+// broadcast.
+class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
+    if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
+                           srcVecType.getNumElements())
+      return failure();
+    rewriter.replaceOpWithNewOp<BroadcastOp>(
+        insertOp, insertOp.getDestVectorType(), insertOp.source());
+    return success();
+  }
+};
+
+} // namespace
+
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  // InsertToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
+  results.add<InsertToBroadcast, BroadcastFolder>(context);
 }
 
 // Eliminates insert operations that produce values identical to their source
index cf40e4f..8a77f17 100644 (file)
@@ -2943,6 +2943,11 @@ static VectorType trimLeadingOneDims(VectorType oldType) {
   return VectorType::get(newShape, oldType.getElementType());
 }
 
+/// Return a smallVector of size `rank` containing all zeros.
+static SmallVector<int64_t> splatZero(int64_t rank) {
+  return SmallVector<int64_t>(rank, 0);
+}
+
 // Casts away leading one dimensions in vector.extract_strided_slice's vector
 // input by inserting vector.shape_cast.
 struct CastAwayExtractStridedSliceLeadingOneDim
@@ -2969,8 +2974,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
 
     Location loc = extractOp.getLoc();
 
-    Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
-        loc, newSrcType, extractOp.vector());
+    Value newSrcVector = rewriter.create<vector::ExtractOp>(
+        loc, extractOp.vector(), splatZero(dropCount));
 
     // The offsets/sizes/strides attribute can have a less number of elements
     // than the input vector's rank: it is meant for the leading dimensions.
@@ -2984,7 +2989,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
         loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
 
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
                                                      newExtractOp);
 
     return success();
@@ -3004,17 +3009,18 @@ struct CastAwayInsertStridedSliceLeadingOneDim
     VectorType oldDstType = insertOp.getDestVectorType();
     VectorType newDstType = trimLeadingOneDims(oldDstType);
 
-    if (newSrcType.getRank() == oldSrcType.getRank() &&
-        newDstType.getRank() == oldDstType.getRank())
+    int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
+    int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
+    if (srcDropCount == 0 && dstDropCount == 0)
       return failure();
 
     // Trim leading one dimensions from both operands.
     Location loc = insertOp.getLoc();
 
-    Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
-        loc, newSrcType, insertOp.source());
-    Value newDstVector =
-        rewriter.create<vector::ShapeCastOp>(loc, newDstType, insertOp.dest());
+    Value newSrcVector = rewriter.create<vector::ExtractOp>(
+        loc, insertOp.source(), splatZero(srcDropCount));
+    Value newDstVector = rewriter.create<vector::ExtractOp>(
+        loc, insertOp.dest(), splatZero(dstDropCount));
 
     auto newOffsets = rewriter.getArrayAttr(
         insertOp.offsets().getValue().take_back(newDstType.getRank()));
@@ -3024,7 +3030,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
         loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
 
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
                                                      newInsertOp);
 
     return success();
@@ -3068,7 +3074,7 @@ struct CastAwayTransferReadLeadingOneDim
     auto newRead = rewriter.create<vector::TransferReadOp>(
         read.getLoc(), newType, read.source(), read.indices(), newMap,
         read.padding(), inBounds);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
 
     return success();
   }
@@ -3092,9 +3098,9 @@ struct CastAwayTransferWriteLeadingOneDim
 
     VectorType oldType = write.getVectorType();
     VectorType newType = trimLeadingOneDims(oldType);
-
     if (newType == oldType)
       return failure();
+    int64_t dropDim = oldType.getRank() - newType.getRank();
 
     AffineMap oldMap = write.permutation_map();
     ArrayRef<AffineExpr> newResults =
@@ -3108,8 +3114,8 @@ struct CastAwayTransferWriteLeadingOneDim
       inBounds = rewriter.getArrayAttr(
           write.in_boundsAttr().getValue().take_back(newType.getRank()));
 
-    auto newVector = rewriter.create<vector::ShapeCastOp>(
-        write.getLoc(), newType, write.vector());
+    auto newVector = rewriter.create<vector::ExtractOp>(
+        write.getLoc(), write.vector(), splatZero(dropDim));
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         write, newVector, write.source(), write.indices(), newMap, inBounds);
 
@@ -3117,35 +3123,6 @@ struct CastAwayTransferWriteLeadingOneDim
   }
 };
 
-template <typename BroadCastType>
-struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern<BroadCastType> {
-  using OpRewritePattern<BroadCastType>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BroadCastType broadcastOp,
-                                PatternRewriter &rewriter) const override {
-    VectorType dstType =
-        broadcastOp.getResult().getType().template dyn_cast<VectorType>();
-    if (!dstType)
-      return failure();
-    VectorType newDstType = trimLeadingOneDims(dstType);
-    if (newDstType == dstType)
-      return failure();
-    Location loc = broadcastOp.getLoc();
-    Value source = broadcastOp->getOperand(0);
-    VectorType srcVecType = source.getType().template dyn_cast<VectorType>();
-    if (srcVecType)
-      srcVecType = trimLeadingOneDims(srcVecType);
-    if (srcVecType && srcVecType != source.getType()) {
-      source = rewriter.create<vector::ShapeCastOp>(loc, srcVecType, source);
-    }
-    Value newBroadcastOp =
-        rewriter.create<BroadCastType>(loc, newDstType, source);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcastOp, dstType,
-                                                     newBroadcastOp);
-    return success();
-  }
-};
-
 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
 public:
   CastAwayElementwiseLeadingOneDim(MLIRContext *context)
@@ -3161,14 +3138,12 @@ public:
     VectorType newVecType = trimLeadingOneDims(vecType);
     if (newVecType == vecType)
       return failure();
-
+    int64_t dropDim = vecType.getRank() - newVecType.getRank();
     SmallVector<Value, 4> newOperands;
     for (Value operand : op->getOperands()) {
       if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
-        auto newType =
-            VectorType::get(newVecType.getShape(), opVecType.getElementType());
-        newOperands.push_back(rewriter.create<vector::ShapeCastOp>(
-            op->getLoc(), newType, operand));
+        newOperands.push_back(rewriter.create<vector::ExtractOp>(
+            op->getLoc(), operand, splatZero(dropDim)));
       } else {
         newOperands.push_back(operand);
       }
@@ -3178,69 +3153,12 @@ public:
     state.addOperands(newOperands);
     state.addTypes(newVecType);
     Operation *newOp = rewriter.createOperation(state);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, vecType,
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
                                                      newOp->getResult(0));
     return success();
   }
 };
 
-// If extractOp is only removing unit dimensions it can be transformed to a
-// shapecast.
-class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern<ExtractOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
-    if (!dstVecType || extractOp.getVectorType().getNumElements() !=
-                           dstVecType.getNumElements())
-      return failure();
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
-                                             extractOp.vector());
-    return success();
-  }
-};
-
-// If insertOp is only inserting unit dimensions it can be transformed to a
-// shapecast.
-class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
-public:
-  using OpRewritePattern<InsertOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(InsertOp insertOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
-    if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
-                           srcVecType.getNumElements())
-      return failure();
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(
-        insertOp, insertOp.getDestVectorType(), insertOp.source());
-    return success();
-  }
-};
-
-// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
-// the degenerated case where the broadcast only adds dimensions of size 1 it
-// can be replaced by a ShapeCastOp. This canonicalization checks if the total
-// number of elements is the same before and after the broadcast to detect if
-// the only change in the vector type are new dimensions of size 1.
-class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
-public:
-  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
-    if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
-                           srcVecType.getNumElements())
-      return failure();
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(
-        broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
-    return success();
-  }
-};
-
 // Returns the values in `arrayAttr` as an integer vector.
 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(
@@ -3722,13 +3640,11 @@ void mlir::vector::populateShapeCastFoldingPatterns(
 
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<
-      BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim,
-      CastAwayInsertStridedSliceLeadingOneDim,
-      CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
-      CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
-      CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
-      ExtractToShapeCast, InsertToShapeCast>(patterns.getContext());
+  patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
+               CastAwayInsertStridedSliceLeadingOneDim,
+               CastAwayTransferReadLeadingOneDim,
+               CastAwayTransferWriteLeadingOneDim,
+               CastAwayElementwiseLeadingOneDim>(patterns.getContext());
   populateShapeCastFoldingPatterns(patterns);
 }
 
index 3d60745..3557ceb 100644 (file)
@@ -496,13 +496,10 @@ func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
 
 // -----
 
-// Negative test for extract_op folding when the type of broadcast source
-// doesn't match the type of vector.extract.
-// CHECK-LABEL: fold_extract_broadcast_negative
-//       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<1x2x4xf32>
-//       CHECK:   %[[R:.*]] = vector.extract %[[B]][0, 1] : vector<1x2x4xf32>
-//       CHECK:   return %[[R]] : vector<4xf32>
-func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
+// CHECK-LABEL: fold_extract_broadcast
+//       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
+//       CHECK:   return %[[B]] : vector<4xf32>
+func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   %r = vector.extract %b[0, 1] : vector<1x2x4xf32>
   return %r : vector<4xf32>
@@ -1058,3 +1055,31 @@ func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
   vector<16x4xf16> to vector<2x4xf16>
   return %1 : vector<2x4xf16>
 }
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_broadcast
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1x4xf32>
+//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+  %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+  %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
+  %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+  return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_constant
+//       CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
+//       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
+//       CHECK: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32
+func @extract_constant() -> (vector<7xf32>, i32) {
+  %cst = arith.constant dense<2.000000e+00> : vector<29x7xf32>
+  %cst_1 = arith.constant dense<1> : vector<4x37x9xi32>
+  %0 = vector.extract %cst[2] : vector<29x7xf32>
+  %1 = vector.extract %cst_1[1, 4, 5] : vector<4x37x9xi32>
+  return %0, %1 : vector<7xf32>, i32
+}
diff --git a/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir b/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir
deleted file mode 100644 (file)
index 5dd44d3..0000000
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
-
-// CHECK-LABEL: broadcast_to_shapecast
-//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
-//  CHECK-NEXT:   return %[[C]] : vector<1x4x4xf16>
-func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
-  %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
-  return %0 : vector<1x4x4xf16>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_extract_to_shapecast
-//  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
-//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
-//       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
-  %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
-  %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
-  %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
-  return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
-}
index 50ed05a..96c3521 100644 (file)
@@ -421,21 +421,21 @@ func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
 
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
 func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
-  // CHECK:     %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+  // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
   // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
-  // CHECK:     %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
+  // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
   // CHECK: return %[[RET]]
   return %0: vector<1x1x8xf16>
 }
 
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
 func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
-  // CHECK:    %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
-  // CHECK:    %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+  // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8xf16>
+  // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
   // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
-  // CHECK:    %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
+  // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
   // CHECK: return %[[RET]]
   return %0: vector<1x8x8xf16>
 }
@@ -443,9 +443,10 @@ func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %a
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
 //  CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
 func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
-  // CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] : vector<1x1xf16> to vector<1x1x1xf16>
+  // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1x1xf16>
+  // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16>
   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
-  // CHECK: return %[[CAST]]
+  // CHECK: return %[[B]]
   return %0: vector<1x1x1xf16>
 }
 
@@ -456,7 +457,7 @@ func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> v
   // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
   %f0 = arith.constant 0. : f16
   // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
-  // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
   %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
   // CHECK: return %[[CAST]]
   return %0: vector<1x4xf16>
@@ -466,7 +467,7 @@ func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> v
 func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
   %c0 = arith.constant 0 : index
   %f0 = arith.constant 0. : f16
-  // CHECK: vector.shape_cast %{{.+}} : vector<1xf16> to vector<1x1xf16>
+  // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16>
   %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16>
   return %0: vector<1x1xf16>
 }
@@ -475,7 +476,7 @@ func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1
 func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
-  // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<1x4xf16>
   // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
 
   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
@@ -485,54 +486,35 @@ func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %ar
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
 func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
   %c0 = arith.constant 0 : index
-  // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
+  // CHECK: vector.extract %{{.+}}[0] : vector<1x1xf16>
   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16>
   return
 }
 
-// CHECK-LABEL: func @cast_away_broadcast_leading_one_dims
-func @cast_away_broadcast_leading_one_dims(
-  %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
-  (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) {
-  // CHECK:  vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
-  %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
-  // CHECK:  vector.broadcast %{{.*}} : f32 to vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
-  %1 = vector.broadcast %arg1 : f32 to vector<1x1x4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
-  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32>
-  %2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32>
-  // CHECK:  splat %{{.*}} : vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
-  %3 = splat %arg1 : vector<1x1x4xf32>
-  return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>
-}
-
 // CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
 func @cast_away_elementwise_leading_one_dims(
   %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
   %arg3: vector<1x4xf32>, %arg4: i1) ->
   (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
+  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
+  // CHECK:  vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
   // CHECK:  arith.addf %{{.*}}, %{{.*}} : vector<8xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
   %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
   // CHECK:  arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<4xi1> to vector<1x4xi1>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1>
   %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
   // CHECK:  select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
   %2 = select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
+  // CHECK:  vector.extract %{{.*}}[0] : vector<1x4xf32>
   // CHECK:  select %arg4, %12, %{{.*}} : vector<4xf32>
-  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
   %3 = select %arg4, %arg3, %arg2 : vector<1x4xf32>
   return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
 }