Split arithmetic function into unary and binary functions. The revision prepares the introduction of unary and binary function attributes that work similar to type function attributes.
Depends On D120108
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D120109
Reduction dimensions are inferred to be any dimensions on the RHS that are not
on the LHS.
-A number of arithmetic functions are supported:
-
-* `ArithFn.add(a, b)` (also via overloading the binary `+` operator)
-* `ArithFn.exp(a)`
-* `ArithFn.log(a)`
-* `ArithFn.mul(a, b)` (also via overloading the binary `*` operator)
-* `ArithFn.max(a, b)`
-* `ArithFn.min(a, b)`
-* `ArithFn.sub(a, b)` (also via overloading the binary `-` operator)
-* `ArithFn.max_unsigned(a, b)`
-* `ArithFn.min_unsigned(a, b)`
+A number of unary and binary arithmetic functions are supported:
+
+* `BinaryFn.add(a, b)` (also via overloading the binary `+` operator)
+* `BinaryFn.mul(a, b)` (also via overloading the binary `*` operator)
+* `BinaryFn.max(a, b)`
+* `BinaryFn.min(a, b)`
+* `BinaryFn.sub(a, b)` (also via overloading the binary `-` operator)
+* `BinaryFn.max_unsigned(a, b)`
+* `BinaryFn.min_unsigned(a, b)`
+* `UnaryFn.exp(a)`
+* `UnaryFn.log(a)`
As the integer types are signless, signedness is implement by different
functions that treat integers as signed or unsigned values.
arg: C
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: C
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: C
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
scalar_arg: AZp
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
arg: accum
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: accum
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: C
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: C
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
scalar_arg: AZp
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
arg: x
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: x
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: C
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: C
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: max
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: max_unsigned
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: max
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: min
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: min_unsigned
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: max
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: min
operands:
- !ScalarExpression
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_index: 1
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_const: '12345 : i64'
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: sub
operands:
- !ScalarExpression
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: unary
fn_name: log
operands:
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_const: '1.000000e+00 : f64'
- !ScalarExpression
scalar_fn:
- kind: arith
+ kind: unary
fn_name: exp
operands:
- !ScalarExpression
// Region builder helper.
// TODO: Move this to a utility library.
// The public methods on this class are referenced directly from generated code
-// and bind by name to math and type conversion functions in the DSL as:
-// `arithfn__{fnName}`
-// `typefn__{fnName}`
+// and bind by name to math functions in the DSL as:
+// `unary__{fnName}`
+// `binary__{fnName}`
// Examples:
-// `arithfn__add`
-// `arithfn__mul`
-// `typefn__cast`
+// `binary__add`
+// `binary__mul`
+// `unary__exp`
+// `unary__log`
// The naming convention is intentional in order to match snake-cased DSL names.
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
//
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value arithfn__add(Value lhs, Value rhs) {
+ Value binary__add(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value arithfn__exp(Value x) {
+ Value unary__exp(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::ExpOp>(x.getLoc(), x);
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value arithfn__log(Value x) {
+ Value unary__log(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::LogOp>(x.getLoc(), x);
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value arithfn__sub(Value lhs, Value rhs) {
+ Value binary__sub(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value arithfn__mul(Value lhs, Value rhs) {
+ Value binary__mul(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value arithfn__max(Value lhs, Value rhs) {
+ Value binary__max(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value arithfn__max_unsigned(Value lhs, Value rhs) {
+ Value binary__max_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value arithfn__min(Value lhs, Value rhs) {
+ Value binary__min(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value arithfn__min_unsigned(Value lhs, Value rhs) {
+ Value binary__min_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
self.visit_tensor_exprs(visit_scalar_def)
def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
- return ArithFn.add(self, rhs)
+ return BinaryFn.add(self, rhs)
def __mul__(self, rhs) -> "TensorExpression":
- return ArithFn.mul(self, rhs)
+ return BinaryFn.mul(self, rhs)
def __sub__(self, rhs) -> "TensorExpression":
- return ArithFn.sub(self, rhs)
+ return BinaryFn.sub(self, rhs)
def __hash__(self):
return hash(id(self))
return rhs_dims - lhs_dims
def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
- return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
+ return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs)
def __repr__(self):
return (f"{self.operand_def.name}"
f"bound to its lhs: {self}")
full_args = [self.lhs.to_scalar_expression()
] + [arg.to_scalar_expression() for arg in self.args]
- return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None,
- None, full_args).expr()
+ return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name,
+ None, None, full_args).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
for arg in self.args:
class FunctionKind(Enum):
- ARITH = 0
- TYPE = 1
+ UNARY = 0
+ BINARY = 1
+ TYPE = 2
-class TypeFnType:
- """Type conversion function.
+class UnaryFnType:
+ """Unary function.
- A type conversion function takes a target type and a tensor expression and
- returns the casted tensor expression.
+ A unary function takes one tensor expression and returns the
+ function evaluation result.
"""
def __init__(self, fn_name: str):
self.fn_name = fn_name
- def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
- return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
+ def __call__(self, exp: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp])
def __repr__(self):
return f"{self.fn_name}"
-class TypeFn:
- """Type conversion function namespace.
-
- As the integer types are signless, signedness is implement by different cast
- functions that treat integers as signed (`cast`) or unsigned
- (`cast_unsigned`) values.
-
- Examples:
- - cast(I32 -> I64) -> `arith.ExtSIOp`
- - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
- """
- cast = TypeFnType("cast")
- cast_unsigned = TypeFnType("cast_unsigned")
+class UnaryFn:
+ """Unary function namespace."""
+ exp = UnaryFnType("exp")
+ log = UnaryFnType("log")
-class ArithFnType:
- """Arithmetic function.
+class BinaryFnType:
+ """Binary function.
- An arithmetic function takes one ore more tensor expressions and returns the
+ A binary function takes two tensor expressions and returns the
function evaluation result.
"""
def __init__(self, fn_name: str):
self.fn_name = fn_name
- def __call__(self, *args) -> "TensorFn":
- return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args)
+ def __call__(self, arg0: TensorExpression,
+ arg1: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
def __repr__(self):
return f"{self.fn_name}"
-class ArithFn:
- """Arithmetic function namespace.
+class BinaryFn:
+ """Binary function namespace.
As the integer types are signless, signedness is implement by different
functions that treat integers as signed or unsigned values.
- max -> `arith.MaxSIOp`
- max_unsinged -> `arith.MaxUIOp`
"""
- add = ArithFnType("add")
- exp = ArithFnType("exp")
- log = ArithFnType("log")
- mul = ArithFnType("mul")
- max = ArithFnType("max")
- min = ArithFnType("min")
- sub = ArithFnType("sub")
- max_unsigned = ArithFnType("max_unsigned")
- min_unsigned = ArithFnType("min_unsigned")
+ add = BinaryFnType("add")
+ mul = BinaryFnType("mul")
+ max = BinaryFnType("max")
+ min = BinaryFnType("min")
+ sub = BinaryFnType("sub")
+ max_unsigned = BinaryFnType("max_unsigned")
+ min_unsigned = BinaryFnType("min_unsigned")
+
+
+class TypeFnType:
+ """Type conversion function.
+
+ A type conversion function takes a target type and a tensor expression and
+ returns the casted tensor expression.
+ """
+
+ def __init__(self, fn_name: str):
+ self.fn_name = fn_name
+
+ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
+
+ def __repr__(self):
+ return f"{self.fn_name}"
+
+
+class TypeFn:
+ """Type conversion function namespace.
+
+ As the integer types are signless, signedness is implement by different cast
+ functions that treat integers as signed (`cast`) or unsigned
+ (`cast_unsigned`) values.
+
+ Examples:
+ - cast(I32 -> I64) -> `arith.ExtSIOp`
+ - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
+ """
+ cast = TypeFnType("cast")
+ cast_unsigned = TypeFnType("cast_unsigned")
class ReduceFnUse:
A reduction use specifies the reduction function and dimensions.
"""
- def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef):
- self.arith_fn = arith_fn
+ def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef):
+ self.binary_fn = binary_fn
self.reduce_dims = reduce_dims
- def __call__(self, *args: TensorExpression):
+ def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
return TensorReduceFn(self, args)
def __repr__(self):
- return (f"reduce_{self.arith_fn.fn_name}"
+ return (f"reduce_{self.binary_fn.fn_name}"
f"({', '.join(repr(d) for d in self.reduce_dims)})")
class ReduceFnType:
"""Reduction function.
- An arithmetic function that reduces its RHS into its LHS.
+ A binary function that reduces its RHS into its LHS.
"""
- def __init__(self, arith_fn: ArithFnType):
- if not isinstance(arith_fn, ArithFnType):
- raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}")
- self.arith_fn = arith_fn
+ def __init__(self, binary_fn: BinaryFnType):
+ if not isinstance(binary_fn, BinaryFnType):
+ raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
+ self.binary_fn = binary_fn
def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
- return ReduceFnUse(self.arith_fn, *reduce_dims)
+ return ReduceFnUse(self.binary_fn, *reduce_dims)
def __repr__(self):
- return (f"reduce_{self.arith_fn.fn_name}")
+ return (f"reduce_{self.binary_fn.fn_name}")
class ReduceFn:
- add = ReduceFnType(ArithFn.add)
- mul = ReduceFnType(ArithFn.mul)
- max = ReduceFnType(ArithFn.max)
- min = ReduceFnType(ArithFn.min)
- max_unsigned = ReduceFnType(ArithFn.max_unsigned)
- min_unsigned = ReduceFnType(ArithFn.min_unsigned)
+ add = ReduceFnType(BinaryFn.add)
+ mul = ReduceFnType(BinaryFn.mul)
+ max = ReduceFnType(BinaryFn.max)
+ min = ReduceFnType(BinaryFn.min)
+ max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
+ min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
###############################################################################
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
return linalg.IndexOp(dim_attr).result
- elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH:
- fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}")
+ elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE:
+ kind = expr.scalar_fn.kind.name.lower()
+ fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}")
operand_values = [
self.expression(operand) for operand in expr.scalar_fn.operands
]
return fn(*operand_values)
- elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE:
+ elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE:
+ kind = expr.scalar_fn.kind.name.lower()
fn_name = expr.scalar_fn.fn_name
if expr.scalar_fn.attr_name:
fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
- fn = self._get_function(f"_typefn_{fn_name}")
+ fn = self._get_function(f"_{kind}_{fn_name}")
operand_value = self.expression(expr.scalar_fn.operands[0])
return fn(expr.scalar_fn.type_var.name, operand_value)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
raise ValueError(f"Unable to cast body expression from {operand_type} to "
f"{to_type}")
- def _typefn_cast(self, type_var_name: str, operand: Value) -> Value:
+ def _type_cast(self, type_var_name: str, operand: Value) -> Value:
return self._cast(type_var_name, operand, False)
- def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
+ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
return self._cast(type_var_name, operand, True)
- def _arithfn_add(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.AddFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return arith.AddIOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'add' operand: {lhs}")
-
- def _arithfn_exp(self, x: Value) -> Value:
+ def _unary_exp(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")
- def _arithfn_log(self, x: Value) -> Value:
+ def _unary_log(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
- def _arithfn_sub(self, lhs: Value, rhs: Value) -> Value:
+ def _binary_add(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return arith.AddFOp(lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ return arith.AddIOp(lhs, rhs).result
+ raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
+
+ def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.SubFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.SubIOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'sub' operand: {lhs}")
+ raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
- def _arithfn_mul(self, lhs: Value, rhs: Value) -> Value:
+ def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MulFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MulIOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
+ raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
- def _arithfn_max(self, lhs: Value, rhs: Value) -> Value:
+ def _binary_max(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MaxSIOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'max' operand: {lhs}")
+ raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
- def _arithfn_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
+ def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MaxUIOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
+ raise NotImplementedError(
+ "Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
- def _arithfn_min(self, lhs: Value, rhs: Value) -> Value:
+ def _binary_min(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MinSIOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'min' operand: {lhs}")
+ raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
- def _arithfn_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
+ def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MinUIOp(lhs, rhs).result
- raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
+ raise NotImplementedError(
+ "Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
def _infer_structured_outs(
"""
domain(D.m, D.n)
O[D.m, D.n] = \
- ArithFn.log(TypeFn.cast(U, const(1.0)) + ArithFn.exp(TypeFn.cast(U, I[D.m, D.n])))
+ UnaryFn.log(TypeFn.cast(U, const(1.0)) + UnaryFn.exp(TypeFn.cast(U, I[D.m, D.n])))
input_accesses.append(expr)
-def _op_to_callable(op: _BinaryOp) -> lang.ArithFnType:
+def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
"""Returns the linalg dialect function object for the given operation."""
op_to_callable = {
- operator.add: lang.ArithFn.add,
- operator.sub: lang.ArithFn.sub,
- operator.mul: lang.ArithFn.mul,
+ operator.add: lang.BinaryFn.add,
+ operator.sub: lang.BinaryFn.sub,
+ operator.mul: lang.BinaryFn.mul,
}
return op_to_callable[op]
arg: O
value: !ScalarExpression
scalar_fn:
- kind: arith
+ kind: binary
fn_name: add
operands:
- !ScalarExpression
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
-# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]);
+# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.binary__add([[VAL1]], [[VAL3]]);
# @linalg_structured_op
# IMPL-NEXT: MLIRContext *context = getContext();
# IMPL-NEXT: AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
# IMPL-NEXT: AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
+
+
+# @linalg_structured_op
+# def test4(O=TensorDef(T, S.M, S.N, output=True)):
+# """Title.
+
+# Detailed description.
+# """
+# O[D.m, D.n] = BinaryFn.add(UnaryFn.exp(O[D.m, D.n]), O[D.m, D.n])
+
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: test4
+ cpp_class_name: Test4Op
+ doc: |-
+ Title.
+
+ Detailed description.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T
+ shape_map: affine_map<()[s0, s1] -> (s0, s1)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: exp
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ scalar_arg: O
+
+# IMPL-LABEL: void Test4Op::regionBuilder(ImplicitLocOpBuilder &b,
+# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
+
+# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.unary__exp(block.getArgument(0))
+# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.binary__add([[VAL0]], block.getArgument(0))
+# IMPL-NEXT: yields.push_back([[VAL1]])
# CHECK: -
# CHECK: arg: O
# CHECK: scalar_fn:
-# CHECK: kind: arith
+# CHECK: kind: binary
# CHECK: fn_name: sub
# CHECK: operands:
# CHECK: scalar_fn:
-# CHECK: kind: arith
+# CHECK: kind: binary
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_fn:
-# CHECK: kind: type
-# CHECK: type_var: T
+# CHECK: kind: unary
+# CHECK: fn_name: exp
# CHECK: operands:
-# CHECK: scalar_const: '3.1415926535897931 : f64'
+# CHECK: scalar_fn:
+# CHECK: kind: type
+# CHECK: type_var: T
+# CHECK: operands:
+# CHECK: scalar_const: '3.1415926535897931 : f64'
# CHECK: scalar_fn:
# CHECK: kind: type
# CHECK: fn_name: cast
pi = TypeFn.cast(T, const(3.1415926535897931))
cst42 = TypeFn.cast(T, const(42))
cst1000 = TypeFn.cast(T, const(1e+3))
- O[D.m, D.n] = pi + cst42 - cst1000
-
+ O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
# CHECK: ---
# CHECK-LABEL: indices
# CHECK: -
# CHECK: arg: O
# CHECK: scalar_fn:
-# CHECK: kind: arith
+# CHECK: kind: binary
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_index: 1
@linalg_structured_op
def soft_plus_poly(
I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
- O[D.m, D.n] = ArithFn.log(
- TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, ArithFn.exp(I[D.m, D.n])))
+ O[D.m, D.n] = UnaryFn.log(
+ TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, UnaryFn.exp(I[D.m, D.n])))
@linalg_structured_op(op_name="custom_op_name")
struct ScalarExpression;
-enum class ScalarFnKind { Arith, Type };
+enum class ScalarFnKind { Unary, Binary, Type };
struct ScalarFn {
ScalarFnKind kind;
template <>
struct ScalarEnumerationTraits<ScalarFnKind> {
static void enumeration(IO &io, ScalarFnKind &value) {
- io.enumCase(value, "arith", ScalarFnKind::Arith);
+ io.enumCase(value, "unary", ScalarFnKind::Unary);
+ io.enumCase(value, "binary", ScalarFnKind::Binary);
io.enumCase(value, "type", ScalarFnKind::Type);
}
};
return cppIdent;
}
if (expression.scalarFn &&
- expression.scalarFn->kind == ScalarFnKind::Arith) {
+ expression.scalarFn->kind != ScalarFnKind::Type) {
// Apply function.
// Recursively generate operands.
SmallVector<std::string> operandCppValues;
return None;
operandCppValues.push_back(*operandCppValue);
}
+
+ std::string prefix = expression.scalarFn->kind == ScalarFnKind::Unary
+ ? "unary"
+ : "binary";
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(
- llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
- expression.scalarFn->fnName,
+ llvm::formatv("Value {0} = helper.{1}__{2}({3});", cppIdent,
+ prefix, expression.scalarFn->fnName,
interleaveToString(operandCppValues, ", ")));
return cppIdent;
}