Summary: Add and pipe through the sqrt operation for Standard and LLVM dialects.
Reviewers: nicolasvasilache, ftynse
Reviewed By: ftynse
Subscribers: frej, ftynse, merge_guards_bot, flaub, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D73571
scalar type, a vector whose element type is float, or a tensor of floats. It
has no standard attributes.
+### 'sqrt' operation
+
+Syntax:
+
+```
+operation ::= ssa-id `=` `sqrt` ssa-use `:` type
+```
+
+Examples:
+
+```mlir
+// Scalar square root value.
+%a = sqrt %b : f64
+// SIMD vector element-wise square root value.
+%f = sqrt %g : vector<4xf32>
+// Tensor element-wise square root value.
+%x = sqrt %y : tensor<4x?xf32>
+```
+
### 'tanh' operation
Syntax:
def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">;
+def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
def LLVM_LogOp : LLVM_Op<"intr.log", [NoSideEffect]>,
Arguments<(ins LLVM_Type:$in)>,
let hasCanonicalizer = 1;
}
+def SqrtOp : FloatUnaryOp<"sqrt"> {
+ let summary = "sqrt of the specified value";
+ let description = [{
+ The `sqrt` operation computes the square root. It takes one operand and
+ returns one result of the same type. This type may be a float scalar type, a
+ vector whose element type is float, or a tensor of floats. It has no standard
+ attributes.
+ }];
+}
+
def TanhOp : FloatUnaryOp<"tanh"> {
let summary = "hyperbolic tangent of the specified value";
let description = [{
: public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
using Super::Super;
};
+struct SqrtOpLowering : public UnaryOpLLVMOpLowering<SqrtOp, LLVM::SqrtOp> {
+ using Super::Super;
+};
struct UnsignedDivIOpLowering
: public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
using Super::Super;
SignedShiftRightOpLowering,
SplatOpLowering,
SplatNdOpLowering,
+ SqrtOpLowering,
SubFOpLowering,
SubIOpLowering,
TanhOpLowering,
}
// CHECK-LABEL: @ops
-func @ops(f32, f32, i32, i32) -> (f32, i32) {
-^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32):
+func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
+^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
// CHECK-NEXT: %0 = llvm.fsub %arg0, %arg1 : !llvm.float
%0 = subf %arg0, %arg1: f32
// CHECK-NEXT: %1 = llvm.sub %arg2, %arg3 : !llvm.i32
%19 = shift_right_signed %arg2, %arg3 : i32
// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
%20 = shift_right_unsigned %arg2, %arg3 : i32
-
+// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
+ %21 = std.sqrt %arg0 : f32
+// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
+ %22 = std.sqrt %arg4 : f64
return %0, %4 : f32, i32
}
// CHECK: %{{[0-9]+}} = shift_right_unsigned %cst_4, %cst_4 : tensor<42xi32>
%138 = shift_right_unsigned %tci32, %tci32 : tensor<42 x i32>
+ // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
+ %139 = "std.sqrt"(%f) : (f32) -> f32
+
+ // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
+ %140 = sqrt %f : f32
+
+ // CHECK: %{{[0-9]+}} = sqrt %cst_8 : vector<4xf32>
+ %141 = sqrt %vcf32 : vector<4xf32>
+
+ // CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32>
+ %142 = sqrt %t : tensor<4x4x?xf32>
+
return
}
llvm.return
}
+// CHECK-LABEL: @sqrt_test
+llvm.func @sqrt_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
+ // CHECK: call float @llvm.sqrt.f32
+ "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
+ // CHECK: call <8 x float> @llvm.sqrt.v8f32
+ "llvm.intr.sqrt"(%arg1) : (!llvm<"<8 x float>">) -> !llvm<"<8 x float>">
+ llvm.return
+}
+
// CHECK-LABEL: @ceil_test
llvm.func @ceil_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
// CHECK: call float @llvm.ceil.f32
// CHECK: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0
// CHECK: declare float @llvm.fabs.f32(float)
// CHECK: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0
+// CHECK: declare float @llvm.sqrt.f32(float)
+// CHECK: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0
// CHECK: declare float @llvm.ceil.f32(float)
// CHECK: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0
// CHECK: declare float @llvm.cos.f32(float)