[mlir][sparse][taco] Add a few unary operations.
authorBixia Zheng <bixia@google.com>
Thu, 10 Mar 2022 17:40:54 +0000 (09:40 -0800)
committerBixia Zheng <bixia@google.com>
Fri, 11 Mar 2022 16:08:55 +0000 (08:08 -0800)
Add operations -, abs, ceil and floor to the index notation.

Add test cases.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D121388

mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_unary_ops.py [new file with mode: 0644]
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py

diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_unary_ops.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_unary_ops.py
new file mode 100644 (file)
index 0000000..b949542
--- /dev/null
@@ -0,0 +1,40 @@
+# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
+
+import numpy as np
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import mlir_pytaco_api as pt
+
+i, j = pt.get_index_vars(2)
+A = pt.tensor([2, 3])
+B = pt.tensor([2, 3])
+A.insert([0, 1], 10.3)
+A.insert([1, 1], 40.7)
+A.insert([0, 2], -11.3)
+A.insert([1, 2], -41.7)
+
+B[i, j] = abs(A[i, j])
+indices, values = B.get_coordinates_and_values()
+passed = np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
+passed += np.allclose(values, [10.3, 11.3, 40.7, 41.7])
+
+B[i, j] = pt.ceil(A[i, j])
+indices, values = B.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
+passed += np.allclose(values, [11, -11, 41, -41])
+
+B[i, j] = pt.floor(A[i, j])
+indices, values = B.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
+passed += np.allclose(values, [10, -12, 40, -42])
+
+B[i, j] = -A[i, j]
+indices, values = B.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]])
+passed += np.allclose(values, [-10.3, 11.3, -40.7, 41.7])
+
+# CHECK: Number of passed: 8
+print("Number of passed:", passed)
index 2c452f1..46c7ba8 100644 (file)
@@ -53,6 +53,7 @@ _INDEX_BIT_WIDTH = 0
 _ENTRY_NAME = "main"
 
 # Type aliases for type annotation.
+_UnaryOp = Callable[[Any], Any]
 _BinaryOp = Callable[[Any, Any], Any]
 _ExprVisitor = Callable[..., None]
 _ExprInfoDict = Dict["IndexExpr", "_ExprInfo"]
@@ -1223,6 +1224,14 @@ class IndexExpr(abc.ABC):
       raise ValueError(f"Expected IndexExpr: {rhs}")
     return _BinaryExpr(op, self, rhs)
 
+  def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
+    """Build a unary expression.
+
+    Args:
+      op: A _UnaryOp object representing the unary operation.
+    """
+    return _UnaryExpr(op, self)
+
   def __add__(self, rhs) -> "_BinaryExpr":
     """Defines the operator +.
 
@@ -1253,6 +1262,22 @@ class IndexExpr(abc.ABC):
     """
     return self._verify_operand_and_build_expr(rhs, operator.mul)
 
+  def __abs__(self) -> "_UnaryExpr":
+    """Defines the operator abs.
+
+    Returns:
+      A _UnaryExpr object representing the operation.
+    """
+    return self._build_unary_expr(operator.abs)
+
+  def __neg__(self) -> "_UnaryExpr":
+    """Defines the operator neg.
+
+    Returns:
+      A _UnaryExpr object representing the operation.
+    """
+    return self._build_unary_expr(operator.neg)
+
   def __sub__(self, rhs) -> "_BinaryExpr":
     """Defines the operator -.
 
@@ -1603,6 +1628,75 @@ def _gather_input_accesses_index_vars(
     input_accesses.append(expr)
 
 
+def _op_ceil(__a: Any) -> Any:
+  """A _UnaryOp object for operation ceil."""
+  pass
+
+
+def _op_floor(__a: Any) -> Any:
+  """A _UnaryOp object for operation floor."""
+  pass
+
+
+def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType:
+  """Returns the linalg dialect function object for the given operation."""
+  op_to_callable = {
+      operator.abs: lang.UnaryFn.abs,
+      operator.neg: lang.UnaryFn.negf,
+      _op_ceil: lang.UnaryFn.ceil,
+      _op_floor: lang.UnaryFn.floor,
+  }
+  return op_to_callable[op]
+
+
+@dataclasses.dataclass(frozen=True)
+class _UnaryExpr(IndexExpr):
+  """The representation for a Unary operation.
+
+  Attributes:
+  op: A _UnaryOp representing the operation.
+  a: An IndexExpr representing the operand for the operation.
+  """
+  op: _BinaryOp
+  a: IndexExpr
+
+  def __post_init__(self) -> None:
+    """Verifies that the operand being added is an IndexExpr."""
+    assert isinstance(self.a, IndexExpr)
+
+  def _emit_expression(
+      self,
+      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+      expr_to_info: _ExprInfoDict,
+  ) -> lang.ScalarExpression:
+    """Emits the expression tree and returns the expression."""
+    # The current expression node is an internal node of the structured op.
+    if self not in expr_to_opnd:
+      a = self.a._emit_expression(expr_to_opnd, expr_to_info)
+      return _op_unary_to_callable(self.op)(a)
+
+    # The current expression is a leaf node of the structured op. That is, it is
+    # a temporary tensor generated by its child structured op.
+    op_info = expr_to_info[self].structop_info
+    assert op_info is not None
+    dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+    return lang.TensorUse(expr_to_opnd[self], dims)
+
+  def _visit(self,
+             func: _ExprVisitor,
+             args,
+             *,
+             leaf_checker: _SubtreeLeafChecker = None) -> None:
+    """A post-order visitor."""
+    if leaf_checker is None or not leaf_checker(self, *args):
+      self.a._visit(func, args, leaf_checker=leaf_checker)
+    func(self, *args)
+
+  def dtype(self) -> DType:
+    """Returns the data type of the operation."""
+    return self.a.dtype()
+
+
 def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
   """Returns the linalg dialect function object for the given operation."""
   op_to_callable = {
@@ -1612,7 +1706,6 @@ def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
   }
   return op_to_callable[op]
 
-
 @dataclasses.dataclass(frozen=True)
 class _BinaryExpr(IndexExpr):
   """The representation for a binary operation.
@@ -1740,6 +1833,15 @@ def _validate_and_collect_expr_info(
       mode_formats = tuple(expr.tensor.format.format_pack.formats)
     assert len(src_dims) == len(mode_formats)
     dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
+  elif isinstance(expr, _UnaryExpr):
+    a_info = expr_to_info[expr.a]
+    index_to_dim_info = {
+        i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
+    }
+    # Here we rely on the fact that dictionaries keep the insertion order for
+    # keys and values.
+    src_indices = tuple(index_to_dim_info.keys())
+    dim_infos = tuple(index_to_dim_info.values())
   else:
     assert isinstance(expr, _BinaryExpr)
     a_info = expr_to_info[expr.a]
@@ -1826,6 +1928,10 @@ def _accumulate_reduce_indices(
     expr_info.acc_reduce_indices = (
         a_info.acc_reduce_indices | b_info.acc_reduce_indices
         | expr_info.reduce_indices)
+  elif isinstance(expr, _UnaryExpr):
+    a_info = expr_to_info[expr.a]
+    expr_info.acc_reduce_indices = (
+        a_info.acc_reduce_indices | expr_info.reduce_indices)
   else:
     assert isinstance(expr, Access)
     # Handle simple reduction expression in the format of A[i] = B[i, j].
@@ -1965,3 +2071,51 @@ def _emit_structured_op_input(
   opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
   op_def.add_operand(name, opnd)
   return opnd
+
+
+def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
+  """Build a unary operation ceil.
+
+    Args:
+      a: The operand, which could be any Python object from user inputs.
+      op: An _UnaryOp object representing the operation.
+
+    Returns:
+      A _UnaryExpr object representing the operation.
+
+    Raises:
+      ValueError: If a is not an IndexExpr.
+    """
+  if not isinstance(a, Access):
+    raise ValueError(f"Expected an Access Operand: {a}")
+  return a._build_unary_expr(op)
+
+
+def ceil(a: Access) -> "_UnaryExpr":
+  """Defines the operation ceil.
+
+    Args:
+      a: The operand, which could be any Python object from user inputs.
+
+    Returns:
+      A _UnaryExpr object representing the operation.
+
+    Raises:
+      ValueError: If a is not an IndexExpr.
+    """
+  return _check_and_build_unary(a, _op_ceil)
+
+
+def floor(a: Access) -> "_UnaryExpr":
+  """Defines the operation floor.
+
+    Args:
+      a: The operand, which could be any Python object from user inputs.
+
+    Returns:
+      A _UnaryExpr object representing the operation.
+
+    Raises:
+      ValueError: If a is not an IndexExpr.
+    """
+  return _check_and_build_unary(a, _op_floor)
index 05704b9..d6072a4 100644 (file)
@@ -16,6 +16,8 @@ from . import mlir_pytaco
 from . import mlir_pytaco_io
 
 # Functions defined by PyTACO API.
+ceil = mlir_pytaco.ceil
+floor = mlir_pytaco.floor
 get_index_vars = mlir_pytaco.get_index_vars
 from_array = mlir_pytaco.Tensor.from_array
 read = mlir_pytaco_io.read