[WIP] Add support for MMA conversion for 1-D vector.transfer followed by a broadcast...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 30 Nov 2022 21:36:13 +0000 (13:36 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 1 Dec 2022 10:49:47 +0000 (02:49 -0800)
Differential Revision: https://reviews.llvm.org/D139040

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

index 2734b5f..1da8dc4 100644 (file)
@@ -150,6 +150,26 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
   return true;
 }
 
+// Return true if the transfer op can be converted to a MMA matrix load.
+static bool transferReadFollowedByBroadcastSupportsMMAMatrixType(
+    vector::TransferReadOp readOp, bool useNvGpu) {
+  bool res = true;
+  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
+      readOp.getVectorType().getRank() != 1)
+    res = false;
+  if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
+    res = false;
+  AffineMap map = readOp.getPermutationMap();
+  OpBuilder b(readOp.getContext());
+
+  if (res && !useNvGpu)
+    return map.isMinorIdentity() || isTransposeMatrixLoadMap(b, map);
+
+  llvm::errs() << "RES transferReadFollowedByBroadcastSupportsMMAMatrixType: "
+               << res << "\n";
+  return res;
+}
+
 // Return true if the transfer op can be converted to a MMA matrix store.
 static bool
 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
@@ -179,8 +199,27 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
 
 /// Return true if this is a broadcast from scalar to a 2D vector.
 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
-  return broadcastOp.getVectorType().getRank() == 2 &&
-         broadcastOp.getSource().getType().isa<FloatType>();
+  auto res = broadcastOp.getVectorType().getRank() == 2 &&
+             broadcastOp.getSource().getType().isa<FloatType>();
+  llvm::errs() << "RES broadcastSupportsMMAMatrixType: " << res << "\n";
+  return res;
+}
+
+/// Return true if this is a broadcast from 1-D to a 2-D vector and the 1-D
+/// vector comes from a TransferReadOp.
+static bool
+broadcastFromTransferReadSupportsMMAMatrixType(vector::BroadcastOp broadcastOp,
+                                               bool useNvGpu) {
+  auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
+  auto sourceVectorType =
+      broadcastOp.getSource().getType().dyn_cast<VectorType>();
+  auto res =
+      !broadcastSupportsMMAMatrixType(broadcastOp) && sourceVectorType &&
+      sourceVectorType.getRank() == 1 &&
+      transferReadFollowedByBroadcastSupportsMMAMatrixType(readOp, useNvGpu);
+  llvm::errs() << "RES broadcastFromTransferReadSupportsMMAMatrixType: " << res
+               << "\n";
+  return res;
 }
 
 /// Return the MMA elementwise enum associated with `op` if it is supported.
@@ -219,9 +258,10 @@ extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
   if (failed(contractOp))
     return false;
 
-  // Handle vector.extract_strided_slice on registers containing
-  // matrixB and matrixC operands. vector.extract_strided_slice op
-  // is not supported on registers containing matrixA operands.
+  // Handle vector.extract_strided_slice on registers
+  // containing matrixB and matrixC operands.
+  // vector.extract_strided_slice op is not supported on
+  // registers containing matrixA operands.
   if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
     return (op->getResult(0).getType().cast<VectorType>() ==
             (*contractOp).getRhs().getType().cast<VectorType>());
@@ -236,7 +276,9 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
   if (isa<scf::ForOp, scf::YieldOp>(op))
     return true;
   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
-    return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
+    return transferReadSupportsMMAMatrixType(transferRead, useNvGpu) ||
+           transferReadFollowedByBroadcastSupportsMMAMatrixType(transferRead,
+                                                                useNvGpu);
   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
     return transferWriteSupportsMMAMatrixType(transferWrite);
   if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
@@ -246,8 +288,10 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
     return contractSupportsMMAMatrixType(contract, useNvGpu);
   if (auto constant = dyn_cast<arith::ConstantOp>(op))
     return constantSupportsMMAMatrixType(constant);
-  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
-    return broadcastSupportsMMAMatrixType(broadcast);
+  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
+    return broadcastSupportsMMAMatrixType(broadcast) ||
+           broadcastFromTransferReadSupportsMMAMatrixType(broadcast, useNvGpu);
+  }
   return elementwiseSupportsMMAMatrixType(op);
 }
 
@@ -264,17 +308,20 @@ static SetVector<Operation *> getSliceContract(Operation *op,
   SetVector<Operation *> forwardSlice;
   while (currentIndex != slice.size()) {
     auto *currentOp = (slice)[currentIndex];
-    // Compute and insert the backwardSlice starting from currentOp.
+    // Compute and insert the backwardSlice starting from
+    // currentOp.
     backwardSlice.clear();
     getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
     slice.insert(backwardSlice.begin(), backwardSlice.end());
 
-    // Compute and insert the forwardSlice starting from currentOp.
+    // Compute and insert the forwardSlice starting from
+    // currentOp.
     forwardSlice.clear();
-    // Special case for ForOp, we don't want to include the whole region but
-    // only the value using the region arguments.
-    // TODO: We should refine this to only care about the region arguments being
-    // converted to matrix type.
+    // Special case for ForOp, we don't want to include the
+    // whole region but only the value using the region
+    // arguments.
+    // TODO: We should refine this to only care about the
+    // region arguments being converted to matrix type.
     if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
       for (Value forOpResult : forOp.getResults())
         getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
@@ -307,16 +354,20 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
       return;
     SetVector<Operation *> dependentOps =
         getSliceContract(contract, hasVectorDest, hasVectorSrc);
-    // If any instruction cannot use MMA matrix type drop the whole
-    // chain. MMA matrix are stored in an opaque type so they cannot be used
-    // by all operations.
+    // If any instruction cannot use MMA matrix type drop the
+    // whole chain. MMA matrix are stored in an opaque type so
+    // they cannot be used by all operations.
     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
-          return !supportsMMaMatrixType(op, useNvGpu);
+          auto res = !supportsMMaMatrixType(op, useNvGpu);
+          if (res)
+            llvm::errs() << "DOES NOT SUPPORT: " << *op << "\n";
+          return res;
         }))
       return;
     opToConvert.insert(dependentOps.begin(), dependentOps.end());
   });
-  // Sort the operations so that we can convert them in topological order.
+  // Sort the operations so that we can convert them in
+  // topological order.
   return topologicalSort(opToConvert);
 }
 
@@ -443,7 +494,12 @@ static const char *inferFragType(OpTy op) {
 static void convertTransferReadOp(vector::TransferReadOp op,
                                   llvm::DenseMap<Value, Value> &valueMapping) {
   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
-  assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
+  if (!transferReadSupportsMMAMatrixType(op,
+                                         /*useNvGpu=*/false))
+    return;
+  // Only transfers that return 2-D vectors are supported.
+  if (op.getVectorType().getRank() != 2)
+    return;
   std::optional<int64_t> stride =
       getMemrefConstantHorizontalStride(op.getShapedType());
   AffineMap map = op.getPermutationMap();
@@ -535,10 +591,11 @@ creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
       *warpMatrixInfo,
       /*transpose=*/!op.getPermutationMap().isMinorIdentity());
   if (failed(params)) {
-    return op->emitError()
-           << "failed to convert vector.transfer_read to ldmatrix; this op "
-              "likely "
-              "should not be converted to a nvgpu.ldmatrix call.";
+    return op->emitError() << "failed to convert vector.transfer_read to "
+                              "ldmatrix; this op "
+                              "likely "
+                              "should not be converted to a nvgpu.ldmatrix "
+                              "call.";
   }
 
   // Adjust the load offset.
@@ -572,7 +629,8 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
   if (failed(regInfo)) {
     op->emitError() << "Failed to deduce register fragment type during "
-                       "conversion to distributed non-ldmatrix compatible load";
+                       "conversion to distributed non-ldmatrix compatible "
+                       "load";
     return failure();
   }
 
@@ -590,8 +648,8 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
 
   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
 
-  // If we are not transposing, then we can use vectorized loads. Otherwise, we
-  // must load each element individually.
+  // If we are not transposing, then we can use vectorized
+  // loads. Otherwise, we must load each element individually.
   if (!isTransposeLoad) {
     if (!loadedElType.isa<VectorType>()) {
       loadedElType = VectorType::get({1}, loadedElType);
@@ -665,9 +723,9 @@ convertTransferReadToLoads(vector::TransferReadOp op,
   VectorType vecTy = op.getVectorType();
   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
 
-  // When we are transposing the B operand, ldmatrix will only work if we have
-  // at least 8 rows to read and the width to read for the transpose is 128
-  // bits.
+  // When we are transposing the B operand, ldmatrix will only
+  // work if we have at least 8 rows to read and the width to
+  // read for the transpose is 128 bits.
   if (!op.getPermutationMap().isMinorIdentity() &&
       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
        vecTy.getDimSize(0) * bitWidth < 128))
@@ -740,7 +798,8 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
   if (failed(mmaSyncFragmentInfo))
     return failure();
 
-  // Find the vector.transer_read whose result vector is being sliced.
+  // Find the vector.transer_read whose result vector is being
+  // sliced.
   auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
   if (!transferReadOp)
     return failure();
@@ -754,12 +813,13 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
   if (failed(ldFragmentInfo))
     return failure();
 
-  assert(
-      (mmaSyncFragmentInfo->elementsPerRegister ==
-       ldFragmentInfo->elementsPerRegister) &&
-      "Number of elements per register should be same for load and mma.sync");
+  assert((mmaSyncFragmentInfo->elementsPerRegister ==
+          ldFragmentInfo->elementsPerRegister) &&
+         "Number of elements per register should be same for "
+         "load and mma.sync");
 
-  // Create vector.extract_strided_slice op for thread-owned fragments.
+  // Create vector.extract_strided_slice op for thread-owned
+  // fragments.
   std::array<int64_t, 2> strides = {1,
                                     1}; // stride for extract slice is always 1.
   std::array<int64_t, 2> sliceShape = {
@@ -775,9 +835,11 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
   populateFromInt64AttrArray(op.getSizes(), sizes);
   ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
 
-  // Compute offset in vector registers. Note that the mma.sync vector registers
-  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
-  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
+  // Compute offset in vector registers. Note that the mma.sync
+  // vector registers are shaped as numberOfFragments x
+  // numberOfRegistersPerfFragment. The vector registers can
+  // only be sliced along numberOfFragments, i.e.,
+  // sliceOffset[0].
   std::array<int64_t, 2> sliceOffset = {0, 0};
 
   if (offsets[0] && offsets[1])
@@ -842,7 +904,10 @@ static void convertConstantOp(arith::ConstantOp op,
 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
 static void convertBroadcastOp(vector::BroadcastOp op,
                                llvm::DenseMap<Value, Value> &valueMapping) {
-  assert(broadcastSupportsMMAMatrixType(op));
+  // This op only catches the broadcasts that can directly
+  // convert to an MMA op.
+  if (!broadcastSupportsMMAMatrixType(op))
+    return;
   OpBuilder b(op);
   const char *fragType = inferFragType(op);
   auto vecType = op.getVectorType();
@@ -853,11 +918,39 @@ static void convertBroadcastOp(vector::BroadcastOp op,
   valueMapping[op.getResult()] = matrix;
 }
 
+/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
+static void
+convertBroadcastFromTransferReadOp(vector::BroadcastOp broadcastOp,
+                                   llvm::DenseMap<Value, Value> &valueMapping) {
+  // This op catches the broadcasts that cannot directly convert to an MMA
+  // op.
+  if (broadcastSupportsMMAMatrixType(broadcastOp))
+    return;
+  if (!broadcastFromTransferReadSupportsMMAMatrixType(broadcastOp,
+                                                      /*useNvGpu=*/false))
+    return;
+  auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
+  assert(readOp && readOp.getVectorType().getRank() == 1);
+  // Handle broadcast by setting the stride to 0, unconditionally.
+  int64_t stride = 0;
+  const char *fragType = inferFragType(readOp);
+  gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
+      broadcastOp.getVectorType().getShape(),
+      broadcastOp.getVectorType().getElementType(), fragType);
+  OpBuilder b(readOp);
+  bool isTranspose = false;
+  Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
+      readOp.getLoc(), type, readOp.getSource(), readOp.getIndices(),
+      b.getIndexAttr(stride), isTranspose ? b.getUnitAttr() : UnitAttr());
+  valueMapping[broadcastOp.getResult()] = load;
+}
+
 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
 // updated and needs to be updated separatly for the loop to be correct.
 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
                                                ValueRange newIterOperands) {
-  // Create a new loop before the existing one, with the extra operands.
+  // Create a new loop before the existing one, with the extra
+  // operands.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(loop);
   auto operands = llvm::to_vector<4>(loop.getIterOperands());
@@ -912,8 +1005,8 @@ static void convertYieldOp(scf::YieldOp op,
     auto it = valueMapping.find(operand.value());
     if (it == valueMapping.end())
       continue;
-    // Replace the yield of old value with the for op argument to make it easier
-    // to remove the dead code.
+    // Replace the yield of old value with the for op argument
+    // to make it easier to remove the dead code.
     yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
     yieldOperands.push_back(it->second);
   }
@@ -959,6 +1052,7 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) {
       convertConstantOp(constantOp, valueMapping);
     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
       convertBroadcastOp(broadcastOp, valueMapping);
+      convertBroadcastFromTransferReadOp(broadcastOp, valueMapping);
     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
       convertForOp(forOp, valueMapping);
     } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
@@ -1027,6 +1121,8 @@ struct ConvertVectorToGPUPass
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
 
+    getOperation()->dump();
+
     if (useNvGpu.getValue()) {
       if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
         return signalPassFailure();
index fa2a40f..9a0f4c9 100644 (file)
@@ -4,7 +4,6 @@
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-#map4 = affine_map<(d0) -> (d0, 0)>
 #map5 = affine_map<(d0, d1) -> (d0, d1)>
 
 // CHECK-LABEL: func @matmul
@@ -118,6 +117,21 @@ func.func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x1
 //       CHECK:   %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   %[[F:.+]] = gpu.subgroup_mma_elementwise divf %[[D]], %[[E]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+// func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
+//   %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
+//   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+//   %c0 = arith.constant 0 : index
+//   %cst = arith.constant 0.000000e+00 : f16
+//   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+//   %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+//   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+//   %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+//     {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
+//     : memref<16x16x16x16xf16>, vector<16x16xf16>
+//   %F = arith.divf %D, %E : vector<16x16xf16>
+//   vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+//   return
+// }
 func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
   %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
@@ -126,9 +140,10 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
   %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-  %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
-    {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
-    : memref<16x16x16x16xf16>, vector<16x16xf16>
+  %Eread = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+    {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3)->(d3)>}
+    : memref<16x16x16x16xf16>, vector<16xf16>
+  %E = vector.broadcast %Eread: vector<16xf16> to vector<16x16xf16>
   %F = arith.divf %D, %E : vector<16x16xf16>
   vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
   return
@@ -141,12 +156,24 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
 //   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
+// func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
+//   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+//   %c0 = arith.constant 0 : index
+//   %cst = arith.constant 0.000000e+00 : f16
+//   %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+//   %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+//   %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+//   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+//   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
+//   return
+// }
 func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f16
   %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
-  %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+  %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
+  %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
   %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
@@ -160,12 +187,24 @@ func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %a
 //   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
+// func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
+//   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+//   %c0 = arith.constant 0 : index
+//   %cst = arith.constant 0.000000e+00 : f16
+//   %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
+//   %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+//   %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+//   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+//   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
+//   return
+// }
 func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f16
   %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
-  %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+  %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
+  %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
   %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>