Use LLVM_Type instead of AnyType in the definition of LLVM_CallOp
authorAlex Zinenko <zinenko@google.com>
Mon, 21 Oct 2019 21:11:50 +0000 (14:11 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 21 Oct 2019 21:12:19 +0000 (14:12 -0700)
The type constraint had to be relaxed due to the order of lowering passes in
the examples, that since has been fixed. The relaxed version was still used by
the CUDA lowering for launch sizes of `index` type. This is not necessary since
the GPU dialect does not restrict the type of the launch size operands. Use an
LLVM type instead and restore the check in the LLVM_CallOp definition.

PiperOrigin-RevId: 275920109

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir

index c51747d..41fcae6 100644 (file)
@@ -323,9 +323,7 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">;
 // Call-related operations.
 def LLVM_CallOp : LLVM_Op<"call">,
                   Arguments<(ins OptionalAttr<SymbolRefAttr>:$callee,
-                             // TODO(b/133216756): fix test failure and
-                             // change to LLVM_Type
-                             Variadic<AnyType>)>,
+                             Variadic<LLVM_Type>)>,
                   Results<(outs Variadic<LLVM_Type>)>,
                   LLVM_TwoBuilders<LLVM_OneResultOpBuilder,
                                    LLVM_ZeroResultOpBuilder> {
index 7647d29..2088606 100644 (file)
@@ -13,7 +13,7 @@ module attributes {gpu.container_module} {
   llvm.func @foo() {
     %0 = "op"() : () -> (!llvm.float)
     %1 = "op"() : () -> (!llvm<"float*">)
-    %cst = constant 8 : index
+    %cst = llvm.mlir.constant(8 : index) : !llvm.i64
 
     // CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]]
     // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index)
@@ -27,7 +27,7 @@ module attributes {gpu.container_module} {
     // CHECK: llvm.call @mcuLaunchKernel
     // CHECK: llvm.call @mcuStreamSynchronize
     "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = "kernel", kernel_module = @kernel_module }
-        : (index, index, index, index, index, index, !llvm.float, !llvm<"float*">) -> ()
+        : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.float, !llvm<"float*">) -> ()
 
     llvm.return
   }