From: Nicolas Vasilache Date: Thu, 1 Dec 2022 10:56:33 +0000 (-0800) Subject: Revert "[WIP] Add support for MMA conversion for 1-D vector.transfer followed by... X-Git-Tag: upstream/17.0.6~25750 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3af6438372ad28c3c2c632a67b15fb68f9c3d52b;p=platform%2Fupstream%2Fllvm.git Revert "[WIP] Add support for MMA conversion for 1-D vector.transfer followed by a broadcast to 2-D" This reverts commit 7db25f78db807da171f23bcbaff258c5677901d1. This was mistakently stacked below (and committed) along with an NFC change. --- diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 1da8dc4..2734b5f 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -150,26 +150,6 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, return true; } -// Return true if the transfer op can be converted to a MMA matrix load. -static bool transferReadFollowedByBroadcastSupportsMMAMatrixType( - vector::TransferReadOp readOp, bool useNvGpu) { - bool res = true; - if (readOp.getMask() || readOp.hasOutOfBoundsDim() || - readOp.getVectorType().getRank() != 1) - res = false; - if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) - res = false; - AffineMap map = readOp.getPermutationMap(); - OpBuilder b(readOp.getContext()); - - if (res && !useNvGpu) - return map.isMinorIdentity() || isTransposeMatrixLoadMap(b, map); - - llvm::errs() << "RES transferReadFollowedByBroadcastSupportsMMAMatrixType: " - << res << "\n"; - return res; -} - // Return true if the transfer op can be converted to a MMA matrix store. static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { @@ -199,27 +179,8 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { /// Return true if this is a broadcast from scalar to a 2D vector. static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { - auto res = broadcastOp.getVectorType().getRank() == 2 && - broadcastOp.getSource().getType().isa(); - llvm::errs() << "RES broadcastSupportsMMAMatrixType: " << res << "\n"; - return res; -} - -/// Return true if this is a broadcast from 1-D to a 2-D vector and the 1-D -/// vector comes from a TransferReadOp. -static bool -broadcastFromTransferReadSupportsMMAMatrixType(vector::BroadcastOp broadcastOp, - bool useNvGpu) { - auto readOp = broadcastOp.getSource().getDefiningOp(); - auto sourceVectorType = - broadcastOp.getSource().getType().dyn_cast(); - auto res = - !broadcastSupportsMMAMatrixType(broadcastOp) && sourceVectorType && - sourceVectorType.getRank() == 1 && - transferReadFollowedByBroadcastSupportsMMAMatrixType(readOp, useNvGpu); - llvm::errs() << "RES broadcastFromTransferReadSupportsMMAMatrixType: " << res - << "\n"; - return res; + return broadcastOp.getVectorType().getRank() == 2 && + broadcastOp.getSource().getType().isa(); } /// Return the MMA elementwise enum associated with `op` if it is supported. @@ -258,10 +219,9 @@ extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { if (failed(contractOp)) return false; - // Handle vector.extract_strided_slice on registers - // containing matrixB and matrixC operands. - // vector.extract_strided_slice op is not supported on - // registers containing matrixA operands. + // Handle vector.extract_strided_slice on registers containing + // matrixB and matrixC operands. vector.extract_strided_slice op + // is not supported on registers containing matrixA operands. if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) return (op->getResult(0).getType().cast() == (*contractOp).getRhs().getType().cast()); @@ -276,9 +236,7 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { if (isa(op)) return true; if (auto transferRead = dyn_cast(op)) - return transferReadSupportsMMAMatrixType(transferRead, useNvGpu) || - transferReadFollowedByBroadcastSupportsMMAMatrixType(transferRead, - useNvGpu); + return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); if (auto transferWrite = dyn_cast(op)) return transferWriteSupportsMMAMatrixType(transferWrite); if (auto extractStridedSlice = dyn_cast(op)) @@ -288,10 +246,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { return contractSupportsMMAMatrixType(contract, useNvGpu); if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); - if (auto broadcast = dyn_cast(op)) { - return broadcastSupportsMMAMatrixType(broadcast) || - broadcastFromTransferReadSupportsMMAMatrixType(broadcast, useNvGpu); - } + if (auto broadcast = dyn_cast(op)) + return broadcastSupportsMMAMatrixType(broadcast); return elementwiseSupportsMMAMatrixType(op); } @@ -308,20 +264,17 @@ static SetVector getSliceContract(Operation *op, SetVector forwardSlice; while (currentIndex != slice.size()) { auto *currentOp = (slice)[currentIndex]; - // Compute and insert the backwardSlice starting from - // currentOp. + // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); getBackwardSlice(currentOp, &backwardSlice, backwardFilter); slice.insert(backwardSlice.begin(), backwardSlice.end()); - // Compute and insert the forwardSlice starting from - // currentOp. + // Compute and insert the forwardSlice starting from currentOp. forwardSlice.clear(); - // Special case for ForOp, we don't want to include the - // whole region but only the value using the region - // arguments. - // TODO: We should refine this to only care about the - // region arguments being converted to matrix type. + // Special case for ForOp, we don't want to include the whole region but + // only the value using the region arguments. + // TODO: We should refine this to only care about the region arguments being + // converted to matrix type. if (auto forOp = dyn_cast(currentOp)) { for (Value forOpResult : forOp.getResults()) getForwardSlice(forOpResult, &forwardSlice, forwardFilter); @@ -354,20 +307,16 @@ static SetVector getOpToConvert(mlir::Operation *op, return; SetVector dependentOps = getSliceContract(contract, hasVectorDest, hasVectorSrc); - // If any instruction cannot use MMA matrix type drop the - // whole chain. MMA matrix are stored in an opaque type so - // they cannot be used by all operations. + // If any instruction cannot use MMA matrix type drop the whole + // chain. MMA matrix are stored in an opaque type so they cannot be used + // by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { - auto res = !supportsMMaMatrixType(op, useNvGpu); - if (res) - llvm::errs() << "DOES NOT SUPPORT: " << *op << "\n"; - return res; + return !supportsMMaMatrixType(op, useNvGpu); })) return; opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); - // Sort the operations so that we can convert them in - // topological order. + // Sort the operations so that we can convert them in topological order. return topologicalSort(opToConvert); } @@ -494,12 +443,7 @@ static const char *inferFragType(OpTy op) { static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap &valueMapping) { assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); - if (!transferReadSupportsMMAMatrixType(op, - /*useNvGpu=*/false)) - return; - // Only transfers that return 2-D vectors are supported. - if (op.getVectorType().getRank() != 2) - return; + assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); std::optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); AffineMap map = op.getPermutationMap(); @@ -591,11 +535,10 @@ creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, *warpMatrixInfo, /*transpose=*/!op.getPermutationMap().isMinorIdentity()); if (failed(params)) { - return op->emitError() << "failed to convert vector.transfer_read to " - "ldmatrix; this op " - "likely " - "should not be converted to a nvgpu.ldmatrix " - "call."; + return op->emitError() + << "failed to convert vector.transfer_read to ldmatrix; this op " + "likely " + "should not be converted to a nvgpu.ldmatrix call."; } // Adjust the load offset. @@ -629,8 +572,7 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { op->emitError() << "Failed to deduce register fragment type during " - "conversion to distributed non-ldmatrix compatible " - "load"; + "conversion to distributed non-ldmatrix compatible load"; return failure(); } @@ -648,8 +590,8 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); - // If we are not transposing, then we can use vectorized - // loads. Otherwise, we must load each element individually. + // If we are not transposing, then we can use vectorized loads. Otherwise, we + // must load each element individually. if (!isTransposeLoad) { if (!loadedElType.isa()) { loadedElType = VectorType::get({1}, loadedElType); @@ -723,9 +665,9 @@ convertTransferReadToLoads(vector::TransferReadOp op, VectorType vecTy = op.getVectorType(); int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); - // When we are transposing the B operand, ldmatrix will only - // work if we have at least 8 rows to read and the width to - // read for the transpose is 128 bits. + // When we are transposing the B operand, ldmatrix will only work if we have + // at least 8 rows to read and the width to read for the transpose is 128 + // bits. if (!op.getPermutationMap().isMinorIdentity() && (bitWidth != 16 || vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128)) @@ -798,8 +740,7 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op, if (failed(mmaSyncFragmentInfo)) return failure(); - // Find the vector.transer_read whose result vector is being - // sliced. + // Find the vector.transer_read whose result vector is being sliced. auto transferReadOp = op.getVector().getDefiningOp(); if (!transferReadOp) return failure(); @@ -813,13 +754,12 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op, if (failed(ldFragmentInfo)) return failure(); - assert((mmaSyncFragmentInfo->elementsPerRegister == - ldFragmentInfo->elementsPerRegister) && - "Number of elements per register should be same for " - "load and mma.sync"); + assert( + (mmaSyncFragmentInfo->elementsPerRegister == + ldFragmentInfo->elementsPerRegister) && + "Number of elements per register should be same for load and mma.sync"); - // Create vector.extract_strided_slice op for thread-owned - // fragments. + // Create vector.extract_strided_slice op for thread-owned fragments. std::array strides = {1, 1}; // stride for extract slice is always 1. std::array sliceShape = { @@ -835,11 +775,9 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op, populateFromInt64AttrArray(op.getSizes(), sizes); ArrayRef warpVectorShape = op.getVectorType().getShape(); - // Compute offset in vector registers. Note that the mma.sync - // vector registers are shaped as numberOfFragments x - // numberOfRegistersPerfFragment. The vector registers can - // only be sliced along numberOfFragments, i.e., - // sliceOffset[0]. + // Compute offset in vector registers. Note that the mma.sync vector registers + // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector + // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. std::array sliceOffset = {0, 0}; if (offsets[0] && offsets[1]) @@ -904,10 +842,7 @@ static void convertConstantOp(arith::ConstantOp op, /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. static void convertBroadcastOp(vector::BroadcastOp op, llvm::DenseMap &valueMapping) { - // This op only catches the broadcasts that can directly - // convert to an MMA op. - if (!broadcastSupportsMMAMatrixType(op)) - return; + assert(broadcastSupportsMMAMatrixType(op)); OpBuilder b(op); const char *fragType = inferFragType(op); auto vecType = op.getVectorType(); @@ -918,39 +853,11 @@ static void convertBroadcastOp(vector::BroadcastOp op, valueMapping[op.getResult()] = matrix; } -/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. -static void -convertBroadcastFromTransferReadOp(vector::BroadcastOp broadcastOp, - llvm::DenseMap &valueMapping) { - // This op catches the broadcasts that cannot directly convert to an MMA - // op. - if (broadcastSupportsMMAMatrixType(broadcastOp)) - return; - if (!broadcastFromTransferReadSupportsMMAMatrixType(broadcastOp, - /*useNvGpu=*/false)) - return; - auto readOp = broadcastOp.getSource().getDefiningOp(); - assert(readOp && readOp.getVectorType().getRank() == 1); - // Handle broadcast by setting the stride to 0, unconditionally. - int64_t stride = 0; - const char *fragType = inferFragType(readOp); - gpu::MMAMatrixType type = gpu::MMAMatrixType::get( - broadcastOp.getVectorType().getShape(), - broadcastOp.getVectorType().getElementType(), fragType); - OpBuilder b(readOp); - bool isTranspose = false; - Value load = b.create( - readOp.getLoc(), type, readOp.getSource(), readOp.getIndices(), - b.getIndexAttr(stride), isTranspose ? b.getUnitAttr() : UnitAttr()); - valueMapping[broadcastOp.getResult()] = load; -} - // Replace ForOp with a new ForOp with extra operands. The YieldOp is not // updated and needs to be updated separatly for the loop to be correct. static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, ValueRange newIterOperands) { - // Create a new loop before the existing one, with the extra - // operands. + // Create a new loop before the existing one, with the extra operands. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getIterOperands()); @@ -1005,8 +912,8 @@ static void convertYieldOp(scf::YieldOp op, auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) continue; - // Replace the yield of old value with the for op argument - // to make it easier to remove the dead code. + // Replace the yield of old value with the for op argument to make it easier + // to remove the dead code. yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; yieldOperands.push_back(it->second); } @@ -1052,7 +959,6 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) { convertConstantOp(constantOp, valueMapping); } else if (auto broadcastOp = dyn_cast(op)) { convertBroadcastOp(broadcastOp, valueMapping); - convertBroadcastFromTransferReadOp(broadcastOp, valueMapping); } else if (auto forOp = dyn_cast(op)) { convertForOp(forOp, valueMapping); } else if (auto yiledOp = dyn_cast(op)) { @@ -1121,8 +1027,6 @@ struct ConvertVectorToGPUPass applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); - getOperation()->dump(); - if (useNvGpu.getValue()) { if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) return signalPassFailure(); diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir index 9a0f4c9..fa2a40f 100644 --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -4,6 +4,7 @@ #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> #map5 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @matmul @@ -117,21 +118,6 @@ func.func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x1 // CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise divf %[[D]], %[[E]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> -// func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, -// %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) { -// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> -// %c0 = arith.constant 0 : index -// %cst = arith.constant 0.000000e+00 : f16 -// %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> -// %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> -// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> -// %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst -// {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>} -// : memref<16x16x16x16xf16>, vector<16x16xf16> -// %F = arith.divf %D, %E : vector<16x16xf16> -// vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> -// return -// } func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> @@ -140,10 +126,9 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> - %Eread = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst - {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3)->(d3)>} - : memref<16x16x16x16xf16>, vector<16xf16> - %E = vector.broadcast %Eread: vector<16xf16> to vector<16x16xf16> + %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst + {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>} + : memref<16x16x16x16xf16>, vector<16x16xf16> %F = arith.divf %D, %E : vector<16x16xf16> vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return @@ -156,24 +141,12 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x // CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16> -// func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { -// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> -// %c0 = arith.constant 0 : index -// %cst = arith.constant 0.000000e+00 : f16 -// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> -// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> -// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> -// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> -// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> -// return -// } func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f16 %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> - %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16> - %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> @@ -187,24 +160,12 @@ func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %a // CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16> -// func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { -// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> -// %c0 = arith.constant 0 : index -// %cst = arith.constant 0.000000e+00 : f16 -// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16> -// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> -// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> -// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> -// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> -// return -// } func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f16 %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16> - %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16> - %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>