Add a test for lowering GPU ops that cover cases where the symbol table isn't held...
authorMehdi Amini <aminim@google.com>
Thu, 31 Oct 2019 17:34:39 +0000 (10:34 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 31 Oct 2019 17:35:15 +0000 (10:35 -0700)
PiperOrigin-RevId: 277752004

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

index 205f1e6..6ad6bec 100644 (file)
@@ -100,3 +100,25 @@ module attributes {gpu.kernel_module} {
     std.return
   }
 }
+
+// -----
+
+// Test that we handled properly operation with SymbolTable other than module op
+module attributes {gpu.kernel_module} {
+  "test.symbol_scope"() ({
+  // CHECK: test.symbol_scope
+  // CHECK: llvm.func @__nv_expf(!llvm.float) -> !llvm.float
+  // CHECK: llvm.func @__nv_exp(!llvm.double) -> !llvm.double
+  // CHECK-LABEL: func @gpu_exp
+    func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) {
+      %exp_f32 = std.exp %arg_f32 : f32
+      // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
+      %result_f32 = std.exp %exp_f32 : f32
+      // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
+      %result64 = std.exp %arg_f64 : f64
+      // CHECK: llvm.call @__nv_exp(%{{.*}}) : (!llvm.double) -> !llvm.double
+      std.return
+    }
+    "test.finish" () : () -> ()
+  }) : () -> ()
+}
index daa17ff..ccb931c 100644 (file)
@@ -52,3 +52,26 @@ module attributes {gpu.kernel_module} {
     std.return
   }
 }
+
+
+// -----
+
+// Test that we handled properly operation with SymbolTable other than module op
+module attributes {gpu.kernel_module} {
+  "test.symbol_scope"() ({
+    // CHECK: test.symbol_scope
+    // CHECK: llvm.func @_ocml_exp_f32(!llvm.float) -> !llvm.float
+    // CHECK: llvm.func @_ocml_exp_f64(!llvm.double) -> !llvm.double
+    // CHECK-LABEL: func @gpu_exp
+    func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) {
+      %exp_f32 = std.exp %arg_f32 : f32
+      // CHECK: llvm.call @_ocml_exp_f32(%{{.*}}) : (!llvm.float) -> !llvm.float
+      %result_f32 = std.exp %exp_f32 : f32
+      // CHECK: llvm.call @_ocml_exp_f32(%{{.*}}) : (!llvm.float) -> !llvm.float
+      %result64 = std.exp %arg_f64 : f64
+      // CHECK: llvm.call @_ocml_exp_f64(%{{.*}}) : (!llvm.double) -> !llvm.double
+      std.return
+    }
+    "test.finish" () : () -> ()
+  }) : () -> ()
+}