return constantOp.value().isa<SplatElementsAttr>();
}
+/// Return true if this is a broadcast from scalar to a 2D vector.
+static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
+ return broadcastOp.getVectorType().getRank() == 2 &&
+ broadcastOp.source().getType().isa<FloatType>();
+}
+
static bool supportsMMaMatrixType(Operation *op) {
if (isa<scf::ForOp, scf::YieldOp>(op))
return true;
return contractSupportsMMAMatrixType(contract);
if (auto constant = dyn_cast<ConstantOp>(op))
return constantSupportsMMAMatrixType(constant);
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+ return broadcastSupportsMMAMatrixType(broadcast);
return false;
}
// slice can be converted to MMA operations.
static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
auto hasVectorDest = [](Operation *op) {
- return op->getNumResults() == 0 ||
- llvm::any_of(op->getResultTypes(),
+ return llvm::any_of(op->getResultTypes(),
+ [](Type t) { return t.isa<VectorType>(); });
+ };
+ auto hasVectorSrc = [](Operation *op) {
+ return llvm::any_of(op->getOperandTypes(),
[](Type t) { return t.isa<VectorType>(); });
};
SetVector<Operation *> opToConvert;
if (opToConvert.contains(contract.getOperation()))
return;
SetVector<Operation *> dependentOps =
- getSlice(contract, hasVectorDest, hasVectorDest);
+ getSlice(contract, hasVectorDest, hasVectorSrc);
// If any instruction cannot use MMA matrix type drop the whole
// chaine. MMA matrix are stored in an opaque type so they cannot be used
// by all operations.
valueMapping[op.getResult()] = matrix;
}
+/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
+static void convertBroadcastOp(vector::BroadcastOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ assert(broadcastSupportsMMAMatrixType(op));
+ OpBuilder b(op);
+ const char *fragType = inferFragType(op);
+ auto vecType = op.getVectorType();
+ gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
+ vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
+ auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
+ op.source());
+ valueMapping[op.getResult()] = matrix;
+}
+
// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
// updated and needs to be updated separatly for the loop to be correct.
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
convertContractOp(contractOp, valueMapping);
} else if (auto constantOp = dyn_cast<ConstantOp>(op)) {
convertConstantOp(constantOp, valueMapping);
+ } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
+ convertBroadcastOp(broadcastOp, valueMapping);
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
convertForOp(forOp, valueMapping);
} else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
return
}
+// CHECK-LABEL: func @matmul_broadcast
+// CHECK-SAME: (%{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %[[F:.*]]: f16)
+// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[F]] : !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %f: f16) {
+ %C = vector.broadcast %f : f16 to vector<16x16xf16>
+ %c0 = constant 0 : index
+ %cst = constant 0.000000e+00 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+ return
+}
+
// CHECK-LABEL: func @matmul_loop
// CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[ACC:.+]] = scf.for {{.*}} iter_args(%[[ACC1:.+]] = %[[C]]) -> (!gpu.mma_matrix<16x16xf16, "COp">) {