Fix Transpose Check in MMA.SYNC Path
authorManish Gupta <manigupta@google.com>
Thu, 6 Apr 2023 23:58:17 +0000 (23:58 +0000)
committerManish Gupta <manigupta@google.com>
Tue, 11 Apr 2023 00:02:20 +0000 (00:02 +0000)
Differential Revision: https://reviews.llvm.org/D147749

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

index 19860a1..10a6ee4 100644 (file)
@@ -650,6 +650,34 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
   return success();
 }
 
+/// Check if the loaded matrix operand requires transposed.
+/// Transposed Map Example:
+/// Example 1   : (..., d0, d1) -> (d1 * 1, d0 * 2)
+/// Example 2   : (d0, d1, d2, d3) -> (d3, d2)
+///
+/// The code below checks if the output 2D is transposed using a generalized
+/// version     : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
+/// Returns     : true; if m > n, false o.w.
+
+static bool isTransposed(vector::TransferReadOp op) {
+  mlir::AffineMap map = op.getPermutationMap();
+  if (map.getNumResults() != 2) {
+    op->emitError("Expected 2D transfer read");
+  }
+
+  // Output 2D matrix dimensions in the order of d0, d1.
+  auto dM = map.getResult(0);
+  auto dN = map.getResult(1);
+
+  //  Find the position of these expressions in the input.
+  auto exprM = dM.dyn_cast<AffineDimExpr>();
+  auto exprN = dN.dyn_cast<AffineDimExpr>();
+  if (!exprM || !exprN) {
+    op->emitError("Expected to find AffineDimExpr in vector::TransferReadOp");
+  }
+  return exprM.getPosition() > exprN.getPosition();
+}
+
 static LogicalResult
 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                              llvm::DenseMap<Value, Value> &valueMapping) {
@@ -671,9 +699,10 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
   }
 
-  FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
-      *warpMatrixInfo,
-      /*transpose=*/!op.getPermutationMap().isMinorIdentity());
+  FailureOr<nvgpu::LdMatrixParams> params =
+      nvgpu::getLdMatrixParams(*warpMatrixInfo,
+                               /*transpose=*/isTransposed(op));
+
   if (failed(params)) {
     LLVM_DEBUG(
         DBGS()
@@ -700,7 +729,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                                          indices);
   nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
       loc, vectorType, op.getSource(), indices,
-      !op.getPermutationMap().isMinorIdentity(), params->numTiles);
+      /*transpose=*/isTransposed(op), params->numTiles);
   valueMapping[op] = newOp->getResult(0);
   return success();
 }