From: Stanley Winata Date: Mon, 19 Sep 2022 17:14:58 +0000 (-0400) Subject: [mlir][spirv] Support OpenCL when lowering memref load/store X-Git-Tag: upstream/17.0.6~33132 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9c3a73a579ca71529fa11dc0e5acee22500e4d10;p=platform%2Fupstream%2Fllvm.git [mlir][spirv] Support OpenCL when lowering memref load/store -Add awareness to Kernel vs Shader capability for memref to SPIR-V lowering. -Add lowering using spv.PtrAccessChain for Kernel capability. -Enable lowering from scalar pointee types for kernel capabilities. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D132714 --- diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index ff1cdac..9b480f6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Transforms/DialectConversion.h" @@ -72,6 +73,9 @@ public: /// Returns the options controlling the SPIR-V type converter. const SPIRVConversionOptions &getOptions() const { return options; } + /// Checks if the SPIR-V capability inquired is supported. + bool allows(spirv::Capability capability); + private: spirv::TargetEnv targetEnv; SPIRVConversionOptions options; @@ -151,10 +155,19 @@ Value linearizeIndex(ValueRange indices, ArrayRef strides, // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap // that has static strides. Extend to handle dynamic strides. -spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, - MemRefType baseType, Value basePtr, - ValueRange indices, Location loc, - OpBuilder &builder); +Value getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, + Value basePtr, ValueRange indices, Location loc, + OpBuilder &builder); + +// GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V. +Value getOpenCLElementPtr(SPIRVTypeConverter &typeConverter, + MemRefType baseType, Value basePtr, + ValueRange indices, Location loc, OpBuilder &builder); + +// GetElementPtr implementation for Vulkan/Shader flavored SPIR-V. +Value getVulkanElementPtr(SPIRVTypeConverter &typeConverter, + MemRefType baseType, Value basePtr, + ValueRange indices, Location loc, OpBuilder &builder); } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index d72802b..766d42b 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -192,7 +192,7 @@ public: ConversionPatternRewriter &rewriter) const override; }; -/// Converts memref.load to spv.Load. +/// Converts memref.load to spv.Load + spv.AccessChain on integers. class IntLoadOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -202,7 +202,7 @@ public: ConversionPatternRewriter &rewriter) const override; }; -/// Converts memref.load to spv.Load. +/// Converts memref.load to spv.Load + spv.AccessChain. class LoadOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -319,11 +319,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, return failure(); auto &typeConverter = *getTypeConverter(); - spirv::AccessChainOp accessChainOp = + Value accessChain = spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), adaptor.getIndices(), loc, rewriter); - if (!accessChainOp) + if (!accessChain) return failure(); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); @@ -333,27 +333,41 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, Type pointeeType = typeConverter.convertType(memrefType) .cast() .getPointeeType(); - Type structElemType = pointeeType.cast().getElementType(0); Type dstType; - if (auto arrayType = structElemType.dyn_cast()) - dstType = arrayType.getElementType(); - else - dstType = structElemType.cast().getElementType(); - + if (typeConverter.allows(spirv::Capability::Kernel)) { + // For OpenCL Kernel, pointer will be directly pointing to the element. + dstType = pointeeType; + } else { + // For Vulkan we need to extract element from wrapping struct and array. + Type structElemType = + pointeeType.cast().getElementType(0); + if (auto arrayType = structElemType.dyn_cast()) + dstType = arrayType.getElementType(); + else + dstType = structElemType.cast().getElementType(); + } int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); // If the rewrited load op has the same bit width, use the loading value // directly. if (srcBits == dstBits) { - Value loadVal = - rewriter.create(loc, accessChainOp.getResult()); + Value loadVal = rewriter.create(loc, accessChain); if (isBool) loadVal = castIntNToBool(loc, loadVal, rewriter); rewriter.replaceOp(loadOp, loadVal); return success(); } + // Bitcasting is currently unsupported for Kernel capability / + // spv.PtrAccessChain. + if (typeConverter.allows(spirv::Capability::Kernel)) + return failure(); + + auto accessChainOp = accessChain.getDefiningOp(); + if (!accessChainOp) + return failure(); + // Assume that getElementPtr() works linearizely. If it's a scalar, the method // still returns a linearized accessing. If the accessing is not linearized, // there will be offset issues. @@ -432,11 +446,11 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, auto loc = storeOp.getLoc(); auto &typeConverter = *getTypeConverter(); - spirv::AccessChainOp accessChainOp = + Value accessChain = spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), adaptor.getIndices(), loc, rewriter); - if (!accessChainOp) + if (!accessChain) return failure(); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); @@ -448,12 +462,19 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, Type pointeeType = typeConverter.convertType(memrefType) .cast() .getPointeeType(); - Type structElemType = pointeeType.cast().getElementType(0); Type dstType; - if (auto arrayType = structElemType.dyn_cast()) - dstType = arrayType.getElementType(); - else - dstType = structElemType.cast().getElementType(); + if (typeConverter.allows(spirv::Capability::Kernel)) { + // For OpenCL Kernel, pointer will be directly pointing to the element. + dstType = pointeeType; + } else { + // For Vulkan we need to extract element from wrapping struct and array. + Type structElemType = + pointeeType.cast().getElementType(0); + if (auto arrayType = structElemType.dyn_cast()) + dstType = arrayType.getElementType(); + else + dstType = structElemType.cast().getElementType(); + } int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); @@ -462,11 +483,19 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, Value storeVal = adaptor.getValue(); if (isBool) storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); - rewriter.replaceOpWithNewOp( - storeOp, accessChainOp.getResult(), storeVal); + rewriter.replaceOpWithNewOp(storeOp, accessChain, storeVal); return success(); } + // Bitcasting is currently unsupported for Kernel capability / + // spv.PtrAccessChain. + if (typeConverter.allows(spirv::Capability::Kernel)) + return failure(); + + auto accessChainOp = accessChain.getDefiningOp(); + if (!accessChainOp) + return failure(); + // Since there are multi threads in the processing, the emulation will be done // with atomic operations. E.g., if the storing value is i8, rewrite the // StoreOp to diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 8d0bde6..b56b8c0 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -122,6 +122,10 @@ MLIRContext *SPIRVTypeConverter::getContext() const { return targetEnv.getAttr().getContext(); } +bool SPIRVTypeConverter::allows(spirv::Capability capability) { + return targetEnv.allows(capability); +} + // TODO: This is a utility function that should probably be exposed by the // SPIR-V dialect. Keeping it local till the use case arises. static Optional getTypeNumBytes(const SPIRVConversionOptions &options, @@ -334,6 +338,12 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, return nullptr; } + // For OpenCL Kernel we can just emit a pointer pointing to the element. + if (targetEnv.allows(spirv::Capability::Kernel)) + return spirv::PointerType::get(arrayElemType, storageClass); + + // For Vulkan we need extra wrapping struct and array to satisfy interface + // needs. if (!type.hasStaticShape()) { int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); @@ -393,6 +403,12 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, return nullptr; } + // For OpenCL Kernel we can just emit a pointer pointing to the element. + if (targetEnv.allows(spirv::Capability::Kernel)) + return spirv::PointerType::get(arrayElemType, storageClass); + + // For Vulkan we need extra wrapping struct and array to satisfy interface + // needs. if (!type.hasStaticShape()) { int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); @@ -712,9 +728,10 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef strides, return linearizedIndex; } -spirv::AccessChainOp mlir::spirv::getElementPtr( - SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, - ValueRange indices, Location loc, OpBuilder &builder) { +Value mlir::spirv::getVulkanElementPtr(SPIRVTypeConverter &typeConverter, + MemRefType baseType, Value basePtr, + ValueRange indices, Location loc, + OpBuilder &builder) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; @@ -742,6 +759,50 @@ spirv::AccessChainOp mlir::spirv::getElementPtr( return builder.create(loc, basePtr, linearizedIndices); } +Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter, + MemRefType baseType, Value basePtr, + ValueRange indices, Location loc, + OpBuilder &builder) { + // Get base and offset of the MemRefType and verify they are static. + + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(baseType, strides, offset)) || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || + offset == MemRefType::getDynamicStrideOrOffset()) { + return nullptr; + } + + auto indexType = typeConverter.getIndexType(); + + SmallVector linearizedIndices; + auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); + + Value linearIndex; + if (baseType.getRank() == 0) { + linearIndex = zero; + } else { + linearIndex = + linearizeIndex(indices, strides, offset, indexType, loc, builder); + } + return builder.create(loc, basePtr, linearIndex, + linearizedIndices); +} + +Value mlir::spirv::getElementPtr(SPIRVTypeConverter &typeConverter, + MemRefType baseType, Value basePtr, + ValueRange indices, Location loc, + OpBuilder &builder) { + + if (typeConverter.allows(spirv::Capability::Kernel)) { + return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc, + builder); + } + + return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc, + builder); +} + //===----------------------------------------------------------------------===// // SPIR-V ConversionTarget //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir index ac22d13..5413157 100644 --- a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir @@ -9,7 +9,7 @@ module attributes { // CHECK: spv.func // CHECK-SAME: {{%.*}}: f32 // CHECK-NOT: spv.interface_var_abi - // CHECK-SAME: {{%.*}}: !spv.ptr)>, CrossWorkgroup> + // CHECK-SAME: {{%.*}}: !spv.ptr // CHECK-NOT: spv.interface_var_abi // CHECK-SAME: spv.entry_point_abi = #spv.entry_point_abi : vector<3xi32>> gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class>) kernel diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 212363c..98e44b9 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -109,6 +109,111 @@ func.func @store_i1(%dst: memref<4xi1, #spv.storage_class>, %i: i // ----- +// Check for Kernel capability, that with proper compute and storage extensions, we don't need to +// perform special tricks. + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, #spv.resource_limits<>> +} { + +// CHECK-LABEL: @load_store_zero_rank_float +func.func @load_store_zero_rank_float(%arg0: memref>, %arg1: memref>) { + // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.PtrAccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "CrossWorkgroup" %{{.*}} : f32 + %0 = memref.load %arg0[] : memref> + // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.PtrAccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "CrossWorkgroup" %{{.*}} : f32 + memref.store %0, %arg1[] : memref> + return +} + +// CHECK-LABEL: @load_store_zero_rank_int +func.func @load_store_zero_rank_int(%arg0: memref>, %arg1: memref>) { + // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.PtrAccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "CrossWorkgroup" %{{.*}} : i32 + %0 = memref.load %arg0[] : memref> + // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.PtrAccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "CrossWorkgroup" %{{.*}} : i32 + memref.store %0, %arg1[] : memref> + return +} + +// CHECK-LABEL: func @load_store_unknown_dim +func.func @load_store_unknown_dim(%i: index, %source: memref>, %dest: memref>) { + // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: %[[AC0:.+]] = spv.PtrAccessChain %[[SRC]] + // CHECK: spv.Load "CrossWorkgroup" %[[AC0]] + %0 = memref.load %source[%i] : memref> + // CHECK: %[[AC1:.+]] = spv.PtrAccessChain %[[DST]] + // CHECK: spv.Store "CrossWorkgroup" %[[AC1]] + memref.store %0, %dest[%i]: memref> + return +} + +// CHECK-LABEL: func @load_i1 +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spv.storage_class>, %[[IDX:.+]]: index) +func.func @load_i1(%src: memref<4xi1, #spv.storage_class>, %i : index) -> i1 { + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spv.storage_class> to !spv.ptr + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32 + // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32 + // CHECK: %[[ADDR:.+]] = spv.PtrAccessChain %[[SRC_CAST]][%[[ADD]]] + // CHECK: %[[VAL:.+]] = spv.Load "CrossWorkgroup" %[[ADDR]] : i8 + // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8 + // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8 + %0 = memref.load %src[%i] : memref<4xi1, #spv.storage_class> + // CHECK: return %[[BOOL]] + return %0: i1 +} + +// CHECK-LABEL: func @store_i1 +// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spv.storage_class>, +// CHECK-SAME: %[[IDX:.+]]: index +func.func @store_i1(%dst: memref<4xi1, #spv.storage_class>, %i: index) { + %true = arith.constant true + // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spv.storage_class> to !spv.ptr + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32 + // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32 + // CHECK: %[[ADDR:.+]] = spv.PtrAccessChain %[[DST_CAST]][%[[ADD]]] + // CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8 + // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8 + // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8 + // CHECK: spv.Store "CrossWorkgroup" %[[ADDR]], %[[RES]] : i8 + memref.store %true, %dst[%i]: memref<4xi1, #spv.storage_class> + return +} + +} // end module + +// ----- + // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. // TODO: Test i64 types.