[mlir][linalg] Add a few unary operations.
authorBixia Zheng <bixia@google.com>
Thu, 10 Mar 2022 17:08:41 +0000 (09:08 -0800)
committerBixia Zheng <bixia@google.com>
Thu, 10 Mar 2022 17:38:58 +0000 (09:38 -0800)
Add operations abs, ceil, floor, and neg to the C++ API and Python API.

Add test cases.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D121339

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/python/dialects/linalg/opdsl/emit_misc.py

index f962eb6..a1a8477 100644 (file)
@@ -61,7 +61,11 @@ def Linalg_Dialect : Dialect {
 // Define the function attribute enums matching the OpDSL functions.
 def UnaryFn : I32EnumAttr<"UnaryFn", "", [
   I32EnumAttrCase<"exp", 0>,
-  I32EnumAttrCase<"log", 1>
+  I32EnumAttrCase<"log", 1>,
+  I32EnumAttrCase<"abs", 2>,
+  I32EnumAttrCase<"ceil", 3>,
+  I32EnumAttrCase<"floor", 4>,
+  I32EnumAttrCase<"negf", 5>
 ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::linalg";
index 02ed755..8880c16 100644 (file)
@@ -144,6 +144,14 @@ public:
       return builder.create<math::ExpOp>(arg.getLoc(), arg);
     case UnaryFn::log:
       return builder.create<math::LogOp>(arg.getLoc(), arg);
+    case UnaryFn::abs:
+      return builder.create<math::AbsOp>(arg.getLoc(), arg);
+    case UnaryFn::ceil:
+      return builder.create<math::CeilOp>(arg.getLoc(), arg);
+    case UnaryFn::floor:
+      return builder.create<math::FloorOp>(arg.getLoc(), arg);
+    case UnaryFn::negf:
+      return builder.create<arith::NegFOp>(arg.getLoc(), arg);
     }
     llvm_unreachable("unsupported unary function");
   }
index 47083de..135f55e 100644 (file)
@@ -274,6 +274,10 @@ class UnaryFn:
   """Unary function namespace."""
   exp = UnaryFnType("exp")
   log = UnaryFnType("log")
+  abs = UnaryFnType("abs")
+  ceil = UnaryFnType("ceil")
+  floor = UnaryFnType("floor")
+  negf = UnaryFnType("negf")
 
 
 class BinaryFnType:
index 93baef1..2e71e56 100644 (file)
@@ -390,6 +390,26 @@ class _BodyBuilder:
       return math.LogOp(x).result
     raise NotImplementedError("Unsupported 'log' operand: {x}")
 
+  def _unary_abs(self, x: Value) -> Value:
+    if _is_floating_point_type(x.type):
+      return math.AbsOp(x).result
+    raise NotImplementedError("Unsupported 'abs' operand: {x}")
+
+  def _unary_ceil(self, x: Value) -> Value:
+    if _is_floating_point_type(x.type):
+      return math.CeilOp(x).result
+    raise NotImplementedError("Unsupported 'ceil' operand: {x}")
+
+  def _unary_floor(self, x: Value) -> Value:
+    if _is_floating_point_type(x.type):
+      return math.FloorOp(x).result
+    raise NotImplementedError("Unsupported 'floor' operand: {x}")
+
+  def _unary_negf(self, x: Value) -> Value:
+    if _is_floating_point_type(x.type):
+      return arith.NegFOp(x).result
+    raise NotImplementedError("Unsupported 'negf' operand: {x}")
+
   def _binary_add(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return arith.AddFOp(lhs, rhs).result
index ebb9a87..3ac2d75 100644 (file)
@@ -298,6 +298,54 @@ func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>)
 
 // -----
 
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_abs(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>}
+                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_abs
+// CHECK:        = math.abs
+
+// -----
+
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_ceil(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<ceil>}
+                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_ceil
+// CHECK:        = math.ceil
+
+// -----
+
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_floor(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<floor>}
+                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_floor
+// CHECK:        = math.floor
+
+// -----
+
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_negf(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
+                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_negf
+// CHECK:        = arith.negf
+
+// -----
+
 // Verifies the default value of the fun attribute is an add op.
 func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
   %0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
index e69d71d..e57a49b 100644 (file)
@@ -11,7 +11,7 @@ from mlir.dialects.linalg.opdsl.lang import *
 # fill, matmul, convolution, or pooling tests. The features include:
 # - constant defined in the body
 # - fix/predefined types
-# - exponential functions
+# - some math/arith functions, including abs, ceil, exp, floor, log, and negf
 # - custom op names.
 
 
@@ -89,6 +89,46 @@ with Context() as ctx, Location.unknown():
     def test_f32_elemwise_log(input, init_result):
       return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
 
+    # CHECK-LABEL: @test_f32_elemwise_abs
+    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+    # CHECK-NEXT:   %[[EXP:.+]] = math.abs %[[IN]] : f32
+    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+    # CHECK-NEXT: -> tensor<4x16xf32>
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+    def test_f32_elemwise_abs(input, init_result):
+      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
+    # CHECK-LABEL: @test_f32_elemwise_ceil
+    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+    # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
+    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+    # CHECK-NEXT: -> tensor<4x16xf32>
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+    def test_f32_elemwise_ceil(input, init_result):
+      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
+
+    # CHECK-LABEL: @test_f32_elemwise_floor
+    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+    # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
+    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+    # CHECK-NEXT: -> tensor<4x16xf32>
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+    def test_f32_elemwise_floor(input, init_result):
+      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
+
+    # CHECK-LABEL: @test_f32_elemwise_neg
+    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+    # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
+    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+    # CHECK-NEXT: -> tensor<4x16xf32>
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+    def test_f32_elemwise_neg(input, init_result):
+      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
     # Just check that we don't assert out on name mismatch.
     # CHECK-LABEL: @test_non_default_op_name
     @builtin.FuncOp.from_py_func(