PatternRewriter &rewrite) const override {
Location location = op->getLoc();
- if (op->hasAttr(op.getTf32EnabledAttrName()))
+ if (op->hasAttr(op.getTf32EnabledAttrName()) ||
+ !op.getMatrixA().getType().cast<VectorType>().getElementType().isF32())
return failure();
if (precision == MmaSyncF32Lowering::Unkown)
return %d : vector<2x2xf32>
}
// -----
+
+// Negative test for non f32 case.
+// CHECK-LABEL: mma_sync_f16
+// CHECK-NOT: tf32Enabled
+// CHECK: return
+func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}