From ba49096817b77932ed1534ab1fb323b46944293c Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 15 Apr 2020 08:42:28 -0400 Subject: [PATCH] [mlir][spirv] Lower memref with dynamic dimensions to runtime arrays memref types with dynamic dimensions do not have a compile-time known size. They should be mapped to SPIR-V runtime array types. Differential Revision: https://reviews.llvm.org/D78197 --- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 35 +++++++++++++--------- .../{std-to-spirv.mlir => std-ops-to-spirv.mlir} | 0 .../StandardToSPIRV/std-types-to-spirv.mlir | 10 +++++-- 3 files changed, 28 insertions(+), 17 deletions(-) rename mlir/test/Conversion/StandardToSPIRV/{std-to-spirv.mlir => std-ops-to-spirv.mlir} (100%) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 0f83412..c6e15d5 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -331,10 +331,11 @@ static Optional convertTensorType(const spirv::TargetEnv &targetEnv, static Optional convertMemrefType(const spirv::TargetEnv &targetEnv, MemRefType type) { - // TODO(ravishankarm) : Handle dynamic shapes. - if (!type.hasStaticShape()) { + Optional storageClass = + SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); + if (!storageClass) { LLVM_DEBUG(llvm::dbgs() - << type << " illegal: dynamic shape unimplemented\n"); + << type << " illegal: cannot convert memory space\n"); return llvm::None; } @@ -345,27 +346,33 @@ static Optional convertMemrefType(const spirv::TargetEnv &targetEnv, return llvm::None; } + auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); + if (!arrayElemType) + return llvm::None; + Optional scalarSize = getTypeNumBytes(scalarType); - Optional memrefSize = getTypeNumBytes(type); - if (!scalarSize || !memrefSize) { + if (!scalarSize) { LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot deduce element count\n"); + << type << " illegal: cannot deduce element size\n"); return llvm::None; } - auto arrayElemCount = *memrefSize / *scalarSize; + if (!type.hasStaticShape()) { + auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize); + // Wrap in a struct to satisfy Vulkan interface requirements. + auto structType = spirv::StructType::get(arrayType, 0); + return spirv::PointerType::get(structType, *storageClass); + } - auto storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); - if (!storageClass) { + Optional memrefSize = getTypeNumBytes(type); + if (!memrefSize) { LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot convert memory space\n"); + << type << " illegal: cannot deduce element count\n"); return llvm::None; } - auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); - if (!arrayElemType) - return llvm::None; + auto arrayElemCount = *memrefSize / *scalarSize; + Optional arrayElemSize = getTypeNumBytes(*arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir similarity index 100% rename from mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir rename to mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir index 89d1fe2..b98a20a 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -486,7 +486,7 @@ func @memref_offset_strides( // ----- -// Check that dynamic shapes are not supported. +// Dynamic shapes module attributes { spv.target_env = #spv.target_env< #spv.vce, @@ -494,13 +494,17 @@ module attributes { max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> } { +// Check that unranked shapes are not supported. // CHECK-LABEL: func @unranked_memref // CHECK-SAME: memref<*xi32> func @unranked_memref(%arg0: memref<*xi32>) { return } // CHECK-LABEL: func @dynamic_dim_memref -// CHECK-SAME: memref<8x?xi32> -func @dynamic_dim_memref(%arg0: memref<8x?xi32>) { return } +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +func @dynamic_dim_memref(%arg0: memref<8x?xi32>, + %arg1: memref) +{ return } } // end module -- 2.7.4