[MLIR] Fix standard -> LLVM conversion to fail for unsupported memref element type.
authorRahul Joshi <jurahul@google.com>
Thu, 12 Nov 2020 20:05:35 +0000 (12:05 -0800)
committerRahul Joshi <jurahul@google.com>
Fri, 13 Nov 2020 01:06:05 +0000 (17:06 -0800)
- Move isSupportedMemRefType() to ConvertToLLVMPatterns and check if the
  memref element type is supported there.

Differential Revision: https://reviews.llvm.org/D91374

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

index e7aa9d5..04f8849 100644 (file)
@@ -522,6 +522,9 @@ protected:
                              ArrayRef<int64_t> strides, int64_t offset,
                              ConversionPatternRewriter &rewriter) const;
 
+  /// Returns if the givem memref type is supported.
+  bool isSupportedMemRefType(MemRefType type) const;
+
   Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
                    ValueRange indices,
                    ConversionPatternRewriter &rewriter) const;
index 57c26c4..6807f83 100644 (file)
@@ -1094,11 +1094,20 @@ Value ConvertToLLVMPattern::getDataPtr(
                               offset, rewriter);
 }
 
+// Check if the MemRefType `type` is supported by the lowering. We currently
+// only support memrefs with identity maps.
+bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
+  if (!typeConverter.convertType(type.getElementType()))
+    return false;
+  return type.getAffineMaps().empty() ||
+         llvm::all_of(type.getAffineMaps(),
+                      [](AffineMap map) { return map.isIdentity(); });
+}
+
 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
   auto elementType = type.getElementType();
-  auto structElementType = typeConverter.convertType(elementType);
-  return structElementType.cast<LLVM::LLVMType>().getPointerTo(
-      type.getMemorySpace());
+  auto structElementType = unwrap(typeConverter.convertType(elementType));
+  return structElementType.getPointerTo(type.getMemorySpace());
 }
 
 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
@@ -1912,14 +1921,6 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
   }
 };
 
-// Check if the MemRefType `type` is supported by the lowering. We currently
-// only support memrefs with identity maps.
-static bool isSupportedMemRefType(MemRefType type) {
-  return type.getAffineMaps().empty() ||
-         llvm::all_of(type.getAffineMaps(),
-                      [](AffineMap map) { return map.isIdentity(); });
-}
-
 /// Lowering for AllocOp and AllocaOp.
 struct AllocLikeOpLowering : public ConvertToLLVMPattern {
   using ConvertToLLVMPattern::createIndexConstant;
@@ -3070,6 +3071,7 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
 template <typename Derived>
 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
+  using ConvertOpToLLVMPattern<Derived>::isSupportedMemRefType;
   using Base = LoadStoreOpLowering<Derived>;
 
   LogicalResult match(Operation *op) const override {