From: Butygin Date: Wed, 28 Jul 2021 19:31:26 +0000 (+0300) Subject: [mlir][spirv] Fix crash in convert-gpu-to-spirv pass with memrefs with affine maps X-Git-Tag: upstream/15.0.7~35114 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1e9799e204ff9eaa2160304e6a139c2faa850d33;p=platform%2Fupstream%2Fllvm.git [mlir][spirv] Fix crash in convert-gpu-to-spirv pass with memrefs with affine maps spirv::getElementPtr can return null (for memrefs with affine map) but patterns didn't handle this. Differential Revision: https://reviews.llvm.org/D106988 --- diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 2186107..a112146 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -147,6 +147,7 @@ Value linearizeIndex(ValueRange indices, ArrayRef strides, /// Performs the index computation to get to the element at `indices` of the /// memory pointed to by `basePtr`, using the layout map of `baseType`. +/// Returns null if index computation cannot be performed. // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap // that has static strides. Extend to handle dynamic strides. diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index c6be486..ddc312e 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -268,6 +268,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), loadOperands.indices(), loc, rewriter); + if (!accessChainOp) + return failure(); + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; if (isBool) @@ -358,6 +361,10 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, auto loadPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); + + if (!loadPtr) + return failure(); + rewriter.replaceOpWithNewOp(loadOp, loadPtr); return success(); } @@ -376,6 +383,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, spirv::AccessChainOp accessChainOp = spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); + + if (!accessChainOp) + return failure(); + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; @@ -467,6 +478,10 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, spirv::getElementPtr(*getTypeConverter(), memrefType, storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(), rewriter); + + if (!storePtr) + return failure(); + rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value()); return success();