[mlir][Python][Linalg] Adding const, capture, and index support to the OpDSL.
authorTobias Gysi <gysit@google.com>
Thu, 29 Apr 2021 06:45:34 +0000 (06:45 +0000)
committerTobias Gysi <gysit@google.com>
Thu, 29 Apr 2021 07:24:47 +0000 (07:24 +0000)
The patch extends the OpDSL with support for:
- Constant values
- Capture scalar parameters
- Access the iteration indices using the index operation
- Provide predefined floating point and integer types.

Up to now the patch only supports emitting the new nodes. The C++/yaml path is not fully implemented. The fill_rng_2d operation defined in emit_structured_generic.py makes use of the new DSL constructs.

Differential Revision: https://reviews.llvm.org/D101364

mlir/docs/Tools/LinalgOpDsl.md
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py
mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py

index 140c2ea..4ef6fb1 100644 (file)
@@ -72,6 +72,34 @@ The docstring will be transferred to the op definition verbatim.
 Special identifying op interfaces can be declared for the op via
 `implements(interface1[, interface2...])`.
 
+## Parameters
+
+Structured operations can take two types of parameters namely input/output
+tensors and captures. Assignment expressions index the tensor parameters to
+access the individual elements, while captures are scalars that can be
+accessed directly.
+
+The following example demonstrates the use of the two parameter types:
+
+```python
+@linalg_structured_op
+def copy_and_scale(I=TensorDef(T, S.M, S.K),
+                   O=TensorDef(T, S.M, S.K, output=True),
+                   val=CaptureDef(T)):
+  """Scale the input by the captured value and store the result"""
+  O[D.m, D.n] = I[D.m, D.n] * val
+```
+
+The operation scales the input tensor `I` scales its elements by the value
+`val` and writes the result to the output tensor `out`. The capture `val` is
+bound to a `CaptureDef`, which specifies the type of the captured value. The
+tensors are bound to a `TensorDef` as demonstrated by the matmul example. All
+parameters appear in the parameter list of the operation:
+
+```python
+fill(in_tensor, outs=[out_tensor], captures=[captured_val])
+```
+
 ## Assignments
 
 The bulk of language consists of assignment expressions of the form above.
@@ -99,22 +127,30 @@ Reduction functions can appear as the outer-most function on the RHS:
 
 There are also special forms:
 
-* `cast(TypeVar, operand)`
+* `cast(TypeVar, operand)` casts the `operand` to the target type `TypeVar`.
+* `const(TypeVar, value)` returns a constant value of type `TypeVar`.
+* `index(dim)` returns the iteration index in the given dimension `dim`.
 
 ## Types
 
 All types in assignment expressions are late bound based on actual input
-and output types of constructed ops. Assignment expressions with no `cast`
-calls will generally require uniform types throughout and will fail to
-verify if violated. The presence of a `cast` allows for a limited form of
-numeric type conversion between element types that can be derived from inputs
-and outputs (and in the future, attributes). `cast` calls with a `TypeVar`
-first argument are emitted as `symbolic_cast` primitives in the YAML definition.
-
-Casting will perform `int<->float` type conversions and will perform any
-necessary extension or truncation within type family. Note that presently,
-any integer type is assumed to be signed for the purpose of determing how to
-extend or truncate. Supporting unsigned integer types is left for future work.
+and output types of constructed ops. An exception are predefined types such as
+`I32`, `I64`, `F32`, and `F64`. These hardwired types enable intermediate
+computations with a type that is independent of the input and output types.
+For example, parts of floating point computation may require double precision
+arithmetic despite all inputs and outputs being single precision values.
+Assignment expressions with no `cast` calls will generally require uniform
+types throughout and will fail to verify if violated. The presence of a
+`cast` allows for a limited form of numeric type conversion between element
+types that can be derived from inputs and outputs (and in the future,
+attributes). `cast` calls with a `TypeVar` first argument are emitted as
+`symbolic_cast` primitives in the YAML definition.
+
+Casting will perform `int<->float` and `index->int` type conversions and will
+perform any necessary extension or truncation within type family. Note that
+presently, any integer type is assumed to be signed for the purpose of
+determining how to extend or truncate. Supporting unsigned integer types is
+left for future work.
 
 Not all functions are applicable for all numeric types, and on mismatch, op
 verification will fail.
index 34a8d6d..6db3bcf 100644 (file)
@@ -232,7 +232,6 @@ class DimDef(AffineExprDef):
 
   """
   ALL_DIMS = dict()  # type: Dict[str, "DimDef"]
-  dimname: str
 
   def __new__(cls, dimname: str):
     existing = cls.ALL_DIMS.get(dimname)
@@ -276,7 +275,6 @@ class SymbolDef(AffineExprDef):
     True
   """
   ALL_SYMBOLS = dict()  # type: Dict[str, "SymbolDef"]
-  symname: str
 
   def __new__(cls, symname: str):
     existing = cls.ALL_SYMBOLS.get(symname)
index 85da332..9b93d33 100644 (file)
@@ -8,7 +8,7 @@ about it typically involves processing this form into config objects that
 represent actual op definitions (i.e. YAML).
 """
 
-from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 from mlir import ir as _ir
 
@@ -27,24 +27,49 @@ class TensorExpression:
   def to_scalar_expression(self) -> ScalarExpression:
     raise NotImplementedError()
 
-  def visit_affine_exprs(self, callback):
-    """Visits all affine expressions reachable by the expression."""
-    pass
+  def visit_tensor_exprs(self, callback):
+    """Visits all tensor expression reachable by the expression."""
+    callback(self)
 
   def _get_all_dim_defs(self) -> Set[DimDef]:
     """Recursively gets all DimDef affine expressions that are referenced."""
     results = set()
 
-    def visitor(affine_expr):
-      if isinstance(affine_expr, DimDef):
-        results.add(affine_expr)
+    def visit_dim_def(dim_def):
+        if isinstance(dim_def, DimDef):
+          results.add(dim_def)
 
-    self.visit_affine_exprs(visitor)
+    def visit_affine_exprs(expr):
+      if isinstance(expr, TensorUse):
+        for ind in expr.indices:
+          ind.visit_affine_exprs(visit_dim_def)
+      if isinstance(expr, ReduceApply):
+        for ind in expr.reduce.reduce_dims:
+          ind.visit_affine_exprs(visit_dim_def)
+
+    self.visit_tensor_exprs(visit_affine_exprs)
     return results
 
   def collect_uses(self, uses: Set["TensorUse"]):
     """Collects all TensorUses reachable through this expression."""
-    pass
+    def visit_tensor_use(expr):
+      if isinstance(expr, TensorUse):
+        uses.add(expr)
+    self.visit_tensor_exprs(visit_tensor_use)
+
+  def collect_indices(self, indices: Set["index"]):
+    """Collects all index accesses reachable through this expression."""
+    def visit_index(expr):
+      if isinstance(expr, index):
+        indices.add(expr)
+    self.visit_tensor_exprs(visit_index)
+
+  def collect_captures(self, captures: Set["CaptureDef"]):
+    """Collects all CaptureDefs reachable through this expression."""
+    def visit_capture_def(expr):
+      if isinstance(expr, CaptureDef):
+        captures.add(expr)
+    self.visit_tensor_exprs(visit_capture_def)
 
   def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
     return PrimFn.add(self, rhs)
@@ -84,13 +109,6 @@ class TensorUse(TensorExpression):
     assert n is not None, "TensorDef not attached"
     return n
 
-  def visit_affine_exprs(self, callback):
-    for ind in self.indices:
-      ind.visit_affine_exprs(callback)
-
-  def collect_uses(self, uses: Set["TensorUse"]):
-    uses.add(self)
-
   def __iadd__(self, rhs: TensorExpression) -> TensorExpression:
     return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs)
 
@@ -178,6 +196,35 @@ class TensorDef:
     return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, "
             f"shape={self.shape})")
 
+class CaptureDef(TensorExpression):
+  """Defines an SSA value captured by the operation.
+
+  The captured SSA values are not indexed by the indexing_maps of the
+  structured op (as opposed to memrefs and tensors). A unique name
+  identifies the captures and an index determines their position the
+  operation's parameter list.
+  """
+
+  def __init__(self, type_var: TypeVar):
+    if not isinstance(type_var, TypeVar):
+      raise ValueError(f"CaptureDef requires a TypeVar. Got: {repr(type_var)}")
+    self.owner = None  # type: Optional["LinalgOpDef"]
+    self.type_var = type_var
+    self.capture_name = None  # type: Optional[str]
+    self.registered_index = -1  # type: int
+
+  def attach(self, index: int, capture_name: str, owner: "LinalgOpDef"):
+    if self.owner:
+      raise ValueError(f"CaptureDef already registered with op: {self}")
+    self.registered_index = index
+    self.capture_name = capture_name
+    self.owner = owner
+
+  def to_scalar_expression(self) -> ScalarExpression:
+    return ScalarCapture(self.capture_name).expr()
+
+  def __repr__(self):
+    return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})")
 
 class Comprehension:
   """Represents a single comprehension."""
@@ -279,17 +326,52 @@ class PrimApply(TensorExpression):
                          *[arg.to_scalar_expression() for arg in self.args
                           ]).expr()
 
-  def visit_affine_exprs(self, callback):
-    for arg in self.args:
-      arg.visit_affine_exprs(callback)
-
-  def collect_uses(self, uses: Set["TensorUse"]):
+  def visit_tensor_exprs(self, callback):
+    super().visit_tensor_exprs(callback)
     for arg in self.args:
-      arg.collect_uses(uses)
+      arg.visit_tensor_exprs(callback)
 
   def __repr__(self):
     return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
 
+class const(TensorExpression):
+  """Returns the given constant floating point or integer value."""
+
+  def __init__(self, type_var: TypeVar, value: Any):
+    if not isinstance(type_var, TypeVar):
+      raise ValueError(f"const requires a TypeVar. Got: {repr(type_var)}")
+    if not (isinstance(value, float) or isinstance(value, int)):
+      raise ValueError(f"const requires int or float. Got: {type(value)}")
+    self.type_var = type_var
+    self.value = value
+
+  def to_scalar_expression(self) -> ScalarExpression:
+    return ScalarConst(self.type_var, self.value).expr()
+
+  def __repr__(self):
+    return f"const({self.type_var}, {self.value})"
+
+class index(TensorExpression):
+  """Returns the iteration index for a given dimension name.
+
+  Resolves the given dimension name to obtain its position in the iteration
+  domain of the operation.
+  """
+
+  def __init__(self, dim : DimDef):
+    self.dim_def = dim
+    self.dim = -1
+
+  def resolve_dimension_name(self, affine_state: AffineBuildState):
+    self.dim = affine_state.get_dim(self.dim_def.dimname)
+
+  def to_scalar_expression(self) -> ScalarExpression:
+    assert self.dim != -1, "Dimension name not resolved"
+    return ScalarIndex(self.dim).expr()
+
+  def __repr__(self):
+    return f"index({repr(self.dim)})"
+
 
 class cast(TensorExpression):
   """Casts the element type to a type (typically symbolic TypeVar)."""
@@ -302,11 +384,9 @@ class cast(TensorExpression):
     return ScalarSymbolicCast(self.to_type,
                               self.operand.to_scalar_expression()).expr()
 
-  def visit_affine_exprs(self, callback):
-    self.operand.visit_affine_exprs(callback)
-
-  def collect_uses(self, uses: Set["TensorUse"]):
-    self.operand.collect_uses(uses)
+  def visit_tensor_exprs(self, callback):
+    super().visit_tensor_exprs(callback)
+    self.operand.visit_tensor_exprs(callback)
 
   def __repr__(self):
     return f"cast({self.to_type}, {repr(self.operand)})"
@@ -331,15 +411,9 @@ class ReduceApply(TensorExpression):
                 ] + [arg.to_scalar_expression() for arg in self.args]
     return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr()
 
-  def visit_affine_exprs(self, callback):
-    for ind in self.reduce.reduce_dims:
-      ind.visit_affine_exprs(callback)
-    for arg in self.args:
-      arg.visit_affine_exprs(callback)
-
-  def collect_uses(self, uses: Set["TensorUse"]):
+  def visit_tensor_exprs(self, callback):
     for arg in self.args:
-      arg.collect_uses(uses)
+      arg.visit_tensor_exprs(callback)
 
   def __repr__(self):
     return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})"
@@ -385,6 +459,7 @@ class LinalgOpDef:
                doc: Optional[str] = None):
     self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
     self.registered_tensors = dict()  # type: Dict[str, TensorDef]
+    self.registered_captures = dict()  # type: Dict[str, CaptureDef]
     self.comprehensions = list()  # type: List[Comprehension]
     self._affine_state = AffineBuildState()
 
@@ -404,12 +479,13 @@ class LinalgOpDef:
     tensor.attach(len(self.registered_tensors), tensor_name, self)
     self.registered_tensors[tensor_name] = tensor
 
-  def tensor(self, name):
-    """Gets a registered tensor by name."""
-    try:
-      return self.registered_tensors[name]
-    except KeyError:
-      raise KeyError(f"Tensor {name} is not registered")
+  def add_capture(self, capture_name: str, capture: CaptureDef):
+    """Registers a capture."""
+    if capture_name in self.registered_captures:
+      raise ValueError(f"Capture {capture_name} is already registered "
+                       f"to {self.registered_captures['capture_name']}")
+    capture.attach(len(self.registered_captures), capture_name, self)
+    self.registered_captures[capture_name] = capture
 
   def __repr__(self):
     lines = [
@@ -417,6 +493,8 @@ class LinalgOpDef:
     ]
     for name, tensor in self.registered_tensors.items():
       lines.append(f"  {tensor}")
+    for name, capture in self.registered_captures.items():
+      lines.append(f"  {capture}")
     if self.comprehensions:
       lines[-1] += " {"
       for comprehension in self.comprehensions:
index fdc6cfd..a67d18c 100644 (file)
@@ -70,6 +70,22 @@ class TensorDefConfig(YAMLObject):
   def __repr__(self):
     return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})"
 
+class CaptureDefConfig(YAMLObject):
+  """Wrapper around a CaptureDef."""
+  yaml_tag = "LinalgCaptureDef"
+
+  def __init__(self, capture_def: CaptureDef):
+    self.capture_def = capture_def
+
+  def to_yaml_custom_dict(self):
+    return dict(
+        name=self.capture_def.capture_name,
+        type_var=self.capture_def.type_var.name,
+    )
+
+  def __repr__(self):
+    return f"Def({self.capture_def})"
+
 
 class LinalgIndexingMapsConfig(YAMLObject):
   """Abstracts the style of indexing maps that the op exports.
@@ -109,10 +125,14 @@ class LinalgStructuredOpConfig(YAMLObject):
     self.affine_state = AffineBuildState()
     self.writes = list()  # type: List[Tuple[TensorUse, TensorExpression]]
     self.tensor_args = dict()  # type: Dict[TensorDef, TensorDefConfig]
+    self.capture_args = dict()  # type: Dict[CaptureDef, CaptureDefConfig]
     self.uses = dict()  # type: Dict[TensorUse, TensorUseConfig]
 
-    # Compute the ordered set of writes.
+    # Compute the ordered set of writes and collect the tensor, capture, and
+    # index uses.
     collected_uses = set()
+    collected_captures = set()
+    collected_indices = set()
     for write_use, read_use in zip(comprehension.definitions,
                                    comprehension.values):
       self.writes.append((write_use, read_use))
@@ -120,10 +140,14 @@ class LinalgStructuredOpConfig(YAMLObject):
     for write_use, read_use in self.writes:
       collected_uses.add(write_use)
       read_use.collect_uses(collected_uses)
+      read_use.collect_captures(collected_captures)
+      read_use.collect_indices(collected_indices)
 
     # Need to add all definitions before uses, so process twice.
     for use in collected_uses:
       self.add_tensor_arg(use.tensor_def)
+    for capture in collected_captures:
+      self.add_capture_arg(capture)
     for use in collected_uses:
       self.add_use(use)
 
@@ -170,6 +194,14 @@ class LinalgStructuredOpConfig(YAMLObject):
           f"dims. Got: {all_reduction_dims}")
     self.reduction_dims = next(iter(all_reduction_dims))
 
+    # Check the index dimension exists and resolve
+    for index in collected_indices:
+      if index.dim_def.dimname not in self.affine_state.all_dims:
+        raise ValueError(
+          f"The dimension {index.dim.dimname} is not part of the iteration "
+          f"domain {self.affine_state.all_dims}")
+      index.resolve_dimension_name(self.affine_state)
+
     # Generate the scalar assignments (used to build a body).
     self.assignments = [
         ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
@@ -187,6 +219,11 @@ class LinalgStructuredOpConfig(YAMLObject):
                   key=lambda tuc: tuc.tensor_use.tensor_def.registered_index)
 
   @property
+  def ordered_capture_args(self) -> Sequence[CaptureDefConfig]:
+    return sorted(self.capture_args.values(),
+                  key=lambda cdc: cdc.capture_def.registered_index)
+
+  @property
   def ordered_dims(self) -> Sequence[Tuple[str, int]]:
     """Gets the ordered list of dim bindings (symbolic name, position).
 
@@ -245,6 +282,12 @@ class LinalgStructuredOpConfig(YAMLObject):
       use_config = TensorUseConfig(tensor_use, indexing_map)
       self.uses[tensor_use] = use_config
 
+  def add_capture_arg(self, capture_def: CaptureDef):
+    if capture_def in self.capture_args:
+      return
+    def_config = CaptureDefConfig(capture_def)
+    self.capture_args[capture_def] = def_config
+
   def _normalize_affine_map(self,
                             affine_map: _ir.AffineMap,
                             with_dims: bool = True) -> _ir.AffineMap:
@@ -258,6 +301,7 @@ class LinalgStructuredOpConfig(YAMLObject):
   def to_yaml_custom_dict(self):
     self_dict = dict(
         args=self.ordered_tensor_args,
+        captures=self.ordered_capture_args,
         # TODO: Refactor the hierarchy internally when supporting more
         # than static (preserving this serialized form).
         indexing_maps=LinalgIndexingMapsConfig(
@@ -272,6 +316,9 @@ class LinalgStructuredOpConfig(YAMLObject):
     lines.append("tensor_args=[")
     for def_config in self.ordered_tensor_args:
       lines.append(f"  {repr(def_config)}")
+    lines.append("], capture_args=[")
+    for def_config in self.ordered_capture_args:
+      lines.append(f"  {repr(def_config)}")
     lines.append("], indexing_maps=[")
     for m in self.indexing_maps:
       lines.append(f"  {repr(m)}")
index 002ae51..428eadf 100644 (file)
@@ -105,11 +105,15 @@ def linalg_structured_op(dsl_func=None,
   sig = inspect.signature(dsl_func)
   for param_name, param in sig.parameters.items():
     param_default = param.default
-    if not isinstance(param_default, TensorDef):
+    if isinstance(param_default, TensorDef):
+      tc_model.add_tensor(param_name, param_default)
+    elif isinstance(param_default, CaptureDef):
+      tc_model.add_capture(param_name, param_default)
+    else:
       raise ValueError(f"@tc_def_op function parameters must be defaulted as "
-                       f"TensorDef(...): Found {param_name}: {param_default}")
+                       f"TensorDef(...) or CaptureDef(...): Found {param_name}"
+                       f": {param_default}")
     dsl_func_args.append(param_default)
-    tc_model.add_tensor(param_name, param_default)
 
   # Invoke the DSL func to finish populating the model.
   with bind_op_def(tc_model):
index 682f191..4a03702 100644 (file)
@@ -2,7 +2,7 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Dict, Sequence
+from typing import Any, Dict, Sequence
 
 from mlir.ir import *
 from mlir.dialects import linalg
@@ -28,10 +28,20 @@ def isa(cls : Type, ty : Type):
 
 def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
                                  *ins: Value,
-                                 outs: Value):
+                                 outs: Sequence[Value],
+                                 captures: Sequence[Value]):
   all_arg_defs = op_config.ordered_tensor_args
   in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
   out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"]
+  capture_arg_defs = op_config.ordered_capture_args
+
+  # Verify outs and captures are sequences.
+  if not isinstance(outs, Sequence):
+    raise ValueError(f"Expected named argument outs to have type Sequence "
+                     f"but got {type(outs)}")
+  if not isinstance(captures, Sequence):
+    raise ValueError(f"Expected named argument captures to have type Sequence "
+                     f"but got {type(outs)}")
 
   # Arity validation.
   if len(ins) != len(in_arg_defs):
@@ -40,19 +50,35 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
   if outs and len(outs) != len(out_arg_defs):
     raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
                      f"{len(outs)} for {op_config}")
+  if captures and len(captures) != len(capture_arg_defs):
+    raise ValueError(f"Expected {len(capture_arg_defs)} captures but got "
+                     f"{len(captures)} for {op_config}")
 
   outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
                                            out_arg_defs, outs)
 
   result_types = [t for t in out_types if isa(RankedTensorType, t)]
 
-  # Extract type vars for input/output based types.
+  # Initialize the type dictionary with the predefined types.
   type_mapping = dict()  # type: Dict[str, Type]
+  type_mapping["F32"] = F32Type.get()
+  type_mapping["F64"] = F64Type.get()
+  type_mapping["I32"] = IntegerType.get_signless(32)
+  type_mapping["I64"] = IntegerType.get_signless(64)
+
+  # Extract type vars for input/output based types.
   for arg_def, arg_element_type in zip(
       in_arg_defs + out_arg_defs,
       _get_shaped_element_types_from_values(*ins, *outs)):
-    tv_name = arg_def.tensor_def.type_var.name
-    type_mapping[tv_name] = arg_element_type
+    _add_type_mapping(arg_def.tensor_def.type_var.name, arg_element_type,
+                      type_mapping)
+
+  # Extract type vars for captures and compute capture argument mapping.
+  capture_arg_mapping = dict()  # type: Dict[str, Value]
+  for arg_def, capture_value in zip(capture_arg_defs, captures):
+    _add_type_mapping(arg_def.capture_def.type_var.name, capture_value.type,
+                      type_mapping)
+    capture_arg_mapping[arg_def.capture_def.capture_name] = capture_value
 
   # Emit the generic op.
   # TODO: Support emission of pure memref form.
@@ -63,21 +89,22 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
        for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)])
   iterator_types_attr = ArrayAttr.get(
       [StringAttr.get(s) for s in op_config.iterator_types])
-  sparse_attr = ArrayAttr.get(
-      [BoolAttr.get(False) for s in list(ins) + list(outs) if isa(RankedTensorType, s.type)])
-  if len(sparse_attr) == 0:
-    sparse_attr = None
+  # TODO: Add support for sparse operands once there is a stable interface.
+  sparse_attr = None
 
   return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
-          type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr)
+          type_mapping, capture_arg_mapping, indexing_maps_attr,
+          iterator_types_attr, sparse_attr)
 
 
 def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
                                *ins: Value,
-                               outs: Value = ()):
-  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
-  type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr =   \
-     prepare_common_structured_op(op_config, *ins, outs = outs)
+                               outs: Sequence[Value] = (),
+                               captures: Sequence[Value] = ()):
+  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
+  capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
+     prepare_common_structured_op(op_config, *ins, outs = outs,
+                                  captures=captures)
 
   generic_op = linalg.GenericOp(
       result_tensors=result_types,
@@ -95,7 +122,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
   block = generic_op.regions[0].blocks.append(*block_arg_types)
   block_arg_mapping = dict(zip(block_arg_names, block.arguments))
   with InsertionPoint(block):
-    body_builder = _BodyBuilder(type_mapping, block_arg_mapping)
+    body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
+                                capture_arg_mapping)
     for assignment in op_config.assignments:
       body_builder.assign(assignment)
     body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs))
@@ -110,10 +138,12 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
                              op_name: str,
                              op_class_name: str,
                              *ins: Value,
-                             outs: Value = ()):
-  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
-  type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr =   \
-     prepare_common_structured_op(op_config, *ins, outs = outs)
+                             outs: Sequence[Value] = (),
+                             captures: Sequence[Value] = ()):
+  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
+  capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
+     prepare_common_structured_op(op_config, *ins, outs = outs,
+                                  captures = captures)
 
   # If we get here, there must exist a builtin class `op_class_name`.
   ctx = Context.current
@@ -127,7 +157,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
   linalgDialect = ctx.get_dialect_descriptor("linalg")
   fill_builtin_region(linalgDialect, named_op.operation)
   # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
-  # attribute that the non-yaml path does not. The non-yaml path hardcodes the 
+  # attribute that the non-yaml path does not. The non-yaml path hardcodes the
   # indexing_maps in C++ directly.
   named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr
   # iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
@@ -141,10 +171,13 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
 class _BodyBuilder:
   """Constructs a structured op body by evaluating assignments."""
 
-  def __init__(self, type_mapping: Dict[str, Type],
-               block_arg_mapping: Dict[str, Value]):
+  def __init__(self,
+               type_mapping: Dict[str, Type],
+               block_arg_mapping: Dict[str, Value],
+               capture_arg_mapping: Dict[str, Value]):
     self.type_mapping = type_mapping
     self.block_arg_mapping = block_arg_mapping
+    self.capture_arg_mapping = capture_arg_mapping
     self.yield_mapping = dict()  # type: Dict[str, Value]
 
   def assign(self, assignment: ScalarAssign):
@@ -161,6 +194,16 @@ class _BodyBuilder:
       except KeyError:
         raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for "
                          f"this structured op.")
+    elif expr.scalar_capture:
+      try:
+        return self.capture_arg_mapping[expr.scalar_capture.capture]
+      except KeyError:
+        raise ValueError(f"Capture {expr.scalar_capture.capture} is not bound for "
+                         f"this structured op.")
+    elif expr.scalar_const:
+      return self.constant(expr.scalar_const.type_var.name, expr.scalar_const.value)
+    elif expr.scalar_index:
+      return self.index(expr.scalar_index.dim)
     elif expr.scalar_apply:
       try:
         fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}")
@@ -177,6 +220,25 @@ class _BodyBuilder:
       return self.cast(expr.symbolic_cast.to_type.name, operand_value)
     raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
 
+  def constant(self, type_var_name: str, value: Any) -> Value:
+    try:
+      type = self.type_mapping[type_var_name]
+    except KeyError:
+      raise ValueError(f"Unbound type variable '{type_var_name}' ("
+                       f"expected one of {self.type_mappings.keys()}")
+    try:
+      if(_is_floating_point_type(type)):
+        return std.ConstantOp(type, FloatAttr.get(type, float(value))).result
+      elif(_is_integer_type(type)):
+        return std.ConstantOp(type, IntegerAttr.get(type, int(value))).result
+    except ValueError:
+      raise ValueError(f"Unable to cast value {value} to type {type}")
+    raise NotImplementedError(f"Unimplemented constant type {type}")
+
+  def index(self, dim: int) -> Value:
+    dim_attr = IntegerAttr.get(IntegerType.get_signless(64), dim)
+    return linalg.IndexOp(IndexType.get(), dim_attr).result
+
   def cast(self, type_var_name: str, operand: Value) -> Value:
     try:
       to_type = self.type_mapping[type_var_name]
@@ -189,15 +251,13 @@ class _BodyBuilder:
       return self._cast_to_integer(to_type, operand)
     elif _is_floating_point_type(to_type):
       return self._cast_to_floating_point(to_type, operand)
-
-    raise ValueError(f"Unable to cast body expression from {operand.type} to "
-                     f"{to_type}")
-
   def _cast_to_integer(self, to_type: Type, operand: Value) -> Value:
     to_width = IntegerType(to_type).width
     operand_type = operand.type
     if _is_floating_point_type(operand_type):
       return std.FPToSIOp(to_type, operand).result
+    if _is_index_type(operand_type):
+      return std.IndexCastOp(to_type, operand).result
     # Assume integer.
     from_width = IntegerType(operand_type).width
     if to_width > from_width:
@@ -234,14 +294,21 @@ class _BodyBuilder:
   def _eval_add(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return std.AddFOp(lhs.type, lhs, rhs).result
-    if _is_integer_type(lhs.type):
+    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return std.AddIOp(lhs.type, lhs, rhs).result
     raise NotImplementedError("Unsupported 'add' operand: {lhs}")
 
+  def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
+    if _is_floating_point_type(lhs.type):
+      return std.SubFOp(lhs.type, lhs, rhs).result
+    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+      return std.SubIOp(lhs.type, lhs, rhs).result
+    raise NotImplementedError("Unsupported 'sub' operand: {lhs}")
+
   def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return std.MulFOp(lhs.type, lhs, rhs).result
-    if _is_integer_type(lhs.type):
+    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return std.MulIOp(lhs.type, lhs, rhs).result
     raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
 
@@ -281,6 +348,12 @@ def _get_tensor_def_names(
     *tensor_def_configs: TensorDefConfig) -> Sequence[str]:
   return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs]
 
+def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]):
+  if name in type_mapping:
+    if type_mapping[name] != type:
+        raise ValueError(f"Cannot overwrite type mapping {name} = "
+                         f"{type_mapping[name]} by type {type}")
+  type_mapping[name] = type
 
 def _is_floating_point_type(t: Type) -> bool:
   # TODO: Create a FloatType in the Python API and implement the switch
@@ -288,10 +361,11 @@ def _is_floating_point_type(t: Type) -> bool:
   return (F64Type.isinstance(t) or F32Type.isinstance(t) or
           F16Type.isinstance(t) or BF16Type.isinstance(t))
 
-
 def _is_integer_type(t: Type) -> bool:
   return IntegerType.isinstance(t)
 
+def _is_index_type(t: Type) -> bool:
+  return IndexType.isinstance(t)
 
 def _get_floating_point_width(t: Type) -> int:
   # TODO: Create a FloatType in the Python API and implement the switch
index 9ebf7a9..bb1938d 100644 (file)
@@ -13,7 +13,7 @@ op body. The class hierarchy is laid out to map well to a form of YAML that
 can be easily consumed from the C++ side, not necessarily for ergonomics.
 """
 
-from typing import Optional, Sequence
+from typing import Any, Optional, Sequence
 
 from .yaml_helper import *
 from .types import *
@@ -22,6 +22,9 @@ __all__ = [
     "ScalarAssign",
     "ScalarApplyFn",
     "ScalarArg",
+    "ScalarCapture",
+    "ScalarConst",
+    "ScalarIndex",
     "ScalarExpression",
     "ScalarSymbolicCast",
 ]
@@ -53,6 +56,42 @@ class ScalarArg:
   def __repr__(self):
     return f"(ScalarArg({self.arg})"
 
+class ScalarCapture:
+  """A type of ScalarExpression that references a named capture."""
+
+  def __init__(self, capture: str):
+    self.capture = capture
+
+  def expr(self) -> "ScalarExpression":
+    return ScalarExpression(scalar_capture=self)
+
+  def __repr__(self):
+    return f"(ScalarCapture({self.capture})"
+
+class ScalarConst:
+  """A type of ScalarExpression representing a constant."""
+
+  def __init__(self, type_var: TypeVar, value: Any):
+    self.type_var = type_var
+    self.value = value
+
+  def expr(self) -> "ScalarExpression":
+    return ScalarExpression(scalar_const=self)
+
+  def __repr__(self):
+    return f"(ScalarConst({self.type_var}, {self.value})"
+
+class ScalarIndex:
+  """A type of ScalarExpression accessing an iteration index."""
+
+  def __init__(self, dim : int):
+    self.dim = dim
+
+  def expr(self) -> "ScalarExpression":
+    return ScalarExpression(scalar_index=self)
+
+  def __repr__(self):
+    return f"(ScalarIndex({self.dim})"
 
 class ScalarSymbolicCast:
   """A type of ScalarExpression that symbolically casts an operand to a TypeVar.
@@ -75,6 +114,9 @@ class ScalarExpression(YAMLObject):
   Can be one of:
     - ScalarApplyFn
     - ScalarArg
+    - ScalarCapture
+    - ScalarConst
+    - ScalarIndex
     - ScalarSymbolicCast
   """
   yaml_tag = "!ScalarExpression"
@@ -82,13 +124,20 @@ class ScalarExpression(YAMLObject):
   def __init__(self,
                scalar_apply: Optional[ScalarApplyFn] = None,
                scalar_arg: Optional[ScalarArg] = None,
+               scalar_capture: Optional[ScalarCapture] = None,
+               scalar_const: Optional[ScalarConst] = None,
+               scalar_index: Optional[ScalarIndex] = None,
                symbolic_cast: Optional[ScalarSymbolicCast] = None):
-    if (bool(scalar_apply) + bool(scalar_arg) + bool(symbolic_cast)) != 1:
+    if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_capture) +
+        bool(scalar_const) + bool(scalar_index) + bool(symbolic_cast)) != 1:
       raise ValueError(
-          "One of 'scalar_apply', 'scalar_block_arg', 'symbolic_cast' must be "
-          "specified")
+          "One of 'scalar_apply', 'scalar_arg', 'scalar_capture', 'scalar_const', "
+          "'scalar_index', 'symbolic_cast' must be specified")
     self.scalar_apply = scalar_apply
     self.scalar_arg = scalar_arg
+    self.scalar_capture = scalar_capture
+    self.scalar_const = scalar_const
+    self.scalar_index = scalar_index
     self.symbolic_cast = symbolic_cast
 
   def to_yaml_custom_dict(self):
@@ -99,6 +148,13 @@ class ScalarExpression(YAMLObject):
       ))
     elif self.scalar_arg:
       return dict(scalar_arg=self.scalar_arg.arg)
+    elif self.scalar_capture:
+      return dict(scalar_capture=self.scalar_capture.capture)
+    elif self.scalar_const:
+      return dict(scalar_const=dict(type_var=self.scalar_const.type_var.name,
+                                    attributes=[self.scalar_const.value]))
+    elif self.scalar_index:
+      return dict(scalar_index=self.scalar_index.dim)
     elif self.symbolic_cast:
       # Note that even though operands must be arity 1, we write it the
       # same way as for apply because it allows handling code to be more
index 35bbfe7..ddac872 100644 (file)
@@ -22,6 +22,12 @@ __all__ = [
     "TypeVar",
     "TV",
 
+    # Predefined types.
+    "I32",
+    "I64",
+    "F32",
+    "F64",
+
     # TypeVar aliases.
     "T",
     "U",
@@ -63,6 +69,12 @@ class TypeVar:
 # Expando access via TV.foo
 TV = TypeVar.create_expando()
 
+# Predefined types.
+I32 = TV.I32
+I64 = TV.I64
+F32 = TV.F32
+F64 = TV.F64
+
 # Some common type name aliases.
 T = TV.T
 U = TV.U
index 5445dae..91274dd 100644 (file)
@@ -23,6 +23,18 @@ def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
                 C=TensorDef(U, S.M, S.N, output=True)):
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
+@linalg_structured_op
+def fill_rng_2d(A=TensorDef(T, S.M, S.N, output=True),
+                min=CaptureDef(F64),
+                max=CaptureDef(F64),
+                seed=CaptureDef(I32)):
+  multiplier = const(I32, 1103515245)
+  increment = const(I32, 12345)
+  temp1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
+  temp2 = (cast(I32, index(D.n)) + temp1) * multiplier + increment
+  inv_randmax = const(F64, 2.3283064e-10)
+  scaling = (max - min) * inv_randmax
+  A[D.m, D.n] = cast(T, cast(F64, temp2) * scaling + min)
 
 with Context() as ctx, Location.unknown():
   module = Module.create()
@@ -142,5 +154,27 @@ with Context() as ctx, Location.unknown():
     def test_f64f64f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
+    # CHECK-LABEL: @test_fill_rng_2d
+    # CHECK-SAME:  %{{.*}} tensor<4x16xi32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32
+    # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
+    # CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
+    # CHECK-DAG:    %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
+    # CHECK-DAG:    %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32
+    # CHECK-DAG:    %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
+    # CHECK-DAG:    %[[CST0:.+]] = constant 1103515245 : i32
+    # CHECK-DAG:    %[[CST1:.+]] = constant 12345 : i32
+    # CHECK-DAG:    %[[RND1:.+]] = muli %[[RND0]], %[[CST0]] : i32
+    # CHECK-DAG:    %[[RND2:.+]] = addi %[[RND1]], %[[CST1]] : i32
+    # CHECK:        %[[RND3:.+]] = sitofp %{{.*}} : i32 to f64
+    # CHECK-DAG:    %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
+    # CHECK-DAG:    %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64
+    # CHECK-DAG:    %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64
+    # CHECK-DAG:    %[[RND4:.+]] = mulf %[[RND3]], %[[FACT]] : f64
+    # CHECK-DAG:    %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64
+    # CHECK-DAG:    %{{.*}} = fptosi %[[RND5]] : f64 to i32
+    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
+                                 f64, f64, i32)
+    def test_fill_rng_2d(init_result, min, max, seed):
+      return fill_rng_2d(outs=[init_result], captures=[min, max, seed])
 
 print(module)