--- /dev/null
+# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
+
+import numpy as np
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import mlir_pytaco_api as pt
+
+compressed = pt.compressed
+
+i, j = pt.get_index_vars(2)
+A = pt.tensor([2, 3])
+S = pt.tensor(3) # S is a scalar tensor.
+B = pt.tensor([2, 3], compressed)
+A.insert([0, 1], 10)
+A.insert([1, 2], 40)
+
+# Use [0] to index the scalar tensor.
+B[i, j] = A[i, j] * S[0]
+
+indices, values = B.get_coordinates_and_values()
+passed = np.array_equal(indices, [[0, 1], [1, 2]])
+passed += np.array_equal(values, [30.0, 120.0])
+
+# CHECK: Number of passed: 2
+print("Number of passed:", passed)
return ir.RankedTensorType.get(shape, ir_type, attr)
-def _verify_and_normalize_indices(indices) -> Tuple[IndexVar, ...]:
- """Verifies and normalizes the indices for a tensor access.
-
- Args:
- indices: The index expression used to access a tensor, which could be any
- Python object from user inputs.
-
- Returns:
- A tuple of IndexVar.
-
- Raises:
- ValueError: If indices is not an IndexVar or a tuple of IndexVar.
- """
- if isinstance(indices, IndexVar):
- return (indices,)
- elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
- return indices
-
- raise ValueError(f"Expected IndexVars: {indices}")
-
-
@dataclasses.dataclass(frozen=True)
class _StructOpInfo:
"""Information for generating a structured op in the linalg dialect.
def is_dense(self) -> bool:
"""Returns true if the tensor doesn't have sparsity annotation."""
- return self._format is None
+ return self.order == 0 or self._format is None
def to_array(self) -> np.ndarray:
"""Returns the numpy array for the Tensor.
"""Returns the shape of the Tensor."""
return self._shape
+ def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
+ """Verifies and normalizes the indices to access the tensor.
+
+ Args:
+ indices: The index expression used to access a tensor, which could be any
+ Python object from user inputs.
+
+ Returns:
+ A tuple of IndexVar.
+
+ Raises:
+ ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
+ a tuple of IndexVar for other tensors.
+ """
+ if self.order == 0:
+ if not isinstance(indices, int) or indices != 0:
+ raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
+ return ()
+
+ if isinstance(indices, IndexVar):
+ return (indices,)
+ elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
+ return indices
+
+ raise ValueError(f"Expected IndexVars: {indices}")
+
def __getitem__(self, key) -> "Access":
"""Verifies and processes a tensor access.
Raises:
ValueError: If key is not an IndexVar or a tuple of IndexVar.
"""
- indices = _verify_and_normalize_indices(key)
+ indices = self._verify_and_normalize_indices(key)
return Access(self, indices)
def __setitem__(self, key, value) -> None:
or a tuple of IndexVar, or the length of the indices is not the same as
the rank of the tensor.
"""
- indices = _verify_and_normalize_indices(key)
+ indices = self._verify_and_normalize_indices(key)
if len(indices) != self.order:
raise ValueError("Mismatch between indices and tensor rank: "
f"len({indices}) != {self.order}.")
def mlir_tensor_type(self) -> ir.RankedTensorType:
"""Returns the MLIR type for the tensor."""
- mlir_attr = None if (
- self._format is None) else self._format.mlir_tensor_attr()
+ mlir_attr = (None if (self._format is None or self.order == 0) else
+ self._format.mlir_tensor_attr())
return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
def dense_dst_ctype_pointer(self) -> ctypes.pointer: