From 8c80d01a95ba0d75c29191de0ea38cce48c9978f Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 19 Jun 2023 13:33:30 +0000 Subject: [PATCH] [mlir][NVGPU] NFC - Add a more convenient C++ builder for nvgpu::MmaSyncOp --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 10 +++++++--- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 9 +++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index 5bb02b0..e595e9d 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -158,8 +158,7 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> { AnyVector:$matrixB, AnyVector:$matrixC, I64ArrayAttr:$mmaShape, - OptionalAttr:$tf32Enabled - ); + OptionalAttr:$tf32Enabled); let results = (outs AnyVector:$res); @@ -167,7 +166,12 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> { OpBuilder<(ins "Value":$matrixA, "Value":$matrixB, "Value":$matrixC, - "ArrayAttr":$mmaShape)> + "ArrayAttr":$mmaShape)>, + OpBuilder<(ins "Value":$matrixA, + "Value":$matrixB, + "Value":$matrixC, + "ArrayRef":$mmaShape, + CArg<"bool", "false">:$tf32Enabled)> ]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 77c853a..0472d27 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -96,6 +96,15 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, mmaShape, UnitAttr()); } +void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, Value matrixA, + Value matrixB, Value matrixC, ArrayRef mmaShape, + bool tf32Enabled) { + build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, + odsBuilder.getI64ArrayAttr(mmaShape), + tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr()); +} + /// Performs verification for MmaSyncOp and MmaSparseSyncOp. static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue matrixA, -- 2.7.4