return true;
}
+// Return true if the transfer op can be converted to a MMA matrix load.
+static bool transferReadFollowedByBroadcastSupportsMMAMatrixType(
+ vector::TransferReadOp readOp, bool useNvGpu) {
+ bool res = true;
+ if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
+ readOp.getVectorType().getRank() != 1)
+ res = false;
+ if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
+ res = false;
+ AffineMap map = readOp.getPermutationMap();
+ OpBuilder b(readOp.getContext());
+
+ if (res && !useNvGpu)
+ return map.isMinorIdentity() || isTransposeMatrixLoadMap(b, map);
+
+ llvm::errs() << "RES transferReadFollowedByBroadcastSupportsMMAMatrixType: "
+ << res << "\n";
+ return res;
+}
+
// Return true if the transfer op can be converted to a MMA matrix store.
static bool
transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
/// Return true if this is a broadcast from scalar to a 2D vector.
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
- return broadcastOp.getVectorType().getRank() == 2 &&
- broadcastOp.getSource().getType().isa<FloatType>();
+ auto res = broadcastOp.getVectorType().getRank() == 2 &&
+ broadcastOp.getSource().getType().isa<FloatType>();
+ llvm::errs() << "RES broadcastSupportsMMAMatrixType: " << res << "\n";
+ return res;
+}
+
+/// Return true if this is a broadcast from 1-D to a 2-D vector and the 1-D
+/// vector comes from a TransferReadOp.
+static bool
+broadcastFromTransferReadSupportsMMAMatrixType(vector::BroadcastOp broadcastOp,
+ bool useNvGpu) {
+ auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
+ auto sourceVectorType =
+ broadcastOp.getSource().getType().dyn_cast<VectorType>();
+ auto res =
+ !broadcastSupportsMMAMatrixType(broadcastOp) && sourceVectorType &&
+ sourceVectorType.getRank() == 1 &&
+ transferReadFollowedByBroadcastSupportsMMAMatrixType(readOp, useNvGpu);
+ llvm::errs() << "RES broadcastFromTransferReadSupportsMMAMatrixType: " << res
+ << "\n";
+ return res;
}
/// Return the MMA elementwise enum associated with `op` if it is supported.
if (failed(contractOp))
return false;
- // Handle vector.extract_strided_slice on registers containing
- // matrixB and matrixC operands. vector.extract_strided_slice op
- // is not supported on registers containing matrixA operands.
+ // Handle vector.extract_strided_slice on registers
+ // containing matrixB and matrixC operands.
+ // vector.extract_strided_slice op is not supported on
+ // registers containing matrixA operands.
if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
return (op->getResult(0).getType().cast<VectorType>() ==
(*contractOp).getRhs().getType().cast<VectorType>());
if (isa<scf::ForOp, scf::YieldOp>(op))
return true;
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
- return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
+ return transferReadSupportsMMAMatrixType(transferRead, useNvGpu) ||
+ transferReadFollowedByBroadcastSupportsMMAMatrixType(transferRead,
+ useNvGpu);
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteSupportsMMAMatrixType(transferWrite);
if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
return contractSupportsMMAMatrixType(contract, useNvGpu);
if (auto constant = dyn_cast<arith::ConstantOp>(op))
return constantSupportsMMAMatrixType(constant);
- if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
- return broadcastSupportsMMAMatrixType(broadcast);
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
+ return broadcastSupportsMMAMatrixType(broadcast) ||
+ broadcastFromTransferReadSupportsMMAMatrixType(broadcast, useNvGpu);
+ }
return elementwiseSupportsMMAMatrixType(op);
}
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentOp = (slice)[currentIndex];
- // Compute and insert the backwardSlice starting from currentOp.
+ // Compute and insert the backwardSlice starting from
+ // currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
slice.insert(backwardSlice.begin(), backwardSlice.end());
- // Compute and insert the forwardSlice starting from currentOp.
+ // Compute and insert the forwardSlice starting from
+ // currentOp.
forwardSlice.clear();
- // Special case for ForOp, we don't want to include the whole region but
- // only the value using the region arguments.
- // TODO: We should refine this to only care about the region arguments being
- // converted to matrix type.
+ // Special case for ForOp, we don't want to include the
+ // whole region but only the value using the region
+ // arguments.
+ // TODO: We should refine this to only care about the
+ // region arguments being converted to matrix type.
if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
for (Value forOpResult : forOp.getResults())
getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
return;
SetVector<Operation *> dependentOps =
getSliceContract(contract, hasVectorDest, hasVectorSrc);
- // If any instruction cannot use MMA matrix type drop the whole
- // chain. MMA matrix are stored in an opaque type so they cannot be used
- // by all operations.
+ // If any instruction cannot use MMA matrix type drop the
+ // whole chain. MMA matrix are stored in an opaque type so
+ // they cannot be used by all operations.
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
- return !supportsMMaMatrixType(op, useNvGpu);
+ auto res = !supportsMMaMatrixType(op, useNvGpu);
+ if (res)
+ llvm::errs() << "DOES NOT SUPPORT: " << *op << "\n";
+ return res;
}))
return;
opToConvert.insert(dependentOps.begin(), dependentOps.end());
});
- // Sort the operations so that we can convert them in topological order.
+ // Sort the operations so that we can convert them in
+ // topological order.
return topologicalSort(opToConvert);
}
static void convertTransferReadOp(vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
- assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
+ if (!transferReadSupportsMMAMatrixType(op,
+ /*useNvGpu=*/false))
+ return;
+ // Only transfers that return 2-D vectors are supported.
+ if (op.getVectorType().getRank() != 2)
+ return;
std::optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
AffineMap map = op.getPermutationMap();
*warpMatrixInfo,
/*transpose=*/!op.getPermutationMap().isMinorIdentity());
if (failed(params)) {
- return op->emitError()
- << "failed to convert vector.transfer_read to ldmatrix; this op "
- "likely "
- "should not be converted to a nvgpu.ldmatrix call.";
+ return op->emitError() << "failed to convert vector.transfer_read to "
+ "ldmatrix; this op "
+ "likely "
+ "should not be converted to a nvgpu.ldmatrix "
+ "call.";
}
// Adjust the load offset.
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
op->emitError() << "Failed to deduce register fragment type during "
- "conversion to distributed non-ldmatrix compatible load";
+ "conversion to distributed non-ldmatrix compatible "
+ "load";
return failure();
}
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
- // If we are not transposing, then we can use vectorized loads. Otherwise, we
- // must load each element individually.
+ // If we are not transposing, then we can use vectorized
+ // loads. Otherwise, we must load each element individually.
if (!isTransposeLoad) {
if (!loadedElType.isa<VectorType>()) {
loadedElType = VectorType::get({1}, loadedElType);
VectorType vecTy = op.getVectorType();
int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
- // When we are transposing the B operand, ldmatrix will only work if we have
- // at least 8 rows to read and the width to read for the transpose is 128
- // bits.
+ // When we are transposing the B operand, ldmatrix will only
+ // work if we have at least 8 rows to read and the width to
+ // read for the transpose is 128 bits.
if (!op.getPermutationMap().isMinorIdentity() &&
(bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
vecTy.getDimSize(0) * bitWidth < 128))
if (failed(mmaSyncFragmentInfo))
return failure();
- // Find the vector.transer_read whose result vector is being sliced.
+ // Find the vector.transer_read whose result vector is being
+ // sliced.
auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return failure();
if (failed(ldFragmentInfo))
return failure();
- assert(
- (mmaSyncFragmentInfo->elementsPerRegister ==
- ldFragmentInfo->elementsPerRegister) &&
- "Number of elements per register should be same for load and mma.sync");
+ assert((mmaSyncFragmentInfo->elementsPerRegister ==
+ ldFragmentInfo->elementsPerRegister) &&
+ "Number of elements per register should be same for "
+ "load and mma.sync");
- // Create vector.extract_strided_slice op for thread-owned fragments.
+ // Create vector.extract_strided_slice op for thread-owned
+ // fragments.
std::array<int64_t, 2> strides = {1,
1}; // stride for extract slice is always 1.
std::array<int64_t, 2> sliceShape = {
populateFromInt64AttrArray(op.getSizes(), sizes);
ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
- // Compute offset in vector registers. Note that the mma.sync vector registers
- // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
- // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
+ // Compute offset in vector registers. Note that the mma.sync
+ // vector registers are shaped as numberOfFragments x
+ // numberOfRegistersPerfFragment. The vector registers can
+ // only be sliced along numberOfFragments, i.e.,
+ // sliceOffset[0].
std::array<int64_t, 2> sliceOffset = {0, 0};
if (offsets[0] && offsets[1])
/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static void convertBroadcastOp(vector::BroadcastOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
- assert(broadcastSupportsMMAMatrixType(op));
+ // This op only catches the broadcasts that can directly
+ // convert to an MMA op.
+ if (!broadcastSupportsMMAMatrixType(op))
+ return;
OpBuilder b(op);
const char *fragType = inferFragType(op);
auto vecType = op.getVectorType();
valueMapping[op.getResult()] = matrix;
}
+/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
+static void
+convertBroadcastFromTransferReadOp(vector::BroadcastOp broadcastOp,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ // This op catches the broadcasts that cannot directly convert to an MMA
+ // op.
+ if (broadcastSupportsMMAMatrixType(broadcastOp))
+ return;
+ if (!broadcastFromTransferReadSupportsMMAMatrixType(broadcastOp,
+ /*useNvGpu=*/false))
+ return;
+ auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
+ assert(readOp && readOp.getVectorType().getRank() == 1);
+ // Handle broadcast by setting the stride to 0, unconditionally.
+ int64_t stride = 0;
+ const char *fragType = inferFragType(readOp);
+ gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
+ broadcastOp.getVectorType().getShape(),
+ broadcastOp.getVectorType().getElementType(), fragType);
+ OpBuilder b(readOp);
+ bool isTranspose = false;
+ Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
+ readOp.getLoc(), type, readOp.getSource(), readOp.getIndices(),
+ b.getIndexAttr(stride), isTranspose ? b.getUnitAttr() : UnitAttr());
+ valueMapping[broadcastOp.getResult()] = load;
+}
+
// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
// updated and needs to be updated separatly for the loop to be correct.
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
ValueRange newIterOperands) {
- // Create a new loop before the existing one, with the extra operands.
+ // Create a new loop before the existing one, with the extra
+ // operands.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getIterOperands());
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end())
continue;
- // Replace the yield of old value with the for op argument to make it easier
- // to remove the dead code.
+ // Replace the yield of old value with the for op argument
+ // to make it easier to remove the dead code.
yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
yieldOperands.push_back(it->second);
}
convertConstantOp(constantOp, valueMapping);
} else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
convertBroadcastOp(broadcastOp, valueMapping);
+ convertBroadcastFromTransferReadOp(broadcastOp, valueMapping);
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
convertForOp(forOp, valueMapping);
} else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
+ getOperation()->dump();
+
if (useNvGpu.getValue()) {
if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
return signalPassFailure();
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-#map4 = affine_map<(d0) -> (d0, 0)>
#map5 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @matmul
// CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise divf %[[D]], %[[E]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+// func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
+// %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
+// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+// %c0 = arith.constant 0 : index
+// %cst = arith.constant 0.000000e+00 : f16
+// %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+// %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+// {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
+// : memref<16x16x16x16xf16>, vector<16x16xf16>
+// %F = arith.divf %D, %E : vector<16x16xf16>
+// vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+// return
+// }
func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
%arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
- %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
- {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
- : memref<16x16x16x16xf16>, vector<16x16xf16>
+ %Eread = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+ {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3)->(d3)>}
+ : memref<16x16x16x16xf16>, vector<16xf16>
+ %E = vector.broadcast %Eread: vector<16xf16> to vector<16x16xf16>
%F = arith.divf %D, %E : vector<16x16xf16>
vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
return
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
+// func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
+// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+// %c0 = arith.constant 0 : index
+// %cst = arith.constant 0.000000e+00 : f16
+// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
+// return
+// }
func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
- %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+ %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
+ %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
%C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
+// func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
+// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+// %c0 = arith.constant 0 : index
+// %cst = arith.constant 0.000000e+00 : f16
+// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
+// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
+// return
+// }
func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
- %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+ %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
+ %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
%C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>