[mlir][linalg] Extend opdsl to support operations on complex types.
authorbixia1 <bixia@google.com>
Thu, 16 Jun 2022 21:27:26 +0000 (14:27 -0700)
committerbixia1 <bixia@google.com>
Fri, 17 Jun 2022 16:34:26 +0000 (09:34 -0700)
Linalg opdsl now supports negf/add/sub/mul on complex types.

Add a test.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D128010

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/python/dialects/linalg/opdsl/emit_misc.py

index 2e71e561a7f543b16fced59bee7f316e02c074bf..cc99081b440d03fdfa7becc1d7fe74667d9ede91 100644 (file)
@@ -10,6 +10,7 @@ from .... import func
 from .... import linalg
 from .... import math
 from .... import arith
+from .... import complex
 from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
 
 from .scalar_expr import *
@@ -408,6 +409,8 @@ class _BodyBuilder:
   def _unary_negf(self, x: Value) -> Value:
     if _is_floating_point_type(x.type):
       return arith.NegFOp(x).result
+    if _is_complex_type(x.type):
+      return complex.NegOp(x).result
     raise NotImplementedError("Unsupported 'negf' operand: {x}")
 
   def _binary_add(self, lhs: Value, rhs: Value) -> Value:
@@ -415,6 +418,8 @@ class _BodyBuilder:
       return arith.AddFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.AddIOp(lhs, rhs).result
+    if _is_complex_type(lhs.type):
+      return complex.AddOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
 
   def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
@@ -422,6 +427,8 @@ class _BodyBuilder:
       return arith.SubFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.SubIOp(lhs, rhs).result
+    if _is_complex_type(lhs.type):
+      return complex.SubOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
 
   def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
@@ -429,6 +436,8 @@ class _BodyBuilder:
       return arith.MulFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.MulIOp(lhs, rhs).result
+    if _is_complex_type(lhs.type):
+      return complex.MulOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
 
   def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
@@ -512,6 +521,10 @@ def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type,
   block_arg_types.append(element_or_self_type)
 
 
+def _is_complex_type(t: Type) -> bool:
+  return ComplexType.isinstance(t)
+
+
 def _is_floating_point_type(t: Type) -> bool:
   # TODO: Create a FloatType in the Python API and implement the switch
   # there.
index 2d045125f28589288fb50d96da816752975c2610..ddb5cc8248024ceea863b9d5df32166933b1a4ce 100644 (file)
@@ -44,6 +44,7 @@ def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
 with Context() as ctx, Location.unknown():
   module = Module.create()
   f32 = F32Type.get()
+  c32 = ComplexType.get(f32)
   i32 = IntegerType.get_signless(32)
   with InsertionPoint(module.body):
 
@@ -129,6 +130,16 @@ with Context() as ctx, Location.unknown():
     def test_f32_elemwise_neg(input, init_result):
       return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
 
+    # CHECK-LABEL: @test_c32_elemwise_neg
+    # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+    # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
+    # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+    # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+    @func.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32))
+    def test_c32_elemwise_neg(input, init_result):
+      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
     # Just check that we don't assert out on name mismatch.
     # CHECK-LABEL: @test_non_default_op_name
     @func.FuncOp.from_py_func(