[mlir][VectorToGPU] Support more cases in conversion to MMA ops
authorThomas Raoux <thomasraoux@google.com>
Wed, 10 Nov 2021 22:32:15 +0000 (14:32 -0800)
committerthomasraoux <thomasraoux@google.com>
Thu, 11 Nov 2021 21:10:38 +0000 (13:10 -0800)
Support load with broadcast, elementwise divf op and remove the
hardcoded restriction on the vector size. Picking the right size should
be enfored by user and will fail conversion to llvm/spirv if it is not
supported.

Differential Revision: https://reviews.llvm.org/D113618

mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

index 5e4d122c69eaa41f54afb804656177968924a868..adefba30d9a8e2cd29af24473ba8f7eaef2a76b7 100644 (file)
@@ -1130,11 +1130,12 @@ def GPU_ELEMENTWISE_OP_ADD : StrEnumAttrCase<"ADDF">;
 def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">;
 def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">;
 def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">;
+def GPU_ELEMENTWISE_OP_DIVF : StrEnumAttrCase<"DIVF">;
 
 def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp",
   "elementwise operation to apply to mma matrix",
   [GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL,
-   GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> {
+   GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF, GPU_ELEMENTWISE_OP_DIVF]> {
   let cppNamespace = "::mlir::gpu";
   let storageType = "::mlir::StringAttr";
   let returnType = "::mlir::gpu::MMAElementwiseOp";
index 6de739088b896b13b5e8379b2cad89fdc975f8fb..55935e7f971f756fe5876060f54ab389b4a836ec 100644 (file)
@@ -304,6 +304,8 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
     return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
   case gpu::MMAElementwiseOp::MULF:
     return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
+  case gpu::MMAElementwiseOp::DIVF:
+    return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
   case gpu::MMAElementwiseOp::MAXF:
     return createMinMaxF(builder, loc, operands[0], operands[1],
                          /*isMin=*/false);
index a9f3c7da9c842a8528a2aa2ecc1f9d88dde7b0a5..18f472634480dc80c167cc2900cf994ba0acd8eb 100644 (file)
@@ -50,26 +50,7 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
   if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
     return false;
 
-  // Check that the size matches what is natively supported.
-  VectorType lhsType = contract.lhs().getType().cast<VectorType>();
-  VectorType rhsType = contract.rhs().getType().cast<VectorType>();
-  VectorType accType = contract.acc().getType().cast<VectorType>();
-
-  std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
-                                lhsType.getDimSize(1));
-  if (lhsType.getElementType().isInteger(8) &&
-      rhsType.getElementType().isInteger(8) &&
-      accType.getElementType().isInteger(32) &&
-      (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
-       dim == std::make_tuple(16, 8, 32)))
-    return true;
-
-  if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
-      (accType.getElementType().isF16() || accType.getElementType().isF32()) &&
-      (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
-       dim == std::make_tuple(16, 8, 16)))
-    return true;
-  return false;
+  return true;
 }
 
 // Return the stide for the dimension 0 of |type| if it is a memref and has a
@@ -95,8 +76,15 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
     return false;
   if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
     return false;
+  AffineMap map = readOp.permutation_map();
+  OpBuilder b(readOp.getContext());
+  AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
+  AffineExpr zero = b.getAffineConstantExpr(0);
+  auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
+                                          readOp.getContext());
   // TODO: Support transpose once it is added to GPU dialect ops.
-  if (!readOp.permutation_map().isMinorIdentity())
+  // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
+  if (!map.isMinorIdentity() && map != broadcastInnerDim)
     return false;
   return true;
 }
@@ -142,6 +130,8 @@ convertElementwiseOpToMMA(Operation *op) {
     return gpu::MMAElementwiseOp::MAXF;
   if (isa<MinFOp>(op))
     return gpu::MMAElementwiseOp::MINF;
+  if (isa<arith::DivFOp>(op))
+    return gpu::MMAElementwiseOp::DIVF;
   return llvm::None;
 }
 
@@ -166,6 +156,44 @@ static bool supportsMMaMatrixType(Operation *op) {
   return elementwiseSupportsMMAMatrixType(op);
 }
 
+/// Return an unsorted slice handling scf.for region differently than
+/// `getSlice`. In scf.for we only want to include as part of the slice elements
+/// that are part of the use/def chain.
+static SetVector<Operation *> getSliceContract(Operation *op,
+                                               TransitiveFilter backwardFilter,
+                                               TransitiveFilter forwardFilter) {
+  SetVector<Operation *> slice;
+  slice.insert(op);
+  unsigned currentIndex = 0;
+  SetVector<Operation *> backwardSlice;
+  SetVector<Operation *> forwardSlice;
+  while (currentIndex != slice.size()) {
+    auto *currentOp = (slice)[currentIndex];
+    // Compute and insert the backwardSlice starting from currentOp.
+    backwardSlice.clear();
+    getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
+    slice.insert(backwardSlice.begin(), backwardSlice.end());
+
+    // Compute and insert the forwardSlice starting from currentOp.
+    forwardSlice.clear();
+    // Special case for ForOp, we don't want to include the whole region but
+    // only the value using the region arguments.
+    // TODO: We should refine this to only care about the region arguments being
+    // converted to matrix type.
+    if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
+      for (Value forOpResult : forOp.getResults())
+        getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
+      for (BlockArgument &arg : forOp.getRegionIterArgs())
+        getForwardSlice(arg, &forwardSlice, forwardFilter);
+    } else {
+      getForwardSlice(currentOp, &forwardSlice, forwardFilter);
+    }
+    slice.insert(forwardSlice.begin(), forwardSlice.end());
+    ++currentIndex;
+  }
+  return slice;
+}
+
 // Analyze slice of operations based on convert op to figure out if the whole
 // slice can be converted to MMA operations.
 static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
@@ -182,16 +210,17 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
     if (opToConvert.contains(contract.getOperation()))
       return;
     SetVector<Operation *> dependentOps =
-        getSlice(contract, hasVectorDest, hasVectorSrc);
+        getSliceContract(contract, hasVectorDest, hasVectorSrc);
     // If any instruction cannot use MMA matrix type drop the whole
-    // chaine. MMA matrix are stored in an opaque type so they cannot be used
+    // chain. MMA matrix are stored in an opaque type so they cannot be used
     // by all operations.
     if (llvm::any_of(dependentOps,
                      [](Operation *op) { return !supportsMMaMatrixType(op); }))
       return;
     opToConvert.insert(dependentOps.begin(), dependentOps.end());
   });
-  return opToConvert;
+  // Sort the operations so that we can convert them in topological order.
+  return topologicalSort(opToConvert);
 }
 
 namespace {
@@ -309,6 +338,12 @@ static void convertTransferReadOp(vector::TransferReadOp op,
   assert(transferReadSupportsMMAMatrixType(op));
   Optional<int64_t> stride =
       getMemrefConstantHorizontalStride(op.getShapedType());
+  AffineMap map = op.permutation_map();
+  // Handle broadcast by setting the stride to 0.
+  if (map.getResult(0).isa<AffineConstantExpr>()) {
+    assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0);
+    stride = 0;
+  }
   assert(stride);
   const char *fragType = inferFragType(op);
   gpu::MMAMatrixType type =
index 2ca899fa5bac40588ec2353c397a53ac0419c10d..5e8f40f52b86ba9b10d71bd66f14c1964ada54fb 100644 (file)
@@ -106,3 +106,28 @@ func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16
   vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
   return
 }
+
+// CHECK-LABEL: func @matmul_fused_broadcast
+//   CHECK-DAG:   %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//   CHECK-DAG:   %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[F:.+]] = gpu.subgroup_mma_elementwise %[[D]], %[[E]] {operation = "DIVF"} : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
+  %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+    {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
+    : memref<16x16x16x16xf16>, vector<16x16xf16>
+  %F = arith.divf %D, %E : vector<16x16xf16>
+  vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}