[mlir][NVGPU] nvgpu.mmasync on F32 through TF32
authorManish Gupta <manigupta@google.com>
Mon, 1 Aug 2022 23:06:23 +0000 (23:06 +0000)
committerThomas Raoux <thomasraoux@google.com>
Mon, 1 Aug 2022 23:23:27 +0000 (23:23 +0000)
Adds optional attribute to support tensor cores on F32 datatype by lowering to `mma.sync` with TF32 operands. Since, TF32 is not a native datatype in LLVM we are adding `tf32Enabled` as an attribute to allow the IR to be aware of `MmaSyncOp` datatype. Additionally, this patch adds placeholders for nvgpu-to-nvgpu transformation targeting higher precision tf32x3.

For mma.sync on f32 input using tensor cores there are two possibilites:
(a) tf32   (1 `mma.sync` per warp-level matrix-multiply-accumulate)
(b) tf32x3 (3 `mma.sync` per warp-level matrix-multiply-accumulate)

Typically, tf32 tensor core acceleration comes at a cost of accuracy from missing precision bits. While f32 has 23 precision bits, tf32 has only 10 precision bits. tf32x3 aims to recover the precision bits by splitting each operand into two tf32 values and issue three `mma.sync` tensor core operations.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D130294

16 files changed:
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt
mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp [new file with mode: 0644]
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Dialect/NVGPU/invalid.mlir
mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir [new file with mode: 0644]
mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/CMakeLists.txt
mlir/test/lib/Dialect/NVGPU/CMakeLists.txt [new file with mode: 0644]
mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp

index 737752b..d0dd5a6 100644 (file)
@@ -110,11 +110,22 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
     (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   ```
   }];
-  let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB,
-                       AnyVector:$matrixC, I64ArrayAttr:$mmaShape);
+  let arguments = (ins AnyVector:$matrixA, 
+                       AnyVector:$matrixB,
+                       AnyVector:$matrixC, 
+                       I64ArrayAttr:$mmaShape,
+                       OptionalAttr<UnitAttr>:$tf32Enabled
+                       );
 
   let results = (outs AnyVector:$res);
 
+  let builders = [
+    OpBuilder<(ins "Value":$matrixA, 
+                   "Value":$matrixB, 
+                   "Value":$matrixC, 
+                   "ArrayAttr":$mmaShape)>
+  ];
+
   let assemblyFormat = [{
     `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
     `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
index abc910d..32888f1 100644 (file)
 namespace mlir {
 namespace nvgpu {
 
+///
+/// Passes
+///
+
 /// Optimizes vectorized accesses to a shared memory buffer specified by
 /// memrefValue. This transformation assumes the following:
 /// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
@@ -41,6 +45,29 @@ namespace nvgpu {
 mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                        Value memrefValue);
 
+///
+/// Rewrites patterns
+///
+
+//===----------------------------------------------------------------------===//
+// NVGPU transformation options exposed as auxiliary structs.
+//===----------------------------------------------------------------------===//
+/// Enum to control the lowering of `nvgpu.mmasync`.
+enum class MmaSyncF32Lowering { TF32 = 0, TF32x3 = 1, Unkown = 2 };
+
+/// Collect patterns to convert mma.sync on f32 input and rewrite
+/// to use tensor cores with user provided level of accuracy:
+/// (a) tf32   (1 mma.sync per warp-level matrix-multiply-accumulate)
+/// (b) tf32x3 (3 mma.sync per warp-level matrix-multiply-accumulate)
+/// Typically, tf32 tensor core acceleration comes at a cost
+/// of accuracy from missing precision bits. While f32 has 23 precision
+/// bits, tf32 has only 10 precision bits. tf32x3 aims to recover the
+/// precision bits by spliting each operand into two tf32 values
+/// and issue three mma.sync tensor core operations.
+void populateMmaSyncF32ToTF32Patterns(
+    RewritePatternSet &patterns,
+    nvgpu::MmaSyncF32Lowering precision = nvgpu::MmaSyncF32Lowering::TF32);
+
 } // namespace nvgpu
 } // namespace mlir
 
index 682a0d4..41f0877 100644 (file)
@@ -275,10 +275,14 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
     NVVM::MMATypes ptxTypeB;
     Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
         cType.getElementType(), /*isAccumulator=*/true);
-    if (!ptxTypeC) {
+    if (!ptxTypeC)
       return op->emitError(
           "could not infer the PTX type for the accumulator/result");
-    }
+
+    // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
+    bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
+    if (aType.getElementType().isF32() && !tf32Enabled)
+      return failure();
 
     Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
     if (aType.getElementType().isInteger(8)) {
index 6691193..2675882 100644 (file)
@@ -687,8 +687,8 @@ convertContractOpToMmaSync(vector::ContractionOp op,
   int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
   int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
   int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
-  Value matmul = b.create<nvgpu::MmaSyncOp>(
-      op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k}));
+  Value matmul = b.create<nvgpu::MmaSyncOp>(op.getLoc(), opA, opB, opC,
+                                            b.getI64ArrayAttr({m, n, k}));
   valueMapping[op.getResult()] = matmul;
   return success();
 }
index 1ced011..8580a84 100644 (file)
@@ -91,6 +91,12 @@ LogicalResult DeviceAsyncCopyOp::verify() {
 //===----------------------------------------------------------------------===//
 // NVGPU_MmaSyncOp
 //===----------------------------------------------------------------------===//
+void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
+                      ::mlir::OperationState &odsState, Value matrixA,
+                      Value matrixB, Value matrixC, ArrayAttr mmaShape) {
+  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
+        mmaShape, UnitAttr());
+}
 
 LogicalResult MmaSyncOp::verify() {
 
@@ -122,6 +128,9 @@ LogicalResult MmaSyncOp::verify() {
   // vector element type
   Type aType = aVector.getElementType();
 
+  // tensor float32 (TF32) enabled
+  bool tf32Enabled = getOperation()->hasAttr(getTf32EnabledAttrName());
+
   // nvgpu.mma.sync shape (per 32 threads or per warp)
   int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
   int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
@@ -163,6 +172,10 @@ LogicalResult MmaSyncOp::verify() {
     return emitOpError() << "expected " << m * n
                          << " warp-wide matrix C elements";
 
+  // verify tf32 tensor cores are enabled for only F32 datatype
+  if (tf32Enabled && !(aType.isF32()))
+    return emitOpError() << "expected tf32 tensor cores only for F32 operands";
+
   //
   // Extended verification
   //
index 831f396..afe7d16 100644 (file)
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRNVGPUTransforms
-  OptimizeSharedMemory.cpp  
+  OptimizeSharedMemory.cpp
+  MmaSyncTF32Transform.cpp  
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
new file mode 100644 (file)
index 0000000..4ef93b3
--- /dev/null
@@ -0,0 +1,73 @@
+//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
+// operations on f32 input datatype
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/NVGPU/Passes.h"
+#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::nvgpu;
+
+namespace {
+
+struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
+
+  using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
+
+  MmaSyncF32ToTF32Pattern(MLIRContext *context,
+                          nvgpu::MmaSyncF32Lowering precision)
+      : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
+        precision(precision) {}
+
+  LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
+                                PatternRewriter &rewrite) const override {
+    Location location = op->getLoc();
+
+    if (op->hasAttr(op.getTf32EnabledAttrName()))
+      return failure();
+
+    if (precision == MmaSyncF32Lowering::Unkown)
+      return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with "
+                                 "unknown precision level");
+
+    if (precision == MmaSyncF32Lowering::TF32x3)
+      return emitError(location, "TF32x3 is not supported at the moment "
+                                 "for nvgpu.mma.sync on f32 datatype");
+
+    if (precision == MmaSyncF32Lowering::TF32)
+      op.setTf32EnabledAttr(rewrite.getUnitAttr());
+
+    return success();
+  }
+
+private:
+  /// Precision for F32 Tensor Cores (TF32 or TF32x3)
+  nvgpu::MmaSyncF32Lowering precision;
+};
+
+} // namespace
+
+void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
+    RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
+
+  patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
+}
index 55b8df6..aa71a26 100644 (file)
@@ -219,7 +219,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
   // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>  
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
   // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
index 5f1894f..7fc8410 100644 (file)
@@ -76,6 +76,13 @@ func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1:
 }
 // -----
 
+func.func @m16n8k16_fp16_tf32Enabled(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{expected tf32 tensor cores only for F32 operands}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16], tf32Enabled} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  return %d : vector<2x2xf16>
+}
+// -----
+
 func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
   // expected-error @+1 {{expected 128 warp-wide matrix A elements}}
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    
diff --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
new file mode 100644 (file)
index 0000000..a8c7226
--- /dev/null
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32" -split-input-file | FileCheck %s
+
+// CHECK-LABEL: m16n8k4_tf32
+func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {  
+  // CHECK: nvgpu.mma.sync
+  // CHECK-SAME: tf32Enabled
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
+  return %d : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: m16n8k8_tf32
+func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  // CHECK: nvgpu.mma.sync
+  // CHECK-SAME: tf32Enabled
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    
+  return %d : vector<2x2xf32>
+}
+// -----
diff --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
new file mode 100644 (file)
index 0000000..523ba24
--- /dev/null
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32x3" -split-input-file | FileCheck %s
+
+// CHECK-LABEL: m16n8k4_tf32
+func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {  
+  // expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
+  return %d : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: m16n8k8_tf32
+func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  // expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    
+  return %d : vector<2x2xf32>
+}
+// -----
index 7c8d1a7..6bc8635 100644 (file)
@@ -5,6 +5,7 @@ add_subdirectory(GPU)
 add_subdirectory(Linalg)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
+add_subdirectory(NVGPU)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
 add_subdirectory(SPIRV)
diff --git a/mlir/test/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/test/lib/Dialect/NVGPU/CMakeLists.txt
new file mode 100644 (file)
index 0000000..2fe031d
--- /dev/null
@@ -0,0 +1,21 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRNVGPUTestPasses
+  TestNVGPUTransforms.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRAffineDialect
+  MLIRAnalysis
+  MLIRFuncDialect
+  MLIRGPUOps
+  MLIRLLVMDialect
+  MLIRMemRefDialect
+  MLIRNVGPUDialect
+  MLIRNVGPUTransforms
+  MLIRPass
+  MLIRSCFDialect
+  MLIRTransformUtils
+  )
+  
diff --git a/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp b/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp
new file mode 100644 (file)
index 0000000..74a15ba
--- /dev/null
@@ -0,0 +1,76 @@
+//===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::nvgpu;
+
+namespace {
+
+struct TestMmaSyncF32ToTF32Patterns
+    : public PassWrapper<TestMmaSyncF32ToTF32Patterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns)
+
+  StringRef getArgument() const final {
+    return "test-nvgpu-mmasync-f32-to-tf32-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns to convert mma.sync on f32 with tf32 precision";
+  }
+  TestMmaSyncF32ToTF32Patterns() = default;
+  TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass)
+      : PassWrapper(pass) {}
+
+  Option<std::string> precision{
+      *this, "precision",
+      llvm::cl::desc(
+          "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"),
+      llvm::cl::init("tf32")};
+
+  MmaSyncF32Lowering tf32Precision =
+      llvm::StringSwitch<MmaSyncF32Lowering>(precision)
+          .Case("tf32", MmaSyncF32Lowering::TF32)
+          .Case("tf32x3", MmaSyncF32Lowering::TF32x3)
+          .Default(MmaSyncF32Lowering::Unkown);
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+
+    populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestNvgpuLowerings() {
+  PassRegistration<TestMmaSyncF32ToTF32Patterns>();
+}
+
+} // namespace test
+} // namespace mlir
\ No newline at end of file
index 87acd73..59036b1 100644 (file)
@@ -20,6 +20,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRLinalgTestPasses
     MLIRMathTestPasses
     MLIRMemRefTestPasses
+    MLIRNVGPUTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses
     MLIRSPIRVTestPasses
index 3d48ec2..63c028c 100644 (file)
@@ -113,6 +113,7 @@ void registerTestTensorTransforms();
 void registerTestTilingInterface();
 void registerTestTransformDialectInterpreterPass();
 void registerTestVectorLowerings();
+void registerTestNvgpuLowerings();
 } // namespace test
 } // namespace mlir
 
@@ -208,6 +209,7 @@ void registerTestPasses() {
   mlir::test::registerTestTilingInterface();
   mlir::test::registerTestTransformDialectInterpreterPass();
   mlir::test::registerTestVectorLowerings();
+  mlir::test::registerTestNvgpuLowerings();
 }
 #endif