[mlir] Fix bitwidth of memref-to-llvm constant.
authorTobias Gysi <tobias.gysi@nextsilicon.com>
Wed, 12 Oct 2022 14:12:21 +0000 (17:12 +0300)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Wed, 12 Oct 2022 14:13:01 +0000 (17:13 +0300)
One constant generated in MemRefToLLVM had a hardcoded bitwidth of
64 bits. The fix uses the typeConverter to create a constant that
matches the bitwidth of the provided by the data layout. The issue was
detected in an attempt to add a verifier to the LLVM ICmp operation that
checks that the types of the compared arguments match.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D135775

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

index 6ff0af9..0f1f644 100644 (file)
@@ -1518,6 +1518,7 @@ static void fillInStridesForCollapsedMemDescriptor(
     ConversionPatternRewriter &rewriter, Location loc, Operation *op,
     TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
     MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
+  auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
   // See comments for computeCollapsedLayoutMap for details on how the strides
   // are calculated.
   auto srcShape = srcType.getShape();
@@ -1579,8 +1580,8 @@ static void fillInStridesForCollapsedMemDescriptor(
           rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
           break;
         }
-        Value one = rewriter.create<LLVM::ConstantOp>(
-            loc, rewriter.getI64Type(), rewriter.getI32IntegerAttr(1));
+        Value one = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
+                                                      rewriter.getIndexAttr(1));
         Value predNeOne = rewriter.create<LLVM::ICmpOp>(
             loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
             one);
index 07979b1..c66cd58 100644 (file)
@@ -776,7 +776,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
 //       CHECK:      llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:      llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 //       CHECK:      llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.mlir.constant(1 : i32) : i64
+//       CHECK:      llvm.mlir.constant(1 : index) : i64
 //       CHECK:      llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:      llvm.icmp "ne" %{{.*}}, %{{.*}} : i64
 //       CHECK:      llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i64), ^bb1
@@ -785,6 +785,10 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
 //       CHECK:      llvm.br ^bb2(%{{.*}} : i64)
 //       CHECK:      ^bb2(%[[STRIDE:.*]]: i64):
 //       CHECK:      llvm.insertvalue %[[STRIDE]], %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
+//       CHECK32:      llvm.mlir.constant(1 : index) : i32
+//       CHECK32:      llvm.mlir.constant(4 : index) : i32
+//       CHECK32:      llvm.mlir.constant(1 : index) : i32
 
 // -----
 
@@ -1149,7 +1153,7 @@ func.func @memref_copy_unranked() {
 // CHECK-LABEL: func @extract_aligned_pointer_as_index
 func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
   %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32> -> index
-  // CHECK: %[[E:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: %[[E:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[E]] : !llvm.ptr<f32> to i64
   // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index