From ce6a03ce0b82dbebcf5c752bade0955dc3abb2ba Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 29 Mar 2023 12:49:32 -0400 Subject: [PATCH] [mlir][arith] Fold `index_cast[ui]` of vectors Handle the splat and dense case. I saw this pattern show up in a couple recent SPIR-V-specific bug report. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D147109 --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 26 ++++++----- mlir/test/Dialect/Arith/canonicalize.mlir | 72 +++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index e56f452..d7ce71a 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1455,12 +1455,15 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) { // index_cast(constant) -> constant - // A little hack because we go through int. Otherwise, the size of the - // constant might need to change. - if (auto value = adaptor.getIn().dyn_cast_or_null()) - return IntegerAttr::get(getType(), value.getInt()); + unsigned resultBitwidth = 64; // Default for index integer attributes. + if (auto intTy = dyn_cast(getElementTypeOrSelf(getType()))) + resultBitwidth = intTy.getWidth(); - return {}; + return constFoldCastOp( + adaptor.getOperands(), getType(), + [resultBitwidth](const APInt &a, bool & /*castStatus*/) { + return a.sextOrTrunc(resultBitwidth); + }); } void arith::IndexCastOp::getCanonicalizationPatterns( @@ -1479,12 +1482,15 @@ bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs, OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) { // index_castui(constant) -> constant - // A little hack because we go through int. Otherwise, the size of the - // constant might need to change. - if (auto value = adaptor.getIn().dyn_cast_or_null()) - return IntegerAttr::get(getType(), value.getValue().getZExtValue()); + unsigned resultBitwidth = 64; // Default for index integer attributes. + if (auto intTy = dyn_cast(getElementTypeOrSelf(getType()))) + resultBitwidth = intTy.getWidth(); - return {}; + return constFoldCastOp( + adaptor.getOperands(), getType(), + [resultBitwidth](const APInt &a, bool & /*castStatus*/) { + return a.zextOrTrunc(resultBitwidth); + }); } void arith::IndexCastUIOp::getCanonicalizationPatterns( diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index b75dfd7..0170620 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -446,6 +446,42 @@ func.func @indexCastFoldIndexToInt() -> i32 { return %int : i32 } +// CHECK-LABEL: @indexCastFoldSplatVector +// CHECK: %[[res:.*]] = arith.constant dense<42> : vector<3xindex> +// CHECK: return %[[res]] : vector<3xindex> +func.func @indexCastFoldSplatVector() -> vector<3xindex> { + %cst = arith.constant dense<42> : vector<3xi32> + %int = arith.index_cast %cst : vector<3xi32> to vector<3xindex> + return %int : vector<3xindex> +} + +// CHECK-LABEL: @indexCastFoldVector +// CHECK: %[[res:.*]] = arith.constant dense<[1, 2, 3]> : vector<3xindex> +// CHECK: return %[[res]] : vector<3xindex> +func.func @indexCastFoldVector() -> vector<3xindex> { + %cst = arith.constant dense<[1, 2, 3]> : vector<3xi32> + %int = arith.index_cast %cst : vector<3xi32> to vector<3xindex> + return %int : vector<3xindex> +} + +// CHECK-LABEL: @indexCastFoldSplatVectorIndexToInt +// CHECK: %[[res:.*]] = arith.constant dense<42> : vector<3xi32> +// CHECK: return %[[res]] : vector<3xi32> +func.func @indexCastFoldSplatVectorIndexToInt() -> vector<3xi32> { + %cst = arith.constant dense<42> : vector<3xindex> + %int = arith.index_cast %cst : vector<3xindex> to vector<3xi32> + return %int : vector<3xi32> +} + +// CHECK-LABEL: @indexCastFoldVectorIndexToInt +// CHECK: %[[res:.*]] = arith.constant dense<[1, 2, 3]> : vector<3xi32> +// CHECK: return %[[res]] : vector<3xi32> +func.func @indexCastFoldVectorIndexToInt() -> vector<3xi32> { + %cst = arith.constant dense<[1, 2, 3]> : vector<3xindex> + %int = arith.index_cast %cst : vector<3xindex> to vector<3xi32> + return %int : vector<3xi32> +} + // CHECK-LABEL: @indexCastUIFold // CHECK: %[[res:.*]] = arith.constant 254 : index // CHECK: return %[[res]] @@ -455,6 +491,24 @@ func.func @indexCastUIFold() -> index { return %idx : index } +// CHECK-LABEL: @indexCastUIFoldSplatVector +// CHECK: %[[res:.*]] = arith.constant dense<42> : vector<3xindex> +// CHECK: return %[[res]] : vector<3xindex> +func.func @indexCastUIFoldSplatVector() -> vector<3xindex> { + %cst = arith.constant dense<42> : vector<3xi32> + %int = arith.index_castui %cst : vector<3xi32> to vector<3xindex> + return %int : vector<3xindex> +} + +// CHECK-LABEL: @indexCastUIFoldVector +// CHECK: %[[res:.*]] = arith.constant dense<[1, 2, 3]> : vector<3xindex> +// CHECK: return %[[res]] : vector<3xindex> +func.func @indexCastUIFoldVector() -> vector<3xindex> { + %cst = arith.constant dense<[1, 2, 3]> : vector<3xi32> + %int = arith.index_castui %cst : vector<3xi32> to vector<3xindex> + return %int : vector<3xindex> +} + // CHECK-LABEL: @indexCastUIFoldIndexToInt // CHECK: %[[res:.*]] = arith.constant 1 : i32 // CHECK: return %[[res]] @@ -464,6 +518,24 @@ func.func @indexCastUIFoldIndexToInt() -> i32 { return %int : i32 } +// CHECK-LABEL: @indexCastUIFoldSplatVectorIndexToInt +// CHECK: %[[res:.*]] = arith.constant dense<42> : vector<3xi32> +// CHECK: return %[[res]] : vector<3xi32> +func.func @indexCastUIFoldSplatVectorIndexToInt() -> vector<3xi32> { + %cst = arith.constant dense<42> : vector<3xindex> + %int = arith.index_castui %cst : vector<3xindex> to vector<3xi32> + return %int : vector<3xi32> +} + +// CHECK-LABEL: @indexCastUIFoldVectorIndexToInt +// CHECK: %[[res:.*]] = arith.constant dense<[1, 2, 3]> : vector<3xi32> +// CHECK: return %[[res]] : vector<3xi32> +func.func @indexCastUIFoldVectorIndexToInt() -> vector<3xi32> { + %cst = arith.constant dense<[1, 2, 3]> : vector<3xindex> + %int = arith.index_castui %cst : vector<3xindex> to vector<3xi32> + return %int : vector<3xi32> +} + // CHECK-LABEL: @signExtendConstant // CHECK: %[[cres:.+]] = arith.constant -2 : i16 // CHECK: return %[[cres]] -- 2.7.4