from .... import linalg
from .... import math
from .... import arith
+from .... import complex
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from .scalar_expr import *
def _unary_negf(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return arith.NegFOp(x).result
+ if _is_complex_type(x.type):
+ return complex.NegOp(x).result
raise NotImplementedError("Unsupported 'negf' operand: {x}")
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
return arith.AddFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.AddIOp(lhs, rhs).result
+ if _is_complex_type(lhs.type):
+ return complex.AddOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
return arith.SubFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.SubIOp(lhs, rhs).result
+ if _is_complex_type(lhs.type):
+ return complex.SubOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
return arith.MulFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MulIOp(lhs, rhs).result
+ if _is_complex_type(lhs.type):
+ return complex.MulOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
block_arg_types.append(element_or_self_type)
+def _is_complex_type(t: Type) -> bool:
+ return ComplexType.isinstance(t)
+
+
def _is_floating_point_type(t: Type) -> bool:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
+ c32 = ComplexType.get(f32)
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
def test_f32_elemwise_neg(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+ # CHECK-LABEL: @test_c32_elemwise_neg
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+ # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32>
+ # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32))
+ def test_c32_elemwise_neg(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
# Just check that we don't assert out on name mismatch.
# CHECK-LABEL: @test_non_default_op_name
@func.FuncOp.from_py_func(