[mlir] Add unsigned version of index_cast
authorThomas Raoux <thomasraoux@google.com>
Mon, 3 Oct 2022 18:46:59 +0000 (18:46 +0000)
committerThomas Raoux <thomasraoux@google.com>
Mon, 3 Oct 2022 18:51:15 +0000 (18:51 +0000)
This is required to be able to cast integer type to a potential larger index using zero-extend cast.

There is a larger change under discussion to move index ops in a separate dialect: https://discourse.llvm.org/t/rfc-index-dialect/65540/
Based on timing of this work this patch can be included as part of this effort but as a short term solution we may want to add this op to arithmetic dialect for now in order to fill the gap.

Reviewed By: Mogball, stellaraccident

Differential Revision: https://reviews.llvm.org/D135089

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Arith/ops.mlir

index c143f9a..b59d9fd 100644 (file)
@@ -1037,6 +1037,25 @@ def Arith_IndexCastOp
 }
 
 //===----------------------------------------------------------------------===//
+// IndexCastUIOp
+//===----------------------------------------------------------------------===//
+
+def Arith_IndexCastUIOp
+  : Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
+                 [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+  let summary = "unsigned cast between index and integer types";
+  let description = [{
+    Casts between scalar or vector integers and corresponding 'index' scalar or
+    vectors. Index is an integer of platform-specific bit width. If casting to
+    a wider integer, the value is zero-extended. If casting to a narrower
+    integer, the value is truncated.
+  }];
+
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
 // BitcastOp
 //===----------------------------------------------------------------------===//
 
index c77edd8..1610e5c 100644 (file)
@@ -104,14 +104,20 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
 /// becomes an integer.  If the bit width of the source and target integer
 /// types is the same, just erase the cast.  If the target type is wider,
 /// sign-extend the value, otherwise truncate it.
-struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+template <typename OpTy, typename ExtCastTy>
+struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
+  using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+using IndexCastOpSILowering =
+    IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
+using IndexCastOpUILowering =
+    IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
+
 struct AddUICarryOpLowering
     : public ConvertOpToLLVMPattern<arith::AddUICarryOp> {
   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -155,14 +161,15 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
 // IndexCastOpLowering
 //===----------------------------------------------------------------------===//
 
-LogicalResult IndexCastOpLowering::matchAndRewrite(
-    arith::IndexCastOp op, OpAdaptor adaptor,
+template <typename OpTy, typename ExtCastTy>
+LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
+    OpTy op, typename OpTy::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Type resultType = op.getResult().getType();
   Type targetElementType =
-      typeConverter->convertType(getElementTypeOrSelf(resultType));
+      this->typeConverter->convertType(getElementTypeOrSelf(resultType));
   Type sourceElementType =
-      typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
+      this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
   unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
   unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
 
@@ -174,13 +181,12 @@ LogicalResult IndexCastOpLowering::matchAndRewrite(
   // Handle the scalar and 1D vector cases.
   Type operandType = adaptor.getIn().getType();
   if (!operandType.isa<LLVM::LLVMArrayType>()) {
-    Type targetType = typeConverter->convertType(resultType);
+    Type targetType = this->typeConverter->convertType(resultType);
     if (targetBits < sourceBits)
       rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
                                                  adaptor.getIn());
     else
-      rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
-                                                adaptor.getIn());
+      rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
     return success();
   }
 
@@ -188,15 +194,15 @@ LogicalResult IndexCastOpLowering::matchAndRewrite(
     return rewriter.notifyMatchFailure(op, "expected vector result type");
 
   return LLVM::detail::handleMultidimensionalVectors(
-      op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+      op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
       [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
-        OpAdaptor adaptor(operands);
+        typename OpTy::Adaptor adaptor(operands);
         if (targetBits < sourceBits) {
           return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
                                                 adaptor.getIn());
         }
-        return rewriter.create<LLVM::SExtOp>(op.getLoc(), llvm1DVectorTy,
-                                             adaptor.getIn());
+        return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
+                                          adaptor.getIn());
       },
       rewriter);
 }
@@ -366,7 +372,8 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     ExtUIOpLowering,
     FPToSIOpLowering,
     FPToUIOpLowering,
-    IndexCastOpLowering,
+    IndexCastOpSILowering,
+    IndexCastOpUILowering,
     MaxFOpLowering,
     MaxSIOpLowering,
     MaxUIOpLowering,
index 034abcc..24bead8 100644 (file)
@@ -990,6 +990,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+    TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
     CmpFOpNanNonePattern, CmpFOpPattern,
index f57c4d8..2cb5a55 100644 (file)
@@ -152,6 +152,21 @@ def IndexCastOfExtSI :
     Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>;
 
 //===----------------------------------------------------------------------===//
+// IndexCastUIOp
+//===----------------------------------------------------------------------===//
+
+// index_castui(index_castui(x)) -> x, if dstType == srcType.
+def IndexCastUIOfIndexCastUI :
+    Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x)),
+        (replaceWithValue $x),
+        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
+
+// index_castui(extui(x)) -> index_castui(x)
+def IndexCastUIOfExtUI :
+    Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x)), (Arith_IndexCastUIOp $x)>;
+
+
+//===----------------------------------------------------------------------===//
 // BitcastOp
 //===----------------------------------------------------------------------===//
 
index d62e3f3..190a1ef 100644 (file)
@@ -1261,8 +1261,7 @@ OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
-                                           TypeRange outputs) {
+static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!areValidCastInputsAndOutputs(inputs, outputs))
     return false;
 
@@ -1275,6 +1274,11 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
          (srcType.isSignlessInteger() && dstType.isIndex());
 }
 
+bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
+                                           TypeRange outputs) {
+  return areIndexCastCompatible(inputs, outputs);
+}
+
 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
   // index_cast(constant) -> constant
   // A little hack because we go through int. Otherwise, the size of the
@@ -1291,6 +1295,30 @@ void arith::IndexCastOp::getCanonicalizationPatterns(
 }
 
 //===----------------------------------------------------------------------===//
+// IndexCastUIOp
+//===----------------------------------------------------------------------===//
+
+bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
+                                             TypeRange outputs) {
+  return areIndexCastCompatible(inputs, outputs);
+}
+
+OpFoldResult arith::IndexCastUIOp::fold(ArrayRef<Attribute> operands) {
+  // index_castui(constant) -> constant
+  // A little hack because we go through int. Otherwise, the size of the
+  // constant might need to change.
+  if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
+    return IntegerAttr::get(getType(), value.getUInt());
+
+  return {};
+}
+
+void arith::IndexCastUIOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
+}
+
+//===----------------------------------------------------------------------===//
 // BitcastOp
 //===----------------------------------------------------------------------===//
 
index e59469c..243c3ef 100644 (file)
@@ -466,13 +466,18 @@ void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // ExtUIOp
 //===----------------------------------------------------------------------===//
 
+static ConstantIntRanges extUIRange(const ConstantIntRanges &range,
+                                    Type destType) {
+  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+  APInt smin = range.umin().zext(destWidth);
+  APInt smax = range.umax().zext(destWidth);
+  return ConstantIntRanges::fromSigned(smin, smax);
+}
+
 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
   Type destType = getResult().getType();
-  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
-  APInt umin = argRanges[0].umin().zext(destWidth);
-  APInt umax = argRanges[0].umax().zext(destWidth);
-  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+  setResultRange(getResult(), extUIRange(argRanges[0], destType));
 }
 
 //===----------------------------------------------------------------------===//
@@ -560,6 +565,25 @@ void arith::IndexCastOp::inferResultRanges(
 }
 
 //===----------------------------------------------------------------------===//
+// IndexCastUIOp
+//===----------------------------------------------------------------------===//
+
+void arith::IndexCastUIOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  Type sourceType = getOperand().getType();
+  Type destType = getResult().getType();
+  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
+  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+
+  if (srcWidth < destWidth)
+    setResultRange(getResult(), extUIRange(argRanges[0], destType));
+  else if (srcWidth > destWidth)
+    setResultRange(getResult(), truncIRange(argRanges[0], destType));
+  else
+    setResultRange(getResult(), argRanges[0]);
+}
+
+//===----------------------------------------------------------------------===//
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
index c476d43..05706d8 100644 (file)
@@ -92,6 +92,23 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
   return
 }
 
+func.func @index_castui(%arg0: index, %arg1: i1) {
+// CHECK: = llvm.trunc %0 : i{{.*}} to i1
+  %0 = arith.index_castui %arg0: index to i1
+// CHECK-NEXT: = llvm.zext %arg1 : i1 to i{{.*}}
+  %1 = arith.index_castui %arg1: i1 to index
+  return
+}
+
+// CHECK-LABEL: @vector_index_castui
+func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
+// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
+  %0 = arith.index_castui %arg0: vector<2xindex> to vector<2xi1>
+// CHECK-NEXT: = llvm.zext %{{.*}} : vector<2xi1> to vector<2xi{{.*}}>
+  %1 = arith.index_castui %arg1: vector<2xi1> to vector<2xindex>
+  return
+}
+
 // Checking conversion of signed integer types to floating point.
 // CHECK-LABEL: @sitofp
 func.func @sitofp(%arg0 : i32, %arg1 : i64) {
index 16a7967..bd5238f 100644 (file)
@@ -690,7 +690,35 @@ func.func @index_cast3(%arg0: i32) {
 
 // CHECK-LABEL: index_cast4
 func.func @index_cast4(%arg0: index) {
-  // CHECK-NOT: spirv.SConvert
+  // CHECK-NOT: spirv.UConvert
+  %0 = arith.index_cast %arg0 : index to i32
+  return
+}
+
+// CHECK-LABEL: index_castui1
+func.func @index_castui1(%arg0: i16) {
+  // CHECK: spirv.UConvert %{{.+}} : i16 to i32
+  %0 = arith.index_castui %arg0 : i16 to index
+  return
+}
+
+// CHECK-LABEL: index_castui2
+func.func @index_castui2(%arg0: index) {
+  // CHECK: spirv.UConvert %{{.+}} : i32 to i16
+  %0 = arith.index_castui %arg0 : index to i16
+  return
+}
+
+// CHECK-LABEL: index_castui3
+func.func @index_castui3(%arg0: i32) {
+  // CHECK-NOT: spirv.UConvert
+  %0 = arith.index_castui %arg0 : i32 to index
+  return
+}
+
+// CHECK-LABEL: index_castui4
+func.func @index_castui4(%arg0: index) {
+  // CHECK-NOT: spirv.UConvert
   %0 = arith.index_cast %arg0 : index to i32
   return
 }
index 632e7af..be680ac 100644 (file)
@@ -308,6 +308,15 @@ func.func @indexCastOfSignExtend(%arg0: i8) -> index {
   return %idx : index
 }
 
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend
+//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+//       CHECK:   return %[[res]]
+func.func @indexCastUIOfUnsignedExtend(%arg0: i8) -> index {
+  %ext = arith.extui %arg0 : i8 to i16
+  %idx = arith.index_castui %ext : i16 to index
+  return %idx : index
+}
+
 // CHECK-LABEL: @signExtendConstant
 //       CHECK:   %[[cres:.+]] = arith.constant -2 : i16
 //       CHECK:   return %[[cres]]
index 56e17c7..c34850f 100644 (file)
@@ -793,6 +793,55 @@ func.func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector
   return %0 : vector<[8]xi64>
 }
 
+
+// CHECK-LABEL: test_index_castui0
+func.func @test_index_castui0(%arg0 : i32) -> index {
+  %0 = arith.index_castui %arg0 : i32 to index
+  return %0 : index
+}
+
+// CHECK-LABEL: test_index_castui_tensor0
+func.func @test_index_castui_tensor0(%arg0 : tensor<8x8xi32>) -> tensor<8x8xindex> {
+  %0 = arith.index_castui %arg0 : tensor<8x8xi32> to tensor<8x8xindex>
+  return %0 : tensor<8x8xindex>
+}
+
+// CHECK-LABEL: test_index_castui_vector0
+func.func @test_index_castui_vector0(%arg0 : vector<8xi32>) -> vector<8xindex> {
+  %0 = arith.index_castui %arg0 : vector<8xi32> to vector<8xindex>
+  return %0 : vector<8xindex>
+}
+
+// CHECK-LABEL: test_index_castui_scalable_vector0
+func.func @test_index_castui_scalable_vector0(%arg0 : vector<[8]xi32>) -> vector<[8]xindex> {
+  %0 = arith.index_castui %arg0 : vector<[8]xi32> to vector<[8]xindex>
+  return %0 : vector<[8]xindex>
+}
+
+// CHECK-LABEL: test_indexui_cast1
+func.func @test_indexui_cast1(%arg0 : index) -> i64 {
+  %0 = arith.index_castui %arg0 : index to i64
+  return %0 : i64
+}
+
+// CHECK-LABEL: test_index_castui_tensor1
+func.func @test_index_castui_tensor1(%arg0 : tensor<8x8xindex>) -> tensor<8x8xi64> {
+  %0 = arith.index_castui %arg0 : tensor<8x8xindex> to tensor<8x8xi64>
+  return %0 : tensor<8x8xi64>
+}
+
+// CHECK-LABEL: test_index_castui_vector1
+func.func @test_index_castui_vector1(%arg0 : vector<8xindex>) -> vector<8xi64> {
+  %0 = arith.index_castui %arg0 : vector<8xindex> to vector<8xi64>
+  return %0 : vector<8xi64>
+}
+
+// CHECK-LABEL: test_index_castui_scalable_vector1
+func.func @test_index_castui_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector<[8]xi64> {
+  %0 = arith.index_castui %arg0 : vector<[8]xindex> to vector<[8]xi64>
+  return %0 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_bitcast0
 func.func @test_bitcast0(%arg0 : i64) -> f64 {
   %0 = arith.bitcast %arg0 : i64 to f64