Revert "[WIP] Add support for MMA conversion for 1-D vector.transfer followed by...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 1 Dec 2022 10:56:33 +0000 (02:56 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 1 Dec 2022 10:57:03 +0000 (02:57 -0800)
This reverts commit 7db25f78db807da171f23bcbaff258c5677901d1.

This was mistakently stacked below (and committed) along with an NFC change.

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

index 1da8dc4..2734b5f 100644 (file)
@@ -150,26 +150,6 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
   return true;
 }
 
-// Return true if the transfer op can be converted to a MMA matrix load.
-static bool transferReadFollowedByBroadcastSupportsMMAMatrixType(
-    vector::TransferReadOp readOp, bool useNvGpu) {
-  bool res = true;
-  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
-      readOp.getVectorType().getRank() != 1)
-    res = false;
-  if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
-    res = false;
-  AffineMap map = readOp.getPermutationMap();
-  OpBuilder b(readOp.getContext());
-
-  if (res && !useNvGpu)
-    return map.isMinorIdentity() || isTransposeMatrixLoadMap(b, map);
-
-  llvm::errs() << "RES transferReadFollowedByBroadcastSupportsMMAMatrixType: "
-               << res << "\n";
-  return res;
-}
-
 // Return true if the transfer op can be converted to a MMA matrix store.
 static bool
 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
@@ -199,27 +179,8 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
 
 /// Return true if this is a broadcast from scalar to a 2D vector.
 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
-  auto res = broadcastOp.getVectorType().getRank() == 2 &&
-             broadcastOp.getSource().getType().isa<FloatType>();
-  llvm::errs() << "RES broadcastSupportsMMAMatrixType: " << res << "\n";
-  return res;
-}
-
-/// Return true if this is a broadcast from 1-D to a 2-D vector and the 1-D
-/// vector comes from a TransferReadOp.
-static bool
-broadcastFromTransferReadSupportsMMAMatrixType(vector::BroadcastOp broadcastOp,
-                                               bool useNvGpu) {
-  auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
-  auto sourceVectorType =
-      broadcastOp.getSource().getType().dyn_cast<VectorType>();
-  auto res =
-      !broadcastSupportsMMAMatrixType(broadcastOp) && sourceVectorType &&
-      sourceVectorType.getRank() == 1 &&
-      transferReadFollowedByBroadcastSupportsMMAMatrixType(readOp, useNvGpu);
-  llvm::errs() << "RES broadcastFromTransferReadSupportsMMAMatrixType: " << res
-               << "\n";
-  return res;
+  return broadcastOp.getVectorType().getRank() == 2 &&
+         broadcastOp.getSource().getType().isa<FloatType>();
 }
 
 /// Return the MMA elementwise enum associated with `op` if it is supported.
@@ -258,10 +219,9 @@ extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
   if (failed(contractOp))
     return false;
 
-  // Handle vector.extract_strided_slice on registers
-  // containing matrixB and matrixC operands.
-  // vector.extract_strided_slice op is not supported on
-  // registers containing matrixA operands.
+  // Handle vector.extract_strided_slice on registers containing
+  // matrixB and matrixC operands. vector.extract_strided_slice op
+  // is not supported on registers containing matrixA operands.
   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
     return (op->getResult(0).getType().cast<VectorType>() ==
             (*contractOp).getRhs().getType().cast<VectorType>());
@@ -276,9 +236,7 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
   if (isa<scf::ForOp, scf::YieldOp>(op))
     return true;
   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
-    return transferReadSupportsMMAMatrixType(transferRead, useNvGpu) ||
-           transferReadFollowedByBroadcastSupportsMMAMatrixType(transferRead,
-                                                                useNvGpu);
+    return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
     return transferWriteSupportsMMAMatrixType(transferWrite);
   if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
@@ -288,10 +246,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
     return contractSupportsMMAMatrixType(contract, useNvGpu);
   if (auto constant = dyn_cast<arith::ConstantOp>(op))
     return constantSupportsMMAMatrixType(constant);
-  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
-    return broadcastSupportsMMAMatrixType(broadcast) ||
-           broadcastFromTransferReadSupportsMMAMatrixType(broadcast, useNvGpu);
-  }
+  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+    return broadcastSupportsMMAMatrixType(broadcast);
   return elementwiseSupportsMMAMatrixType(op);
 }
 
@@ -308,20 +264,17 @@ static SetVector<Operation *> getSliceContract(Operation *op,
   SetVector<Operation *> forwardSlice;
   while (currentIndex != slice.size()) {
     auto *currentOp = (slice)[currentIndex];
-    // Compute and insert the backwardSlice starting from
-    // currentOp.
+    // Compute and insert the backwardSlice starting from currentOp.
     backwardSlice.clear();
     getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
     slice.insert(backwardSlice.begin(), backwardSlice.end());
 
-    // Compute and insert the forwardSlice starting from
-    // currentOp.
+    // Compute and insert the forwardSlice starting from currentOp.
     forwardSlice.clear();
-    // Special case for ForOp, we don't want to include the
-    // whole region but only the value using the region
-    // arguments.
-    // TODO: We should refine this to only care about the
-    // region arguments being converted to matrix type.
+    // Special case for ForOp, we don't want to include the whole region but
+    // only the value using the region arguments.
+    // TODO: We should refine this to only care about the region arguments being
+    // converted to matrix type.
     if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
       for (Value forOpResult : forOp.getResults())
         getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
@@ -354,20 +307,16 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
       return;
     SetVector<Operation *> dependentOps =
         getSliceContract(contract, hasVectorDest, hasVectorSrc);
-    // If any instruction cannot use MMA matrix type drop the
-    // whole chain. MMA matrix are stored in an opaque type so
-    // they cannot be used by all operations.
+    // If any instruction cannot use MMA matrix type drop the whole
+    // chain. MMA matrix are stored in an opaque type so they cannot be used
+    // by all operations.
     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
-          auto res = !supportsMMaMatrixType(op, useNvGpu);
-          if (res)
-            llvm::errs() << "DOES NOT SUPPORT: " << *op << "\n";
-          return res;
+          return !supportsMMaMatrixType(op, useNvGpu);
         }))
       return;
     opToConvert.insert(dependentOps.begin(), dependentOps.end());
   });
-  // Sort the operations so that we can convert them in
-  // topological order.
+  // Sort the operations so that we can convert them in topological order.
   return topologicalSort(opToConvert);
 }
 
@@ -494,12 +443,7 @@ static const char *inferFragType(OpTy op) {
 static void convertTransferReadOp(vector::TransferReadOp op,
                                   llvm::DenseMap<Value, Value> &valueMapping) {
   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
-  if (!transferReadSupportsMMAMatrixType(op,
-                                         /*useNvGpu=*/false))
-    return;
-  // Only transfers that return 2-D vectors are supported.
-  if (op.getVectorType().getRank() != 2)
-    return;
+  assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
   std::optional<int64_t> stride =
       getMemrefConstantHorizontalStride(op.getShapedType());
   AffineMap map = op.getPermutationMap();
@@ -591,11 +535,10 @@ creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
       *warpMatrixInfo,
       /*transpose=*/!op.getPermutationMap().isMinorIdentity());
   if (failed(params)) {
-    return op->emitError() << "failed to convert vector.transfer_read to "
-                              "ldmatrix; this op "
-                              "likely "
-                              "should not be converted to a nvgpu.ldmatrix "
-                              "call.";
+    return op->emitError()
+           << "failed to convert vector.transfer_read to ldmatrix; this op "
+              "likely "
+              "should not be converted to a nvgpu.ldmatrix call.";
   }
 
   // Adjust the load offset.
@@ -629,8 +572,7 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
   if (failed(regInfo)) {
     op->emitError() << "Failed to deduce register fragment type during "
-                       "conversion to distributed non-ldmatrix compatible "
-                       "load";
+                       "conversion to distributed non-ldmatrix compatible load";
     return failure();
   }
 
@@ -648,8 +590,8 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
 
   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
 
-  // If we are not transposing, then we can use vectorized
-  // loads. Otherwise, we must load each element individually.
+  // If we are not transposing, then we can use vectorized loads. Otherwise, we
+  // must load each element individually.
   if (!isTransposeLoad) {
     if (!loadedElType.isa<VectorType>()) {
       loadedElType = VectorType::get({1}, loadedElType);
@@ -723,9 +665,9 @@ convertTransferReadToLoads(vector::TransferReadOp op,
   VectorType vecTy = op.getVectorType();
   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
 
-  // When we are transposing the B operand, ldmatrix will only
-  // work if we have at least 8 rows to read and the width to
-  // read for the transpose is 128 bits.
+  // When we are transposing the B operand, ldmatrix will only work if we have
+  // at least 8 rows to read and the width to read for the transpose is 128
+  // bits.
   if (!op.getPermutationMap().isMinorIdentity() &&
       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
        vecTy.getDimSize(0) * bitWidth < 128))
@@ -798,8 +740,7 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
   if (failed(mmaSyncFragmentInfo))
     return failure();
 
-  // Find the vector.transer_read whose result vector is being
-  // sliced.
+  // Find the vector.transer_read whose result vector is being sliced.
   auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
   if (!transferReadOp)
     return failure();
@@ -813,13 +754,12 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
   if (failed(ldFragmentInfo))
     return failure();
 
-  assert((mmaSyncFragmentInfo->elementsPerRegister ==
-          ldFragmentInfo->elementsPerRegister) &&
-         "Number of elements per register should be same for "
-         "load and mma.sync");
+  assert(
+      (mmaSyncFragmentInfo->elementsPerRegister ==
+       ldFragmentInfo->elementsPerRegister) &&
+      "Number of elements per register should be same for load and mma.sync");
 
-  // Create vector.extract_strided_slice op for thread-owned
-  // fragments.
+  // Create vector.extract_strided_slice op for thread-owned fragments.
   std::array<int64_t, 2> strides = {1,
                                     1}; // stride for extract slice is always 1.
   std::array<int64_t, 2> sliceShape = {
@@ -835,11 +775,9 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
   populateFromInt64AttrArray(op.getSizes(), sizes);
   ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
 
-  // Compute offset in vector registers. Note that the mma.sync
-  // vector registers are shaped as numberOfFragments x
-  // numberOfRegistersPerfFragment. The vector registers can
-  // only be sliced along numberOfFragments, i.e.,
-  // sliceOffset[0].
+  // Compute offset in vector registers. Note that the mma.sync vector registers
+  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
+  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
   std::array<int64_t, 2> sliceOffset = {0, 0};
 
   if (offsets[0] && offsets[1])
@@ -904,10 +842,7 @@ static void convertConstantOp(arith::ConstantOp op,
 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
 static void convertBroadcastOp(vector::BroadcastOp op,
                                llvm::DenseMap<Value, Value> &valueMapping) {
-  // This op only catches the broadcasts that can directly
-  // convert to an MMA op.
-  if (!broadcastSupportsMMAMatrixType(op))
-    return;
+  assert(broadcastSupportsMMAMatrixType(op));
   OpBuilder b(op);
   const char *fragType = inferFragType(op);
   auto vecType = op.getVectorType();
@@ -918,39 +853,11 @@ static void convertBroadcastOp(vector::BroadcastOp op,
   valueMapping[op.getResult()] = matrix;
 }
 
-/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
-static void
-convertBroadcastFromTransferReadOp(vector::BroadcastOp broadcastOp,
-                                   llvm::DenseMap<Value, Value> &valueMapping) {
-  // This op catches the broadcasts that cannot directly convert to an MMA
-  // op.
-  if (broadcastSupportsMMAMatrixType(broadcastOp))
-    return;
-  if (!broadcastFromTransferReadSupportsMMAMatrixType(broadcastOp,
-                                                      /*useNvGpu=*/false))
-    return;
-  auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
-  assert(readOp && readOp.getVectorType().getRank() == 1);
-  // Handle broadcast by setting the stride to 0, unconditionally.
-  int64_t stride = 0;
-  const char *fragType = inferFragType(readOp);
-  gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
-      broadcastOp.getVectorType().getShape(),
-      broadcastOp.getVectorType().getElementType(), fragType);
-  OpBuilder b(readOp);
-  bool isTranspose = false;
-  Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
-      readOp.getLoc(), type, readOp.getSource(), readOp.getIndices(),
-      b.getIndexAttr(stride), isTranspose ? b.getUnitAttr() : UnitAttr());
-  valueMapping[broadcastOp.getResult()] = load;
-}
-
 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
 // updated and needs to be updated separatly for the loop to be correct.
 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
                                                ValueRange newIterOperands) {
-  // Create a new loop before the existing one, with the extra
-  // operands.
+  // Create a new loop before the existing one, with the extra operands.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(loop);
   auto operands = llvm::to_vector<4>(loop.getIterOperands());
@@ -1005,8 +912,8 @@ static void convertYieldOp(scf::YieldOp op,
     auto it = valueMapping.find(operand.value());
     if (it == valueMapping.end())
       continue;
-    // Replace the yield of old value with the for op argument
-    // to make it easier to remove the dead code.
+    // Replace the yield of old value with the for op argument to make it easier
+    // to remove the dead code.
     yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
     yieldOperands.push_back(it->second);
   }
@@ -1052,7 +959,6 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) {
       convertConstantOp(constantOp, valueMapping);
     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
       convertBroadcastOp(broadcastOp, valueMapping);
-      convertBroadcastFromTransferReadOp(broadcastOp, valueMapping);
     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
       convertForOp(forOp, valueMapping);
     } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
@@ -1121,8 +1027,6 @@ struct ConvertVectorToGPUPass
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
 
-    getOperation()->dump();
-
     if (useNvGpu.getValue()) {
       if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
         return signalPassFailure();
index 9a0f4c9..fa2a40f 100644 (file)
@@ -4,6 +4,7 @@
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
 #map5 = affine_map<(d0, d1) -> (d0, d1)>
 
 // CHECK-LABEL: func @matmul
@@ -117,21 +118,6 @@ func.func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x1
 //       CHECK:   %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   %[[F:.+]] = gpu.subgroup_mma_elementwise divf %[[D]], %[[E]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
-// func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
-//   %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
-//   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
-//   %c0 = arith.constant 0 : index
-//   %cst = arith.constant 0.000000e+00 : f16
-//   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
-//   %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
-//   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-//   %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
-//     {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
-//     : memref<16x16x16x16xf16>, vector<16x16xf16>
-//   %F = arith.divf %D, %E : vector<16x16xf16>
-//   vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
-//   return
-// }
 func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
   %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
@@ -140,10 +126,9 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
   %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-  %Eread = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
-    {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3)->(d3)>}
-    : memref<16x16x16x16xf16>, vector<16xf16>
-  %E = vector.broadcast %Eread: vector<16xf16> to vector<16x16xf16>
+  %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+    {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
+    : memref<16x16x16x16xf16>, vector<16x16xf16>
   %F = arith.divf %D, %E : vector<16x16xf16>
   vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
   return
@@ -156,24 +141,12 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
 //   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
-// func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
-//   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
-//   %c0 = arith.constant 0 : index
-//   %cst = arith.constant 0.000000e+00 : f16
-//   %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
-//   %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
-//   %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
-//   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-//   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
-//   return
-// }
 func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f16
   %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
-  %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
-  %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
   %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
@@ -187,24 +160,12 @@ func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %a
 //   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
-// func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
-//   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
-//   %c0 = arith.constant 0 : index
-//   %cst = arith.constant 0.000000e+00 : f16
-//   %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
-//   %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
-//   %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
-//   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-//   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
-//   return
-// }
 func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f16
   %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
-  %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
-  %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
   %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>