[mlir] Add std op for X raised to the power of Y
authorTres Popp <tpopp@google.com>
Thu, 10 Dec 2020 22:49:42 +0000 (23:49 +0100)
committerTres Popp <tpopp@google.com>
Tue, 15 Dec 2020 16:06:26 +0000 (17:06 +0100)
Proposal:
https://llvm.discourse.group/t/rfc-standard-add-powop-to-std-dialect/2377

Differential Revision: https://reviews.llvm.org/D93119

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/test/IR/core-ops.mlir

index c218618..481dfaf 100644 (file)
@@ -2353,6 +2353,39 @@ def OrOp : IntArithmeticOp<"or", [Commutative]> {
 }
 
 //===----------------------------------------------------------------------===//
+// PowFOp
+//===----------------------------------------------------------------------===//
+
+def PowFOp : FloatArithmeticOp<"powf"> {
+  let summary = "floating point raised to the power of operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `std.powf` ssa-use `,` ssa-use `:` type
+    ```
+
+    The `powf` operation takes two operands and returns one result, each of
+    these is required to be the same type. This type may be a floating point
+    scalar type, a vector whose element type is a floating point type, or a
+    floating point tensor.
+
+    Example:
+
+    ```mlir
+    // Scalar exponentiation.
+    %a = powf %b, %c : f64
+
+    // SIMD pointwise vector exponentiation
+    %f = powf %g, %h : vector<4xf32>
+
+    // Tensor pointwise exponentiation.
+    %x = powf %y, %z : tensor<4x?xbf16>
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // PrefetchOp
 //===----------------------------------------------------------------------===//
 
index 9af0c01..502e7fb 100644 (file)
@@ -86,6 +86,9 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: %[[I6:.*]] = muli %[[I2]], %[[I2]] : i32
   %i6 = muli %i2, %i2 : i32
 
+  // CHECK: %[[F7:.*]] = powf %[[F2]], %[[F2]] : f32
+  %f7 = powf %f2, %f2 : f32
+
   // CHECK: %[[C0:.*]] = create_complex %[[F2]], %[[F2]] : complex<f32>
   %c0 = "std.create_complex"(%f2, %f2) : (f32, f32) -> complex<f32>