[mlir][VectorToGPU] Support converting vetor.broadcast to MMA op
authorthomasraoux <thomasraoux@google.com>
Wed, 30 Jun 2021 07:02:47 +0000 (00:02 -0700)
committerthomasraoux <thomasraoux@google.com>
Wed, 30 Jun 2021 16:08:55 +0000 (09:08 -0700)
Differential Revision: https://reviews.llvm.org/D105175

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

index 869301f..7298b93 100644 (file)
@@ -123,6 +123,12 @@ static bool constantSupportsMMAMatrixType(ConstantOp constantOp) {
   return constantOp.value().isa<SplatElementsAttr>();
 }
 
+/// Return true if this is a broadcast from scalar to a 2D vector.
+static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
+  return broadcastOp.getVectorType().getRank() == 2 &&
+         broadcastOp.source().getType().isa<FloatType>();
+}
+
 static bool supportsMMaMatrixType(Operation *op) {
   if (isa<scf::ForOp, scf::YieldOp>(op))
     return true;
@@ -134,6 +140,8 @@ static bool supportsMMaMatrixType(Operation *op) {
     return contractSupportsMMAMatrixType(contract);
   if (auto constant = dyn_cast<ConstantOp>(op))
     return constantSupportsMMAMatrixType(constant);
+  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+    return broadcastSupportsMMAMatrixType(broadcast);
   return false;
 }
 
@@ -141,8 +149,11 @@ static bool supportsMMaMatrixType(Operation *op) {
 // slice can be converted to MMA operations.
 static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
   auto hasVectorDest = [](Operation *op) {
-    return op->getNumResults() == 0 ||
-           llvm::any_of(op->getResultTypes(),
+    return llvm::any_of(op->getResultTypes(),
+                        [](Type t) { return t.isa<VectorType>(); });
+  };
+  auto hasVectorSrc = [](Operation *op) {
+    return llvm::any_of(op->getOperandTypes(),
                         [](Type t) { return t.isa<VectorType>(); });
   };
   SetVector<Operation *> opToConvert;
@@ -150,7 +161,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
     if (opToConvert.contains(contract.getOperation()))
       return;
     SetVector<Operation *> dependentOps =
-        getSlice(contract, hasVectorDest, hasVectorDest);
+        getSlice(contract, hasVectorDest, hasVectorSrc);
     // If any instruction cannot use MMA matrix type drop the whole
     // chaine. MMA matrix are stored in an opaque type so they cannot be used
     // by all operations.
@@ -329,6 +340,20 @@ static void convertConstantOp(ConstantOp op,
   valueMapping[op.getResult()] = matrix;
 }
 
+/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
+static void convertBroadcastOp(vector::BroadcastOp op,
+                               llvm::DenseMap<Value, Value> &valueMapping) {
+  assert(broadcastSupportsMMAMatrixType(op));
+  OpBuilder b(op);
+  const char *fragType = inferFragType(op);
+  auto vecType = op.getVectorType();
+  gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
+      vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
+  auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
+                                                           op.source());
+  valueMapping[op.getResult()] = matrix;
+}
+
 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
 // updated and needs to be updated separatly for the loop to be correct.
 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
@@ -416,6 +441,8 @@ void convertVectorToMMAOps(FuncOp funcOp) {
       convertContractOp(contractOp, valueMapping);
     } else if (auto constantOp = dyn_cast<ConstantOp>(op)) {
       convertConstantOp(constantOp, valueMapping);
+    } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
+      convertBroadcastOp(broadcastOp, valueMapping);
     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
       convertForOp(forOp, valueMapping);
     } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
index a7fa579..db7087f 100644 (file)
@@ -41,6 +41,24 @@ func @matmul_cst(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memr
   return
 }
 
+// CHECK-LABEL: func @matmul_broadcast
+//  CHECK-SAME:   (%{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %[[F:.*]]: f16)
+//   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[F]] : !gpu.mma_matrix<16x16xf16, "COp">
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %f: f16) {
+  %C = vector.broadcast %f : f16 to vector<16x16xf16>
+  %c0 = constant 0 : index
+  %cst = constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}
+
 // CHECK-LABEL: func @matmul_loop
 //       CHECK:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   %[[ACC:.+]] = scf.for {{.*}} iter_args(%[[ACC1:.+]] = %[[C]]) -> (!gpu.mma_matrix<16x16xf16, "COp">) {