From 1a572f4509a6fb392e87b7ea0346344bf6b8ac66 Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Thu, 18 Mar 2021 12:59:49 -0700 Subject: [PATCH] [mlir] Add vector op support to cuda-runner including vector.print Differential Revision: https://reviews.llvm.org/D97346 --- mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp | 2 ++ mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp index 44dfd73..0e3bf16 100644 --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -18,6 +18,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" @@ -313,6 +314,7 @@ void GpuToLLVMConversionPass::runOnOperation() { OwningRewritePatternList patterns; LLVMConversionTarget target(getContext()); + populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); populateAsyncStructuralTypeConversionsAndLegality(&getContext(), converter, patterns, target); diff --git a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir index ec9720f..c4ad897 100644 --- a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir +++ b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir @@ -21,7 +21,10 @@ func @other_func(%arg0 : f32, %arg1 : memref) { } // CHECK: [1, 1, 1, 1, 1] +// CHECK: ( 1, 1 ) func @main() { + %v0 = constant 0.0 : f32 + %c0 = constant 0: index %arg0 = memref.alloc() : memref<5xf32> %21 = constant 5 : i32 %22 = memref.cast %arg0 : memref<5xf32> to memref @@ -31,6 +34,8 @@ func @main() { %24 = constant 1.0 : f32 call @other_func(%24, %22) : (f32, memref) -> () call @print_memref_f32(%23) : (memref<*xf32>) -> () + %val1 = vector.transfer_read %arg0[%c0], %v0: memref<5xf32>, vector<2xf32> + vector.print %val1: vector<2xf32> return } -- 2.7.4