[mlir][spirv] Fix crash in convert-gpu-to-spirv pass with memrefs with affine maps
authorButygin <ivan.butygin@intel.com>
Wed, 28 Jul 2021 19:31:26 +0000 (22:31 +0300)
committerButygin <ivan.butygin@intel.com>
Fri, 30 Jul 2021 09:46:13 +0000 (12:46 +0300)
spirv::getElementPtr can return null (for memrefs with affine map) but patterns didn't handle this.

Differential Revision: https://reviews.llvm.org/D106988

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

index 2186107..a112146 100644 (file)
@@ -147,6 +147,7 @@ Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
 
 /// Performs the index computation to get to the element at `indices` of the
 /// memory pointed to by `basePtr`, using the layout map of `baseType`.
+/// Returns null if index computation cannot be performed.
 
 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
 // that has static strides. Extend to handle dynamic strides.
index c6be486..ddc312e 100644 (file)
@@ -268,6 +268,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
       spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
                            loadOperands.indices(), loc, rewriter);
 
+  if (!accessChainOp)
+    return failure();
+
   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
   bool isBool = srcBits == 1;
   if (isBool)
@@ -358,6 +361,10 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
   auto loadPtr = spirv::getElementPtr(
       *getTypeConverter<SPIRVTypeConverter>(), memrefType,
       loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
+
+  if (!loadPtr)
+    return failure();
+
   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
   return success();
 }
@@ -376,6 +383,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
   spirv::AccessChainOp accessChainOp =
       spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
                            storeOperands.indices(), loc, rewriter);
+
+  if (!accessChainOp)
+    return failure();
+
   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
 
   bool isBool = srcBits == 1;
@@ -467,6 +478,10 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
       spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
                            storeOperands.memref(), storeOperands.indices(),
                            storeOp.getLoc(), rewriter);
+
+  if (!storePtr)
+    return failure();
+
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
                                               storeOperands.value());
   return success();