[mlir][nvpu] Prevent F32ToTF32 pattern to generate illegal IR
authorThomas Raoux <thomasraoux@google.com>
Mon, 15 Aug 2022 16:16:46 +0000 (16:16 +0000)
committerThomas Raoux <thomasraoux@google.com>
Mon, 15 Aug 2022 16:46:18 +0000 (16:46 +0000)
We shouldn't apply this pattern to non F32->F32 mma.sync operations.

Differential Revision: https://reviews.llvm.org/D131902

mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir

index 4ef93b3..d24001c 100644 (file)
@@ -42,7 +42,8 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
                                 PatternRewriter &rewrite) const override {
     Location location = op->getLoc();
 
-    if (op->hasAttr(op.getTf32EnabledAttrName()))
+    if (op->hasAttr(op.getTf32EnabledAttrName()) ||
+        !op.getMatrixA().getType().cast<VectorType>().getElementType().isF32())
       return failure();
 
     if (precision == MmaSyncF32Lowering::Unkown)
index a8c7226..80de11f 100644 (file)
@@ -18,3 +18,12 @@ func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: v
   return %d : vector<2x2xf32>
 }
 // -----
+
+// Negative test for non f32 case.
+// CHECK-LABEL: mma_sync_f16
+//   CHECK-NOT: tf32Enabled
+//       CHECK: return
+func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}