[flang] Handle special case for SHIFTA intrinsic
authorValentin Clement <clementval@gmail.com>
Thu, 1 Sep 2022 14:27:51 +0000 (16:27 +0200)
committerValentin Clement <clementval@gmail.com>
Thu, 1 Sep 2022 14:28:08 +0000 (16:28 +0200)
This patch update the lowering of the shifta intrinsic to match
the behvior of gfortran. When the SHIFT value is equal to the
integer bitwidth then we handle it differently.
This is due to the operation used in lowering (`mlir::arith::ShRSIOp`)
that lowers to `ashr`.

Before this patch we have the following results:

```
SHIFTA(  -1, 8) =  0
SHIFTA(  -2, 8) =  0
SHIFTA( -30, 8) =  0
SHIFTA( -31, 8) =  0
SHIFTA( -32, 8) =  0
SHIFTA( -33, 8) =  0
SHIFTA(-126, 8) =  0
SHIFTA(-127, 8) =  0
SHIFTA(-128, 8) =  0
```

While gfortran is giving this:

```
SHIFTA(  -1, 8) = -1
SHIFTA(  -2, 8) = -1
SHIFTA( -30, 8) = -1
SHIFTA( -31, 8) = -1
SHIFTA( -32, 8) = -1
SHIFTA( -33, 8) = -1
SHIFTA(-126, 8) = -1
SHIFTA(-127, 8) = -1
SHIFTA(-128, 8) = -1
```

With this patch flang and gfortran have the same behavior.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D133104

flang/lib/Lower/IntrinsicCall.cpp
flang/test/Lower/Intrinsics/shifta.f90

index c5a1121..545c2d9 100644 (file)
@@ -557,6 +557,7 @@ struct IntrinsicLibrary {
                              llvm::ArrayRef<mlir::Value> args);
   template <typename Shift>
   mlir::Value genShift(mlir::Type resultType, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genShiftA(mlir::Type resultType, llvm::ArrayRef<mlir::Value>);
   mlir::Value genSign(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genSize(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genSpacing(mlir::Type resultType,
@@ -958,7 +959,7 @@ static constexpr IntrinsicHandler handlers[]{
        {"radix", asAddr, handleDynamicOptional}}},
      /*isElemental=*/false},
     {"set_exponent", &I::genSetExponent},
-    {"shifta", &I::genShift<mlir::arith::ShRSIOp>},
+    {"shifta", &I::genShiftA},
     {"shiftl", &I::genShift<mlir::arith::ShLIOp>},
     {"shiftr", &I::genShift<mlir::arith::ShRUIOp>},
     {"sign", &I::genSign},
@@ -4015,7 +4016,7 @@ mlir::Value IntrinsicLibrary::genSetExponent(mlir::Type resultType,
                                    fir::getBase(args[1])));
 }
 
-// SHIFTA, SHIFTL, SHIFTR
+// SHIFTL, SHIFTR
 template <typename Shift>
 mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType,
                                        llvm::ArrayRef<mlir::Value> args) {
@@ -4041,6 +4042,31 @@ mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType,
   return builder.create<mlir::arith::SelectOp>(loc, outOfBounds, zero, shifted);
 }
 
+// SHIFTA
+mlir::Value IntrinsicLibrary::genShiftA(mlir::Type resultType,
+                                        llvm::ArrayRef<mlir::Value> args) {
+  unsigned bits = resultType.getIntOrFloatBitWidth();
+  mlir::Value bitSize = builder.createIntegerConstant(loc, resultType, bits);
+  mlir::Value shift = builder.createConvert(loc, resultType, args[1]);
+  mlir::Value shiftEqBitSize = builder.create<mlir::arith::CmpIOp>(
+      loc, mlir::arith::CmpIPredicate::eq, shift, bitSize);
+
+  // Lowering of mlir::arith::ShRSIOp is using `ashr`. `ashr` is undefined when
+  // the shift amount is equal to the element size.
+  // So if SHIFT is equal to the bit width then it is handled as a special case.
+  mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
+  mlir::Value minusOne = builder.createIntegerConstant(loc, resultType, -1);
+  mlir::Value valueIsNeg = builder.create<mlir::arith::CmpIOp>(
+      loc, mlir::arith::CmpIPredicate::slt, args[0], zero);
+  mlir::Value specialRes =
+      builder.create<mlir::arith::SelectOp>(loc, valueIsNeg, minusOne, zero);
+
+  mlir::Value shifted =
+      builder.create<mlir::arith::ShRSIOp>(loc, args[0], shift);
+  return builder.create<mlir::arith::SelectOp>(loc, shiftEqBitSize, specialRes,
+                                               shifted);
+}
+
 // SIGN
 mlir::Value IntrinsicLibrary::genSign(mlir::Type resultType,
                                       llvm::ArrayRef<mlir::Value> args) {
index 311206b..92fda39 100644 (file)
@@ -12,13 +12,14 @@ subroutine shifta1_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 8 : i8
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i8
   ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i8
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i8
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i8
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i8
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i8
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i8
+  ! CHECK: %[[CM1:.*]] = arith.constant -1 : i8
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i8
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i8
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i8
 end subroutine shifta1_test
 
 ! CHECK-LABEL: shifta2_test
@@ -32,13 +33,14 @@ subroutine shifta2_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 16 : i16
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i16
   ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i16
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i16
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i16
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i16
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i16
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i16
+  ! CHECK: %[[CM1:.*]] = arith.constant -1 : i16
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i16
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i16
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i16
 end subroutine shifta2_test
 
 ! CHECK-LABEL: shifta4_test
@@ -52,12 +54,13 @@ subroutine shifta4_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 32 : i32
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i32
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_VAL]], %[[C_0]] : i32
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_VAL]], %[[C_BITS]] : i32
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i32
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_VAL]], %[[C_BITS]] : i32
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i32
+  ! CHECK: %[[CM1:.*]] = arith.constant -1 : i32
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i32
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i32
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i32
 end subroutine shifta4_test
 
 ! CHECK-LABEL: shifta8_test
@@ -71,13 +74,14 @@ subroutine shifta8_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 64 : i64
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i64
   ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i64
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i64
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i64
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i64
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i64
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i64
+  ! CHECK: %[[CM1:.*]] = arith.constant -1 : i64
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i64
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i64
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i64
 end subroutine shifta8_test
 
 ! CHECK-LABEL: shifta16_test
@@ -91,11 +95,12 @@ subroutine shifta16_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 128 : i128
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i128
   ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i128
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i128
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i128
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i128
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i128
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i128
+  ! CHECK: %[[CM1:.*]] = arith.constant {{.*}} : i128
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i128
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i128
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i128
 end subroutine shifta16_test