Fix math.cbrt with vector and f16 arguments.
authorJohannes Reifferscheid <jreiffers@google.com>
Tue, 10 Jan 2023 20:15:49 +0000 (21:15 +0100)
committerJohannes Reifferscheid <jreiffers@google.com>
Tue, 10 Jan 2023 20:32:12 +0000 (21:32 +0100)
Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D141421

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/test/Conversion/MathToLibm/convert-to-libm.mlir

index 8a8adb5..c48686e 100644 (file)
@@ -154,19 +154,20 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
 void mlir::populateMathToLibmConversionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit,
     llvm::Optional<PatternBenefit> log1pBenefit) {
-  patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
-               VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
-               VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
-               VecOpToScalarOp<math::RoundEvenOp>,
+  patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::CbrtOp>,
+               VecOpToScalarOp<math::ExpM1Op>, VecOpToScalarOp<math::TanhOp>,
+               VecOpToScalarOp<math::CosOp>, VecOpToScalarOp<math::SinOp>,
+               VecOpToScalarOp<math::ErfOp>, VecOpToScalarOp<math::RoundEvenOp>,
                VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
                VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>(
       patterns.getContext(), benefit);
-  patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
-               PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
-               PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
-               PromoteOpToF32<math::RoundEvenOp>, PromoteOpToF32<math::RoundOp>,
-               PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>,
-               PromoteOpToF32<math::TruncOp>>(patterns.getContext(), benefit);
+  patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::CbrtOp>,
+               PromoteOpToF32<math::ExpM1Op>, PromoteOpToF32<math::TanhOp>,
+               PromoteOpToF32<math::CosOp>, PromoteOpToF32<math::SinOp>,
+               PromoteOpToF32<math::ErfOp>, PromoteOpToF32<math::RoundEvenOp>,
+               PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
+               PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(
+      patterns.getContext(), benefit);
   patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
                                                  "atan", benefit);
   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
index b0459d8..eb37505 100644 (file)
@@ -246,13 +246,22 @@ func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) {
 // CHECK-LABEL: func @cbrt_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64
-func.func @cbrt_caller(%float: f32, %double: f64) -> (f32, f64)  {
-  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
+func.func @cbrt_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16,
+                       %float_vec: vector<2xf32>) -> (f32, f64, f16, bf16, vector<2xf32>)  {
+  // CHECK: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
   %float_result = math.cbrt %float : f32
-  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
+  // CHECK: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
   %double_result = math.cbrt %double : f64
+  // Just check that these lower successfully:
+  // CHECK: call @cbrtf
+  %half_result = math.cbrt %half : f16
+  // CHECK: call @cbrtf
+  %bfloat_result = math.cbrt %bfloat : bf16
+  // CHECK: call @cbrtf
+  %vec_result = math.cbrt %float_vec : vector<2xf32>
   // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
-  return %float_result, %double_result : f32, f64
+  return %float_result, %double_result, %half_result, %bfloat_result, %vec_result
+    : f32, f64, f16, bf16, vector<2xf32>
 }
 
 // CHECK-LABEL: func @cos_caller