--- /dev/null
+# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
+
+import numpy as np
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import mlir_pytaco_api as pt
+
+i, j = pt.get_index_vars(2)
+A = pt.tensor([2, 3])
+B = pt.tensor([2, 3])
+A.insert([0, 1], 10.3)
+A.insert([1, 1], 40.7)
+A.insert([0, 2], -11.3)
+A.insert([1, 2], -41.7)
+
+B[i, j] = abs(A[i, j])
+indices, values = B.get_coordinates_and_values()
+passed = np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
+passed += np.allclose(values, [10.3, 11.3, 40.7, 41.7])
+
+B[i, j] = pt.ceil(A[i, j])
+indices, values = B.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
+passed += np.allclose(values, [11, -11, 41, -41])
+
+B[i, j] = pt.floor(A[i, j])
+indices, values = B.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
+passed += np.allclose(values, [10, -12, 40, -42])
+
+B[i, j] = -A[i, j]
+indices, values = B.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
+passed += np.allclose(values, [-10.3, 11.3, -40.7, 41.7])
+
+# CHECK: Number of passed: 8
+print("Number of passed:", passed)
_ENTRY_NAME = "main"
# Type aliases for type annotation.
+_UnaryOp = Callable[[Any], Any]
_BinaryOp = Callable[[Any, Any], Any]
_ExprVisitor = Callable[..., None]
_ExprInfoDict = Dict["IndexExpr", "_ExprInfo"]
raise ValueError(f"Expected IndexExpr: {rhs}")
return _BinaryExpr(op, self, rhs)
+ def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
+ """Build a unary expression.
+
+ Args:
+ op: A _UnaryOp object representing the unary operation.
+ """
+ return _UnaryExpr(op, self)
+
def __add__(self, rhs) -> "_BinaryExpr":
"""Defines the operator +.
"""
return self._verify_operand_and_build_expr(rhs, operator.mul)
+ def __abs__(self) -> "_UnaryExpr":
+ """Defines the operator abs.
+
+ Returns:
+ A _UnaryExpr object representing the operation.
+ """
+ return self._build_unary_expr(operator.abs)
+
+ def __neg__(self) -> "_UnaryExpr":
+ """Defines the operator neg.
+
+ Returns:
+ A _UnaryExpr object representing the operation.
+ """
+ return self._build_unary_expr(operator.neg)
+
def __sub__(self, rhs) -> "_BinaryExpr":
"""Defines the operator -.
input_accesses.append(expr)
+def _op_ceil(__a: Any) -> Any:
+ """A _UnaryOp object for operation ceil."""
+ pass
+
+
+def _op_floor(__a: Any) -> Any:
+ """A _UnaryOp object for operation floor."""
+ pass
+
+
+def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType:
+ """Returns the linalg dialect function object for the given operation."""
+ op_to_callable = {
+ operator.abs: lang.UnaryFn.abs,
+ operator.neg: lang.UnaryFn.negf,
+ _op_ceil: lang.UnaryFn.ceil,
+ _op_floor: lang.UnaryFn.floor,
+ }
+ return op_to_callable[op]
+
+
+@dataclasses.dataclass(frozen=True)
+class _UnaryExpr(IndexExpr):
+ """The representation for a Unary operation.
+
+ Attributes:
+ op: A _UnaryOp representing the operation.
+ a: An IndexExpr representing the operand for the operation.
+ """
+ op: _BinaryOp
+ a: IndexExpr
+
+ def __post_init__(self) -> None:
+ """Verifies that the operand being added is an IndexExpr."""
+ assert isinstance(self.a, IndexExpr)
+
+ def _emit_expression(
+ self,
+ expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+ expr_to_info: _ExprInfoDict,
+ ) -> lang.ScalarExpression:
+ """Emits the expression tree and returns the expression."""
+ # The current expression node is an internal node of the structured op.
+ if self not in expr_to_opnd:
+ a = self.a._emit_expression(expr_to_opnd, expr_to_info)
+ return _op_unary_to_callable(self.op)(a)
+
+ # The current expression is a leaf node of the structured op. That is, it is
+ # a temporary tensor generated by its child structured op.
+ op_info = expr_to_info[self].structop_info
+ assert op_info is not None
+ dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+ return lang.TensorUse(expr_to_opnd[self], dims)
+
+ def _visit(self,
+ func: _ExprVisitor,
+ args,
+ *,
+ leaf_checker: _SubtreeLeafChecker = None) -> None:
+ """A post-order visitor."""
+ if leaf_checker is None or not leaf_checker(self, *args):
+ self.a._visit(func, args, leaf_checker=leaf_checker)
+ func(self, *args)
+
+ def dtype(self) -> DType:
+ """Returns the data type of the operation."""
+ return self.a.dtype()
+
+
def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
"""Returns the linalg dialect function object for the given operation."""
op_to_callable = {
}
return op_to_callable[op]
-
@dataclasses.dataclass(frozen=True)
class _BinaryExpr(IndexExpr):
"""The representation for a binary operation.
mode_formats = tuple(expr.tensor.format.format_pack.formats)
assert len(src_dims) == len(mode_formats)
dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
+ elif isinstance(expr, _UnaryExpr):
+ a_info = expr_to_info[expr.a]
+ index_to_dim_info = {
+ i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
+ }
+ # Here we rely on the fact that dictionaries keep the insertion order for
+ # keys and values.
+ src_indices = tuple(index_to_dim_info.keys())
+ dim_infos = tuple(index_to_dim_info.values())
else:
assert isinstance(expr, _BinaryExpr)
a_info = expr_to_info[expr.a]
expr_info.acc_reduce_indices = (
a_info.acc_reduce_indices | b_info.acc_reduce_indices
| expr_info.reduce_indices)
+ elif isinstance(expr, _UnaryExpr):
+ a_info = expr_to_info[expr.a]
+ expr_info.acc_reduce_indices = (
+ a_info.acc_reduce_indices | expr_info.reduce_indices)
else:
assert isinstance(expr, Access)
# Handle simple reduction expression in the format of A[i] = B[i, j].
opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
op_def.add_operand(name, opnd)
return opnd
+
+
+def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
+ """Build a unary operation ceil.
+
+ Args:
+ a: The operand, which could be any Python object from user inputs.
+ op: An _UnaryOp object representing the operation.
+
+ Returns:
+ A _UnaryExpr object representing the operation.
+
+ Raises:
+ ValueError: If a is not an IndexExpr.
+ """
+ if not isinstance(a, Access):
+ raise ValueError(f"Expected an Access Operand: {a}")
+ return a._build_unary_expr(op)
+
+
+def ceil(a: Access) -> "_UnaryExpr":
+ """Defines the operation ceil.
+
+ Args:
+ a: The operand, which could be any Python object from user inputs.
+
+ Returns:
+ A _UnaryExpr object representing the operation.
+
+ Raises:
+ ValueError: If a is not an IndexExpr.
+ """
+ return _check_and_build_unary(a, _op_ceil)
+
+
+def floor(a: Access) -> "_UnaryExpr":
+ """Defines the operation floor.
+
+ Args:
+ a: The operand, which could be any Python object from user inputs.
+
+ Returns:
+ A _UnaryExpr object representing the operation.
+
+ Raises:
+ ValueError: If a is not an IndexExpr.
+ """
+ return _check_and_build_unary(a, _op_floor)