[MLIR] Add the sqrt operation to mlir.
authorLubomir Litchev <Lubomir.Litchev@intel.com>
Thu, 30 Jan 2020 15:44:44 +0000 (07:44 -0800)
committerFrank Laub <frank.laub@intel.com>
Thu, 30 Jan 2020 16:07:38 +0000 (08:07 -0800)
Summary: Add and pipe through the sqrt operation for Standard and LLVM dialects.

Reviewers: nicolasvasilache, ftynse

Reviewed By: ftynse

Subscribers: frej, ftynse, merge_guards_bot, flaub, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73571

mlir/docs/Dialects/Standard.md
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/IR/core-ops.mlir
mlir/test/Target/llvmir-intrinsics.mlir

index 0d30296..e956b64 100644 (file)
@@ -587,6 +587,25 @@ operand and returns one result of the same type. This type may be a float
 scalar type, a vector whose element type is float, or a tensor of floats. It
 has no standard attributes.
 
+### 'sqrt' operation
+
+Syntax:
+
+```
+operation ::= ssa-id `=` `sqrt` ssa-use `:` type
+```
+
+Examples:
+
+```mlir
+// Scalar square root value.
+%a = sqrt %b : f64
+// SIMD vector element-wise square root value.
+%f = sqrt %g : vector<4xf32>
+// Tensor element-wise square root value.
+%x = sqrt %y : tensor<4x?xf32>
+```
+
 ### 'tanh' operation
 
 Syntax:
index 42635e4..b1abf61 100644 (file)
@@ -716,6 +716,7 @@ def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">;
 def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
 def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
 def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">;
+def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
 
 def LLVM_LogOp : LLVM_Op<"intr.log", [NoSideEffect]>,
                    Arguments<(ins LLVM_Type:$in)>,
index de205f5..6b49eed 100644 (file)
@@ -1402,6 +1402,16 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
   let hasCanonicalizer = 1;
 }
 
+def SqrtOp : FloatUnaryOp<"sqrt"> {
+  let summary = "sqrt of the specified value";
+  let description = [{
+    The `sqrt` operation computes the square root. It takes one operand and
+    returns one result of the same type. This type may be a float scalar type, a
+    vector whose element type is float, or a tensor of floats. It has no standard
+    attributes.
+  }];
+}
+
 def TanhOp : FloatUnaryOp<"tanh"> {
   let summary = "hyperbolic tangent of the specified value";
   let description = [{
index 1a2d3b0..fa4d33c 100644 (file)
@@ -807,6 +807,9 @@ struct SignedDivIOpLowering
     : public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
   using Super::Super;
 };
+struct SqrtOpLowering : public UnaryOpLLVMOpLowering<SqrtOp, LLVM::SqrtOp> {
+  using Super::Super;
+};
 struct UnsignedDivIOpLowering
     : public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
   using Super::Super;
@@ -2108,6 +2111,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       SignedShiftRightOpLowering,
       SplatOpLowering,
       SplatNdOpLowering,
+      SqrtOpLowering,
       SubFOpLowering,
       SubIOpLowering,
       TanhOpLowering,
index 4514723..66a99ae 100644 (file)
@@ -398,8 +398,8 @@ func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>
 }
 
 // CHECK-LABEL: @ops
-func @ops(f32, f32, i32, i32) -> (f32, i32) {
-^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32):
+func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
+^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
 // CHECK-NEXT:  %0 = llvm.fsub %arg0, %arg1 : !llvm.float
   %0 = subf %arg0, %arg1: f32
 // CHECK-NEXT:  %1 = llvm.sub %arg2, %arg3 : !llvm.i32
@@ -440,7 +440,10 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) {
   %19 = shift_right_signed %arg2, %arg3 : i32
 // CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
   %20 = shift_right_unsigned %arg2, %arg3 : i32
-
+// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
+  %21 = std.sqrt %arg0 : f32
+// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
+  %22 = std.sqrt %arg4 : f64
   return %0, %4 : f32, i32
 }
 
index 3590a28..2318ef5 100644 (file)
@@ -494,6 +494,18 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: %{{[0-9]+}} = shift_right_unsigned %cst_4, %cst_4 : tensor<42xi32>
   %138 = shift_right_unsigned %tci32, %tci32 : tensor<42 x i32>
 
+  // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
+  %139 = "std.sqrt"(%f) : (f32) -> f32
+
+  // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
+  %140 = sqrt %f : f32
+
+  // CHECK: %{{[0-9]+}} = sqrt %cst_8 : vector<4xf32>
+  %141 = sqrt %vcf32 : vector<4xf32>
+
+  // CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32>
+  %142 = sqrt %t : tensor<4x4x?xf32>
+
   return
 }
 
index 9a5e432..343b1f0 100644 (file)
@@ -59,6 +59,15 @@ llvm.func @fabs_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
   llvm.return
 }
 
+// CHECK-LABEL: @sqrt_test
+llvm.func @sqrt_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
+  // CHECK: call float @llvm.sqrt.f32
+  "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
+  // CHECK: call <8 x float> @llvm.sqrt.v8f32
+  "llvm.intr.sqrt"(%arg1) : (!llvm<"<8 x float>">) -> !llvm<"<8 x float>">
+  llvm.return
+}
+
 // CHECK-LABEL: @ceil_test
 llvm.func @ceil_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
   // CHECK: call float @llvm.ceil.f32
@@ -100,6 +109,8 @@ llvm.func @copysign_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<
 // CHECK: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0
 // CHECK: declare float @llvm.fabs.f32(float)
 // CHECK: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0
+// CHECK: declare float @llvm.sqrt.f32(float)
+// CHECK: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0
 // CHECK: declare float @llvm.ceil.f32(float)
 // CHECK: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0
 // CHECK: declare float @llvm.cos.f32(float)