From a938f3a9b58d1f7ccd2fc17c0935c12f94d41695 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Wed, 12 Oct 2022 17:12:21 +0300 Subject: [PATCH] [mlir] Fix bitwidth of memref-to-llvm constant. One constant generated in MemRefToLLVM had a hardcoded bitwidth of 64 bits. The fix uses the typeConverter to create a constant that matches the bitwidth of the provided by the data layout. The issue was detected in an attempt to add a verifier to the LLVM ICmp operation that checks that the types of the compared arguments match. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D135775 --- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 5 +++-- mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 6ff0af9..0f1f644 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1518,6 +1518,7 @@ static void fillInStridesForCollapsedMemDescriptor( ConversionPatternRewriter &rewriter, Location loc, Operation *op, TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, ArrayRef reassociation) { + auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); // See comments for computeCollapsedLayoutMap for details on how the strides // are calculated. auto srcShape = srcType.getShape(); @@ -1579,8 +1580,8 @@ static void fillInStridesForCollapsedMemDescriptor( rewriter.create(loc, srcStride, continueBlock); break; } - Value one = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI32IntegerAttr(1)); + Value one = rewriter.create(loc, llvmIndexType, + rewriter.getIndexAttr(1)); Value predNeOne = rewriter.create( loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex), one); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 07979b1..c66cd58 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -776,7 +776,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout( // CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: llvm.mlir.constant(1 : i32) : i64 +// CHECK: llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i64 // CHECK: llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i64), ^bb1 @@ -785,6 +785,10 @@ func.func @collapse_shape_dynamic_with_non_identity_layout( // CHECK: llvm.br ^bb2(%{{.*}} : i64) // CHECK: ^bb2(%[[STRIDE:.*]]: i64): // CHECK: llvm.insertvalue %[[STRIDE]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout( +// CHECK32: llvm.mlir.constant(1 : index) : i32 +// CHECK32: llvm.mlir.constant(4 : index) : i32 +// CHECK32: llvm.mlir.constant(1 : index) : i32 // ----- @@ -1149,7 +1153,7 @@ func.func @memref_copy_unranked() { // CHECK-LABEL: func @extract_aligned_pointer_as_index func.func @extract_aligned_pointer_as_index(%m: memref) -> index { %0 = memref.extract_aligned_pointer_as_index %m: memref -> index - // CHECK: %[[E:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[E:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[E]] : !llvm.ptr to i64 // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index -- 2.7.4