if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
return false;
- // Check that the size matches what is natively supported.
- VectorType lhsType = contract.lhs().getType().cast<VectorType>();
- VectorType rhsType = contract.rhs().getType().cast<VectorType>();
- VectorType accType = contract.acc().getType().cast<VectorType>();
-
- std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
- lhsType.getDimSize(1));
- if (lhsType.getElementType().isInteger(8) &&
- rhsType.getElementType().isInteger(8) &&
- accType.getElementType().isInteger(32) &&
- (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
- dim == std::make_tuple(16, 8, 32)))
- return true;
-
- if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
- (accType.getElementType().isF16() || accType.getElementType().isF32()) &&
- (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
- dim == std::make_tuple(16, 8, 16)))
- return true;
- return false;
+ return true;
}
// Return the stide for the dimension 0 of |type| if it is a memref and has a
return false;
if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
return false;
+ AffineMap map = readOp.permutation_map();
+ OpBuilder b(readOp.getContext());
+ AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
+ AffineExpr zero = b.getAffineConstantExpr(0);
+ auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
+ readOp.getContext());
// TODO: Support transpose once it is added to GPU dialect ops.
- if (!readOp.permutation_map().isMinorIdentity())
+ // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
+ if (!map.isMinorIdentity() && map != broadcastInnerDim)
return false;
return true;
}
return gpu::MMAElementwiseOp::MAXF;
if (isa<MinFOp>(op))
return gpu::MMAElementwiseOp::MINF;
+ if (isa<arith::DivFOp>(op))
+ return gpu::MMAElementwiseOp::DIVF;
return llvm::None;
}
return elementwiseSupportsMMAMatrixType(op);
}
+/// Return an unsorted slice handling scf.for region differently than
+/// `getSlice`. In scf.for we only want to include as part of the slice elements
+/// that are part of the use/def chain.
+static SetVector<Operation *> getSliceContract(Operation *op,
+ TransitiveFilter backwardFilter,
+ TransitiveFilter forwardFilter) {
+ SetVector<Operation *> slice;
+ slice.insert(op);
+ unsigned currentIndex = 0;
+ SetVector<Operation *> backwardSlice;
+ SetVector<Operation *> forwardSlice;
+ while (currentIndex != slice.size()) {
+ auto *currentOp = (slice)[currentIndex];
+ // Compute and insert the backwardSlice starting from currentOp.
+ backwardSlice.clear();
+ getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
+ slice.insert(backwardSlice.begin(), backwardSlice.end());
+
+ // Compute and insert the forwardSlice starting from currentOp.
+ forwardSlice.clear();
+ // Special case for ForOp, we don't want to include the whole region but
+ // only the value using the region arguments.
+ // TODO: We should refine this to only care about the region arguments being
+ // converted to matrix type.
+ if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
+ for (Value forOpResult : forOp.getResults())
+ getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
+ for (BlockArgument &arg : forOp.getRegionIterArgs())
+ getForwardSlice(arg, &forwardSlice, forwardFilter);
+ } else {
+ getForwardSlice(currentOp, &forwardSlice, forwardFilter);
+ }
+ slice.insert(forwardSlice.begin(), forwardSlice.end());
+ ++currentIndex;
+ }
+ return slice;
+}
+
// Analyze slice of operations based on convert op to figure out if the whole
// slice can be converted to MMA operations.
static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
if (opToConvert.contains(contract.getOperation()))
return;
SetVector<Operation *> dependentOps =
- getSlice(contract, hasVectorDest, hasVectorSrc);
+ getSliceContract(contract, hasVectorDest, hasVectorSrc);
// If any instruction cannot use MMA matrix type drop the whole
- // chaine. MMA matrix are stored in an opaque type so they cannot be used
+ // chain. MMA matrix are stored in an opaque type so they cannot be used
// by all operations.
if (llvm::any_of(dependentOps,
[](Operation *op) { return !supportsMMaMatrixType(op); }))
return;
opToConvert.insert(dependentOps.begin(), dependentOps.end());
});
- return opToConvert;
+ // Sort the operations so that we can convert them in topological order.
+ return topologicalSort(opToConvert);
}
namespace {
assert(transferReadSupportsMMAMatrixType(op));
Optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
+ AffineMap map = op.permutation_map();
+ // Handle broadcast by setting the stride to 0.
+ if (map.getResult(0).isa<AffineConstantExpr>()) {
+ assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0);
+ stride = 0;
+ }
assert(stride);
const char *fragType = inferFragType(op);
gpu::MMAMatrixType type =
vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
return
}
+
+// CHECK-LABEL: func @matmul_fused_broadcast
+// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+// CHECK-DAG: %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise %[[D]], %[[E]] {operation = "DIVF"} : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
+ %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+ {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
+ : memref<16x16x16x16xf16>, vector<16x16xf16>
+ %F = arith.divf %D, %E : vector<16x16xf16>
+ vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+ return
+}