[mlir][VectorToGPU] Support more cases in conversion to MMA ops
authorThomas Raoux <thomasraoux@google.com>
Wed, 10 Nov 2021 22:32:15 +0000 (14:32 -0800)
committerthomasraoux <thomasraoux@google.com>
Thu, 11 Nov 2021 21:10:38 +0000 (13:10 -0800)
Support load with broadcast, elementwise divf op and remove the
hardcoded restriction on the vector size. Picking the right size should
be enfored by user and will fail conversion to llvm/spirv if it is not
supported.

Differential Revision: https://reviews.llvm.org/D113618

mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

index 5e4d122..adefba3 100644 (file)
@@ -1130,11 +1130,12 @@ def GPU_ELEMENTWISE_OP_ADD : StrEnumAttrCase<"ADDF">;
 def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">;
 def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">;
 def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">;
+def GPU_ELEMENTWISE_OP_DIVF : StrEnumAttrCase<"DIVF">;
 
 def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp",
   "elementwise operation to apply to mma matrix",
   [GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL,
-   GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> {
+   GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF, GPU_ELEMENTWISE_OP_DIVF]> {
   let cppNamespace = "::mlir::gpu";
   let storageType = "::mlir::StringAttr";
   let returnType = "::mlir::gpu::MMAElementwiseOp";
index 6de7390..55935e7 100644 (file)
@@ -304,6 +304,8 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
     return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
   case gpu::MMAElementwiseOp::MULF:
     return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
+  case gpu::MMAElementwiseOp::DIVF:
+    return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
   case gpu::MMAElementwiseOp::MAXF:
     return createMinMaxF(builder, loc, operands[0], operands[1],
                          /*isMin=*/false);
index a9f3c7d..18f4726 100644 (file)
@@ -50,26 +50,7 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
   if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
     return false;
 
-  // Check that the size matches what is natively supported.
-  VectorType lhsType = contract.lhs().getType().cast<VectorType>();
-  VectorType rhsType = contract.rhs().getType().cast<VectorType>();
-  VectorType accType = contract.acc().getType().cast<VectorType>();
-
-  std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
-                                lhsType.getDimSize(1));
-  if (lhsType.getElementType().isInteger(8) &&
-      rhsType.getElementType().isInteger(8) &&
-      accType.getElementType().isInteger(32) &&
-      (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
-       dim == std::make_tuple(16, 8, 32)))
-    return true;
-
-  if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
-      (accType.getElementType().isF16() || accType.getElementType().isF32()) &&
-      (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
-       dim == std::make_tuple(16, 8, 16)))
-    return true;
-  return false;
+  return true;
 }
 
 // Return the stide for the dimension 0 of |type| if it is a memref and has a
@@ -95,8 +76,15 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
     return false;
   if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
     return false;
+  AffineMap map = readOp.permutation_map();
+  OpBuilder b(readOp.getContext());
+  AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
+  AffineExpr zero = b.getAffineConstantExpr(0);
+  auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
+                                          readOp.getContext());
   // TODO: Support transpose once it is added to GPU dialect ops.
-  if (!readOp.permutation_map().isMinorIdentity())
+  // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
+  if (!map.isMinorIdentity() && map != broadcastInnerDim)
     return false;
   return true;
 }
@@ -142,6 +130,8 @@ convertElementwiseOpToMMA(Operation *op) {
     return gpu::MMAElementwiseOp::MAXF;
   if (isa<MinFOp>(op))
     return gpu::MMAElementwiseOp::MINF;
+  if (isa<arith::DivFOp>(op))
+    return gpu::MMAElementwiseOp::DIVF;
   return llvm::None;
 }
 
@@ -166,6 +156,44 @@ static bool supportsMMaMatrixType(Operation *op) {
   return elementwiseSupportsMMAMatrixType(op);
 }
 
+/// Return an unsorted slice handling scf.for region differently than
+/// `getSlice`. In scf.for we only want to include as part of the slice elements
+/// that are part of the use/def chain.
+static SetVector<Operation *> getSliceContract(Operation *op,
+                                               TransitiveFilter backwardFilter,
+                                               TransitiveFilter forwardFilter) {
+  SetVector<Operation *> slice;
+  slice.insert(op);
+  unsigned currentIndex = 0;
+  SetVector<Operation *> backwardSlice;
+  SetVector<Operation *> forwardSlice;
+  while (currentIndex != slice.size()) {
+    auto *currentOp = (slice)[currentIndex];
+    // Compute and insert the backwardSlice starting from currentOp.
+    backwardSlice.clear();
+    getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
+    slice.insert(backwardSlice.begin(), backwardSlice.end());
+
+    // Compute and insert the forwardSlice starting from currentOp.
+    forwardSlice.clear();
+    // Special case for ForOp, we don't want to include the whole region but
+    // only the value using the region arguments.
+    // TODO: We should refine this to only care about the region arguments being
+    // converted to matrix type.
+    if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
+      for (Value forOpResult : forOp.getResults())
+        getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
+      for (BlockArgument &arg : forOp.getRegionIterArgs())
+        getForwardSlice(arg, &forwardSlice, forwardFilter);
+    } else {
+      getForwardSlice(currentOp, &forwardSlice, forwardFilter);
+    }
+    slice.insert(forwardSlice.begin(), forwardSlice.end());
+    ++currentIndex;
+  }
+  return slice;
+}
+
 // Analyze slice of operations based on convert op to figure out if the whole
 // slice can be converted to MMA operations.
 static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
@@ -182,16 +210,17 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
     if (opToConvert.contains(contract.getOperation()))
       return;
     SetVector<Operation *> dependentOps =
-        getSlice(contract, hasVectorDest, hasVectorSrc);
+        getSliceContract(contract, hasVectorDest, hasVectorSrc);
     // If any instruction cannot use MMA matrix type drop the whole
-    // chaine. MMA matrix are stored in an opaque type so they cannot be used
+    // chain. MMA matrix are stored in an opaque type so they cannot be used
     // by all operations.
     if (llvm::any_of(dependentOps,
                      [](Operation *op) { return !supportsMMaMatrixType(op); }))
       return;
     opToConvert.insert(dependentOps.begin(), dependentOps.end());
   });
-  return opToConvert;
+  // Sort the operations so that we can convert them in topological order.
+  return topologicalSort(opToConvert);
 }
 
 namespace {
@@ -309,6 +338,12 @@ static void convertTransferReadOp(vector::TransferReadOp op,
   assert(transferReadSupportsMMAMatrixType(op));
   Optional<int64_t> stride =
       getMemrefConstantHorizontalStride(op.getShapedType());
+  AffineMap map = op.permutation_map();
+  // Handle broadcast by setting the stride to 0.
+  if (map.getResult(0).isa<AffineConstantExpr>()) {
+    assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0);
+    stride = 0;
+  }
   assert(stride);
   const char *fragType = inferFragType(op);
   gpu::MMAMatrixType type =
index 2ca899f..5e8f40f 100644 (file)
@@ -106,3 +106,28 @@ func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16
   vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
   return
 }
+
+// CHECK-LABEL: func @matmul_fused_broadcast
+//   CHECK-DAG:   %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//   CHECK-DAG:   %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[F:.+]] = gpu.subgroup_mma_elementwise %[[D]], %[[E]] {operation = "DIVF"} : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
+  %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+    {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
+    : memref<16x16x16x16xf16>, vector<16x16xf16>
+  %F = arith.divf %D, %E : vector<16x16xf16>
+  vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}