[mlir][arith] Fold `index_cast[ui]` of vectors
authorJakub Kuderski <kubak@google.com>
Wed, 29 Mar 2023 16:49:32 +0000 (12:49 -0400)
committerJakub Kuderski <kubak@google.com>
Wed, 29 Mar 2023 16:51:51 +0000 (12:51 -0400)
Handle the splat and dense case.

I saw this pattern show up in a couple recent SPIR-V-specific bug
report.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D147109

mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir

index e56f452..d7ce71a 100644 (file)
@@ -1455,12 +1455,15 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
 
 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
   // index_cast(constant) -> constant
-  // A little hack because we go through int. Otherwise, the size of the
-  // constant might need to change.
-  if (auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
-    return IntegerAttr::get(getType(), value.getInt());
+  unsigned resultBitwidth = 64; // Default for index integer attributes.
+  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
+    resultBitwidth = intTy.getWidth();
 
-  return {};
+  return constFoldCastOp<IntegerAttr, IntegerAttr>(
+      adaptor.getOperands(), getType(),
+      [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
+        return a.sextOrTrunc(resultBitwidth);
+      });
 }
 
 void arith::IndexCastOp::getCanonicalizationPatterns(
@@ -1479,12 +1482,15 @@ bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
 
 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
   // index_castui(constant) -> constant
-  // A little hack because we go through int. Otherwise, the size of the
-  // constant might need to change.
-  if (auto value = adaptor.getIn().dyn_cast_or_null<IntegerAttr>())
-    return IntegerAttr::get(getType(), value.getValue().getZExtValue());
+  unsigned resultBitwidth = 64; // Default for index integer attributes.
+  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
+    resultBitwidth = intTy.getWidth();
 
-  return {};
+  return constFoldCastOp<IntegerAttr, IntegerAttr>(
+      adaptor.getOperands(), getType(),
+      [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
+        return a.zextOrTrunc(resultBitwidth);
+      });
 }
 
 void arith::IndexCastUIOp::getCanonicalizationPatterns(
index b75dfd7..0170620 100644 (file)
@@ -446,6 +446,42 @@ func.func @indexCastFoldIndexToInt() -> i32 {
   return %int : i32
 }
 
+// CHECK-LABEL: @indexCastFoldSplatVector
+//       CHECK:   %[[res:.*]] = arith.constant dense<42> : vector<3xindex>
+//       CHECK:   return %[[res]] : vector<3xindex>
+func.func @indexCastFoldSplatVector() -> vector<3xindex> {
+  %cst = arith.constant dense<42> : vector<3xi32>
+  %int = arith.index_cast %cst : vector<3xi32> to vector<3xindex>
+  return %int : vector<3xindex>
+}
+
+// CHECK-LABEL: @indexCastFoldVector
+//       CHECK:   %[[res:.*]] = arith.constant dense<[1, 2, 3]> : vector<3xindex>
+//       CHECK:   return %[[res]] : vector<3xindex>
+func.func @indexCastFoldVector() -> vector<3xindex> {
+  %cst = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+  %int = arith.index_cast %cst : vector<3xi32> to vector<3xindex>
+  return %int : vector<3xindex>
+}
+
+// CHECK-LABEL: @indexCastFoldSplatVectorIndexToInt
+//       CHECK:   %[[res:.*]] = arith.constant dense<42> : vector<3xi32>
+//       CHECK:   return %[[res]] : vector<3xi32>
+func.func @indexCastFoldSplatVectorIndexToInt() -> vector<3xi32> {
+  %cst = arith.constant dense<42> : vector<3xindex>
+  %int = arith.index_cast %cst : vector<3xindex> to vector<3xi32>
+  return %int : vector<3xi32>
+}
+
+// CHECK-LABEL: @indexCastFoldVectorIndexToInt
+//       CHECK:   %[[res:.*]] = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+//       CHECK:   return %[[res]] : vector<3xi32>
+func.func @indexCastFoldVectorIndexToInt() -> vector<3xi32> {
+  %cst = arith.constant dense<[1, 2, 3]> : vector<3xindex>
+  %int = arith.index_cast %cst : vector<3xindex> to vector<3xi32>
+  return %int : vector<3xi32>
+}
+
 // CHECK-LABEL: @indexCastUIFold
 //       CHECK:   %[[res:.*]] = arith.constant 254 : index
 //       CHECK:   return %[[res]]
@@ -455,6 +491,24 @@ func.func @indexCastUIFold() -> index {
   return %idx : index
 }
 
+// CHECK-LABEL: @indexCastUIFoldSplatVector
+//       CHECK:   %[[res:.*]] = arith.constant dense<42> : vector<3xindex>
+//       CHECK:   return %[[res]] : vector<3xindex>
+func.func @indexCastUIFoldSplatVector() -> vector<3xindex> {
+  %cst = arith.constant dense<42> : vector<3xi32>
+  %int = arith.index_castui %cst : vector<3xi32> to vector<3xindex>
+  return %int : vector<3xindex>
+}
+
+// CHECK-LABEL: @indexCastUIFoldVector
+//       CHECK:   %[[res:.*]] = arith.constant dense<[1, 2, 3]> : vector<3xindex>
+//       CHECK:   return %[[res]] : vector<3xindex>
+func.func @indexCastUIFoldVector() -> vector<3xindex> {
+  %cst = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+  %int = arith.index_castui %cst : vector<3xi32> to vector<3xindex>
+  return %int : vector<3xindex>
+}
+
 // CHECK-LABEL: @indexCastUIFoldIndexToInt
 //       CHECK:   %[[res:.*]] = arith.constant 1 : i32
 //       CHECK:   return %[[res]]
@@ -464,6 +518,24 @@ func.func @indexCastUIFoldIndexToInt() -> i32 {
   return %int : i32
 }
 
+// CHECK-LABEL: @indexCastUIFoldSplatVectorIndexToInt
+//       CHECK:   %[[res:.*]] = arith.constant dense<42> : vector<3xi32>
+//       CHECK:   return %[[res]] : vector<3xi32>
+func.func @indexCastUIFoldSplatVectorIndexToInt() -> vector<3xi32> {
+  %cst = arith.constant dense<42> : vector<3xindex>
+  %int = arith.index_castui %cst : vector<3xindex> to vector<3xi32>
+  return %int : vector<3xi32>
+}
+
+// CHECK-LABEL: @indexCastUIFoldVectorIndexToInt
+//       CHECK:   %[[res:.*]] = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+//       CHECK:   return %[[res]] : vector<3xi32>
+func.func @indexCastUIFoldVectorIndexToInt() -> vector<3xi32> {
+  %cst = arith.constant dense<[1, 2, 3]> : vector<3xindex>
+  %int = arith.index_castui %cst : vector<3xindex> to vector<3xi32>
+  return %int : vector<3xi32>
+}
+
 // CHECK-LABEL: @signExtendConstant
 //       CHECK:   %[[cres:.+]] = arith.constant -2 : i16
 //       CHECK:   return %[[cres]]