[mlir][arith] Fold on extension of FP constants using arith.extf
authorVictor Perez <victor.perez@codeplay.com>
Fri, 17 Feb 2023 10:16:42 +0000 (10:16 +0000)
committerVictor Perez <victor.perez@codeplay.com>
Thu, 23 Feb 2023 12:38:55 +0000 (12:38 +0000)
It is safe to fold when extending, as we will not lose precision.

Differential Revision: https://reviews.llvm.org/D144251

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir

index fb790c1..35f4d07 100644 (file)
@@ -1047,6 +1047,7 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
     When operating on vectors, casts elementwise.
   }];
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
index d5fb08f..f6308a6 100644 (file)
@@ -1224,6 +1224,16 @@ LogicalResult arith::ExtSIOp::verify() {
 // ExtFOp
 //===----------------------------------------------------------------------===//
 
+/// Always fold extension of FP constants.
+OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
+  auto constOperand = adaptor.getIn().dyn_cast_or_null<FloatAttr>();
+  if (!constOperand)
+    return {};
+
+  // Convert to target type via 'double'.
+  return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
+}
+
 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
 }
index 355e7a8..eaafa9e 100644 (file)
@@ -502,6 +502,15 @@ func.func @unsignedExtendConstantVector() -> vector<4xi16> {
   return %ext : vector<4xi16>
 }
 
+// CHECK-LABEL: @extFPConstant
+//       CHECK:   %[[cres:.+]] = arith.constant 1.000000e+00 : f64
+//       CHECK:   return %[[cres]]
+func.func @extFPConstant() -> f64 {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = arith.extf %cst : f32 to f64
+  return %0 : f64
+}
+
 // CHECK-LABEL: @truncConstant
 //       CHECK:   %[[cres:.+]] = arith.constant -2 : i16
 //       CHECK:   return %[[cres]]