From f8fafe99a4ee2c047acf5a79d1033da8024f1f26 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 3 Oct 2022 18:46:59 +0000 Subject: [PATCH] [mlir] Add unsigned version of index_cast This is required to be able to cast integer type to a potential larger index using zero-extend cast. There is a larger change under discussion to move index ops in a separate dialect: https://discourse.llvm.org/t/rfc-index-dialect/65540/ Based on timing of this work this patch can be included as part of this effort but as a short term solution we may want to add this op to arithmetic dialect for now in order to fill the gap. Reviewed By: Mogball, stellaraccident Differential Revision: https://reviews.llvm.org/D135089 --- mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 19 +++++++++ mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 37 +++++++++------- mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 1 + mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td | 15 +++++++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 32 +++++++++++++- .../Arith/IR/InferIntRangeInterfaceImpls.cpp | 32 ++++++++++++-- .../test/Conversion/ArithToLLVM/arith-to-llvm.mlir | 17 ++++++++ .../Conversion/ArithToSPIRV/arith-to-spirv.mlir | 30 ++++++++++++- mlir/test/Dialect/Arith/canonicalize.mlir | 9 ++++ mlir/test/Dialect/Arith/ops.mlir | 49 ++++++++++++++++++++++ 10 files changed, 219 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index c143f9a..b59d9fd 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1037,6 +1037,25 @@ def Arith_IndexCastOp } //===----------------------------------------------------------------------===// +// IndexCastUIOp +//===----------------------------------------------------------------------===// + +def Arith_IndexCastUIOp + : Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint, + [DeclareOpInterfaceMethods]> { + let summary = "unsigned cast between index and integer types"; + let description = [{ + Casts between scalar or vector integers and corresponding 'index' scalar or + vectors. Index is an integer of platform-specific bit width. If casting to + a wider integer, the value is zero-extended. If casting to a narrower + integer, the value is truncated. + }]; + + let hasFolder = 1; + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// // BitcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index c77edd8..1610e5c 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -104,14 +104,20 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { /// becomes an integer. If the bit width of the source and target integer /// types is the same, just erase the cast. If the target type is wider, /// sign-extend the value, otherwise truncate it. -struct IndexCastOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +template +struct IndexCastOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; +using IndexCastOpSILowering = + IndexCastOpLowering; +using IndexCastOpUILowering = + IndexCastOpLowering; + struct AddUICarryOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -155,14 +161,15 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, // IndexCastOpLowering //===----------------------------------------------------------------------===// -LogicalResult IndexCastOpLowering::matchAndRewrite( - arith::IndexCastOp op, OpAdaptor adaptor, +template +LogicalResult IndexCastOpLowering::matchAndRewrite( + OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { Type resultType = op.getResult().getType(); Type targetElementType = - typeConverter->convertType(getElementTypeOrSelf(resultType)); + this->typeConverter->convertType(getElementTypeOrSelf(resultType)); Type sourceElementType = - typeConverter->convertType(getElementTypeOrSelf(op.getIn())); + this->typeConverter->convertType(getElementTypeOrSelf(op.getIn())); unsigned targetBits = targetElementType.getIntOrFloatBitWidth(); unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth(); @@ -174,13 +181,12 @@ LogicalResult IndexCastOpLowering::matchAndRewrite( // Handle the scalar and 1D vector cases. Type operandType = adaptor.getIn().getType(); if (!operandType.isa()) { - Type targetType = typeConverter->convertType(resultType); + Type targetType = this->typeConverter->convertType(resultType); if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); else - rewriter.replaceOpWithNewOp(op, targetType, - adaptor.getIn()); + rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); return success(); } @@ -188,15 +194,15 @@ LogicalResult IndexCastOpLowering::matchAndRewrite( return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( - op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()), [&](Type llvm1DVectorTy, ValueRange operands) -> Value { - OpAdaptor adaptor(operands); + typename OpTy::Adaptor adaptor(operands); if (targetBits < sourceBits) { return rewriter.create(op.getLoc(), llvm1DVectorTy, adaptor.getIn()); } - return rewriter.create(op.getLoc(), llvm1DVectorTy, - adaptor.getIn()); + return rewriter.create(op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); }, rewriter); } @@ -366,7 +372,8 @@ void mlir::arith::populateArithToLLVMConversionPatterns( ExtUIOpLowering, FPToSIOpLowering, FPToUIOpLowering, - IndexCastOpLowering, + IndexCastOpSILowering, + IndexCastOpUILowering, MaxFOpLowering, MaxSIOpLowering, MaxUIOpLowering, diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 034abcc..24bead8 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -990,6 +990,7 @@ void mlir::arith::populateArithToSPIRVPatterns( TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, + TypeCastingOpPattern, TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index f57c4d8..2cb5a55 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -152,6 +152,21 @@ def IndexCastOfExtSI : Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>; //===----------------------------------------------------------------------===// +// IndexCastUIOp +//===----------------------------------------------------------------------===// + +// index_castui(index_castui(x)) -> x, if dstType == srcType. +def IndexCastUIOfIndexCastUI : + Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x)), + (replaceWithValue $x), + [(Constraint> $res, $x)]>; + +// index_castui(extui(x)) -> index_castui(x) +def IndexCastUIOfExtUI : + Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x)), (Arith_IndexCastUIOp $x)>; + + +//===----------------------------------------------------------------------===// // BitcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d62e3f3..190a1ef 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1261,8 +1261,7 @@ OpFoldResult arith::FPToSIOp::fold(ArrayRef operands) { // IndexCastOp //===----------------------------------------------------------------------===// -bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, - TypeRange outputs) { +static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) { if (!areValidCastInputsAndOutputs(inputs, outputs)) return false; @@ -1275,6 +1274,11 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, (srcType.isSignlessInteger() && dstType.isIndex()); } +bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, + TypeRange outputs) { + return areIndexCastCompatible(inputs, outputs); +} + OpFoldResult arith::IndexCastOp::fold(ArrayRef operands) { // index_cast(constant) -> constant // A little hack because we go through int. Otherwise, the size of the @@ -1291,6 +1295,30 @@ void arith::IndexCastOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// +// IndexCastUIOp +//===----------------------------------------------------------------------===// + +bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs, + TypeRange outputs) { + return areIndexCastCompatible(inputs, outputs); +} + +OpFoldResult arith::IndexCastUIOp::fold(ArrayRef operands) { + // index_castui(constant) -> constant + // A little hack because we go through int. Otherwise, the size of the + // constant might need to change. + if (auto value = operands[0].dyn_cast_or_null()) + return IntegerAttr::get(getType(), value.getUInt()); + + return {}; +} + +void arith::IndexCastUIOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + +//===----------------------------------------------------------------------===// // BitcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp index e59469c..243c3ef 100644 --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -466,13 +466,18 @@ void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, // ExtUIOp //===----------------------------------------------------------------------===// +static ConstantIntRanges extUIRange(const ConstantIntRanges &range, + Type destType) { + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + APInt smin = range.umin().zext(destWidth); + APInt smax = range.umax().zext(destWidth); + return ConstantIntRanges::fromSigned(smin, smax); +} + void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { Type destType = getResult().getType(); - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt umin = argRanges[0].umin().zext(destWidth); - APInt umax = argRanges[0].umax().zext(destWidth); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), extUIRange(argRanges[0], destType)); } //===----------------------------------------------------------------------===// @@ -560,6 +565,25 @@ void arith::IndexCastOp::inferResultRanges( } //===----------------------------------------------------------------------===// +// IndexCastUIOp +//===----------------------------------------------------------------------===// + +void arith::IndexCastUIOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + + if (srcWidth < destWidth) + setResultRange(getResult(), extUIRange(argRanges[0], destType)); + else if (srcWidth > destWidth) + setResultRange(getResult(), truncIRange(argRanges[0], destType)); + else + setResultRange(getResult(), argRanges[0]); +} + +//===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index c476d43..05706d8 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -92,6 +92,23 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) { return } +func.func @index_castui(%arg0: index, %arg1: i1) { +// CHECK: = llvm.trunc %0 : i{{.*}} to i1 + %0 = arith.index_castui %arg0: index to i1 +// CHECK-NEXT: = llvm.zext %arg1 : i1 to i{{.*}} + %1 = arith.index_castui %arg1: i1 to index + return +} + +// CHECK-LABEL: @vector_index_castui +func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) { +// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1> + %0 = arith.index_castui %arg0: vector<2xindex> to vector<2xi1> +// CHECK-NEXT: = llvm.zext %{{.*}} : vector<2xi1> to vector<2xi{{.*}}> + %1 = arith.index_castui %arg1: vector<2xi1> to vector<2xindex> + return +} + // Checking conversion of signed integer types to floating point. // CHECK-LABEL: @sitofp func.func @sitofp(%arg0 : i32, %arg1 : i64) { diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 16a7967..bd5238f 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -690,7 +690,35 @@ func.func @index_cast3(%arg0: i32) { // CHECK-LABEL: index_cast4 func.func @index_cast4(%arg0: index) { - // CHECK-NOT: spirv.SConvert + // CHECK-NOT: spirv.UConvert + %0 = arith.index_cast %arg0 : index to i32 + return +} + +// CHECK-LABEL: index_castui1 +func.func @index_castui1(%arg0: i16) { + // CHECK: spirv.UConvert %{{.+}} : i16 to i32 + %0 = arith.index_castui %arg0 : i16 to index + return +} + +// CHECK-LABEL: index_castui2 +func.func @index_castui2(%arg0: index) { + // CHECK: spirv.UConvert %{{.+}} : i32 to i16 + %0 = arith.index_castui %arg0 : index to i16 + return +} + +// CHECK-LABEL: index_castui3 +func.func @index_castui3(%arg0: i32) { + // CHECK-NOT: spirv.UConvert + %0 = arith.index_castui %arg0 : i32 to index + return +} + +// CHECK-LABEL: index_castui4 +func.func @index_castui4(%arg0: index) { + // CHECK-NOT: spirv.UConvert %0 = arith.index_cast %arg0 : index to i32 return } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 632e7af..be680ac 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -308,6 +308,15 @@ func.func @indexCastOfSignExtend(%arg0: i8) -> index { return %idx : index } +// CHECK-LABEL: @indexCastUIOfUnsignedExtend +// CHECK: %[[res:.+]] = arith.index_castui %arg0 : i8 to index +// CHECK: return %[[res]] +func.func @indexCastUIOfUnsignedExtend(%arg0: i8) -> index { + %ext = arith.extui %arg0 : i8 to i16 + %idx = arith.index_castui %ext : i16 to index + return %idx : index +} + // CHECK-LABEL: @signExtendConstant // CHECK: %[[cres:.+]] = arith.constant -2 : i16 // CHECK: return %[[cres]] diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index 56e17c7..c34850f 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -793,6 +793,55 @@ func.func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector return %0 : vector<[8]xi64> } + +// CHECK-LABEL: test_index_castui0 +func.func @test_index_castui0(%arg0 : i32) -> index { + %0 = arith.index_castui %arg0 : i32 to index + return %0 : index +} + +// CHECK-LABEL: test_index_castui_tensor0 +func.func @test_index_castui_tensor0(%arg0 : tensor<8x8xi32>) -> tensor<8x8xindex> { + %0 = arith.index_castui %arg0 : tensor<8x8xi32> to tensor<8x8xindex> + return %0 : tensor<8x8xindex> +} + +// CHECK-LABEL: test_index_castui_vector0 +func.func @test_index_castui_vector0(%arg0 : vector<8xi32>) -> vector<8xindex> { + %0 = arith.index_castui %arg0 : vector<8xi32> to vector<8xindex> + return %0 : vector<8xindex> +} + +// CHECK-LABEL: test_index_castui_scalable_vector0 +func.func @test_index_castui_scalable_vector0(%arg0 : vector<[8]xi32>) -> vector<[8]xindex> { + %0 = arith.index_castui %arg0 : vector<[8]xi32> to vector<[8]xindex> + return %0 : vector<[8]xindex> +} + +// CHECK-LABEL: test_indexui_cast1 +func.func @test_indexui_cast1(%arg0 : index) -> i64 { + %0 = arith.index_castui %arg0 : index to i64 + return %0 : i64 +} + +// CHECK-LABEL: test_index_castui_tensor1 +func.func @test_index_castui_tensor1(%arg0 : tensor<8x8xindex>) -> tensor<8x8xi64> { + %0 = arith.index_castui %arg0 : tensor<8x8xindex> to tensor<8x8xi64> + return %0 : tensor<8x8xi64> +} + +// CHECK-LABEL: test_index_castui_vector1 +func.func @test_index_castui_vector1(%arg0 : vector<8xindex>) -> vector<8xi64> { + %0 = arith.index_castui %arg0 : vector<8xindex> to vector<8xi64> + return %0 : vector<8xi64> +} + +// CHECK-LABEL: test_index_castui_scalable_vector1 +func.func @test_index_castui_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector<[8]xi64> { + %0 = arith.index_castui %arg0 : vector<[8]xindex> to vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_bitcast0 func.func @test_bitcast0(%arg0 : i64) -> f64 { %0 = arith.bitcast %arg0 : i64 to f64 -- 2.7.4