[mir][Python][linalg] Support OpDSL extensions in C++.
authorTobias Gysi <gysit@google.com>
Wed, 19 May 2021 13:10:28 +0000 (13:10 +0000)
committerTobias Gysi <gysit@google.com>
Wed, 19 May 2021 13:36:56 +0000 (13:36 +0000)
The patch extends the yaml code generation to support the following new OpDSL constructs:
- captures
- constants
- iteration index accesses
- predefined types
These changes have been introduced by revision
https://reviews.llvm.org/D101364.

Differential Revision: https://reviews.llvm.org/D102075

16 files changed:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/CMakeLists.txt
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/lit.cfg.py
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml [new file with mode: 0644]
mlir/test/python/dialects/linalg/opdsl/arguments.py [new file with mode: 0644]
mlir/test/python/dialects/linalg/opdsl/assignments.py
mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
mlir/test/python/dialects/linalg/opsrun.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

index 085eaed..7e8d560 100644 (file)
@@ -1,7 +1,7 @@
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
-  cpp_op_name: MatmulOp
+  cpp_class_name: MatmulOp
   doc: |-
     Performs a matrix multiplication of two 2D inputs.
 
@@ -63,7 +63,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matmul
-  cpp_op_name: BatchMatmulOp
+  cpp_class_name: BatchMatmulOp
   doc: |-
     Performs a batched matrix multiplication of two 3D inputs.
 
@@ -126,7 +126,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
-  cpp_op_name: MatvecOp
+  cpp_class_name: MatvecOp
   doc: |-
     Performs a matrix-vector multiplication.
 
@@ -187,7 +187,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: vecmat
-  cpp_op_name: VecmatOp
+  cpp_class_name: VecmatOp
   doc: |-
     Performs a vector-matrix multiplication.
 
@@ -248,7 +248,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: dot
-  cpp_op_name: DotOp
+  cpp_class_name: DotOp
   doc: |-
     Performs a dot product of two vectors to a scalar result.
 
@@ -305,4 +305,160 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: fill_rng_2d
+  cpp_class_name: FillRng2DOp
+  doc: |-
+    Fills the output tensor with pseudo random numbers.
+
+    The operation generations pseudo random numbers using a linear congruential
+    generator. It provides no guarantees regarding the distribution of the
+    generated random numbers. Instead of generating the random numbers
+    sequentially, it instantiates one random number generator per data element
+    and runs them in parallel. The seed operand and the indices of the data
+    element seed the random number generation. The min and max operands limit
+    the range of the generated random numbers.
 
+    Note: The captures are hard-coded till there is capture support on the C++
+    side.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !<LinalgTensorDef>
+    name: O
+    usage: output
+    shape: affine_map<()[s0, s1] -> (s0, s1)>
+    element_type_var: T
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      symbolic_cast:
+        type_var: T
+        operands:
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: add
+            operands:
+            - !ScalarExpression
+              scalar_apply:
+                fn_name: mul
+                operands:
+                - !ScalarExpression
+                  scalar_apply:
+                    fn_name: add
+                    operands:
+                    - !ScalarExpression
+                      symbolic_cast:
+                        type_var: F64
+                        operands:
+                        - !ScalarExpression
+                          scalar_const: '2147483647 : i64'
+                    - !ScalarExpression
+                      symbolic_cast:
+                        type_var: F64
+                        operands:
+                        - !ScalarExpression
+                          scalar_apply:
+                            fn_name: add
+                            operands:
+                            - !ScalarExpression
+                              scalar_apply:
+                                fn_name: mul
+                                operands:
+                                - !ScalarExpression
+                                  scalar_apply:
+                                    fn_name: add
+                                    operands:
+                                    - !ScalarExpression
+                                      symbolic_cast:
+                                        type_var: I32
+                                        operands:
+                                        - !ScalarExpression
+                                          scalar_index: 1
+                                    - !ScalarExpression
+                                      scalar_apply:
+                                        fn_name: add
+                                        operands:
+                                        - !ScalarExpression
+                                          scalar_apply:
+                                            fn_name: mul
+                                            operands:
+                                            - !ScalarExpression
+                                              scalar_apply:
+                                                fn_name: add
+                                                operands:
+                                                - !ScalarExpression
+                                                  symbolic_cast:
+                                                    type_var: I32
+                                                    operands:
+                                                    - !ScalarExpression
+                                                      scalar_index: 0
+                                                - !ScalarExpression
+                                                  symbolic_cast:
+                                                    type_var: I32
+                                                    operands:
+                                                    - !ScalarExpression
+                                                      scalar_const: '42 : i64'
+                                            - !ScalarExpression
+                                              symbolic_cast:
+                                                type_var: I32
+                                                operands:
+                                                - !ScalarExpression
+                                                  scalar_const: '1103515245 : i64'
+                                        - !ScalarExpression
+                                          symbolic_cast:
+                                            type_var: I32
+                                            operands:
+                                            - !ScalarExpression
+                                              scalar_const: '12345 : i64'
+                                - !ScalarExpression
+                                  symbolic_cast:
+                                    type_var: I32
+                                    operands:
+                                    - !ScalarExpression
+                                      scalar_const: '1103515245 : i64'
+                            - !ScalarExpression
+                              symbolic_cast:
+                                type_var: I32
+                                operands:
+                                - !ScalarExpression
+                                  scalar_const: '12345 : i64'
+                - !ScalarExpression
+                  scalar_apply:
+                    fn_name: mul
+                    operands:
+                    - !ScalarExpression
+                      scalar_apply:
+                        fn_name: sub
+                        operands:
+                        - !ScalarExpression
+                          symbolic_cast:
+                            type_var: F64
+                            operands:
+                            - !ScalarExpression
+                              scalar_const: '1000 : i64'
+                        - !ScalarExpression
+                          symbolic_cast:
+                            type_var: F64
+                            operands:
+                            - !ScalarExpression
+                              scalar_const: '-1000 : i64'
+                    - !ScalarExpression
+                      symbolic_cast:
+                        type_var: F64
+                        operands:
+                        - !ScalarExpression
+                          scalar_const: '2.3283063999999999E-10 : f64'
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: F64
+                operands:
+                - !ScalarExpression
+                  scalar_const: '-1000 : i64'
index 1801f27..ee2136a 100644 (file)
@@ -220,14 +220,15 @@ namespace {
 
 class RegionBuilderHelper {
 public:
-  RegionBuilderHelper(Block &block) : block(block) {}
+  RegionBuilderHelper(MLIRContext *context, Block &block)
+      : context(context), block(block) {}
 
   // Generates operations to cast the given operand to a specified type.
   // If the cast cannot be performed, a warning will be issued and the
   // operand returned as-is (which will presumably yield a verification
   // issue downstream).
   Value cast(Type toType, Value operand) {
-    OpBuilder builder = getBuilder(operand);
+    OpBuilder builder = getBuilder();
     auto loc = operand.getLoc();
 
     if (operand.getType() == toType)
@@ -236,11 +237,14 @@ public:
       // If operand is floating point, cast directly to the int type.
       if (operand.getType().isa<FloatType>())
         return builder.create<FPToSIOp>(loc, toType, operand);
+      // Cast index operands directly to the int type.
+      if (operand.getType().isIndex())
+        return builder.create<IndexCastOp>(loc, toType, operand);
       if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
         // Either sign extend or truncate.
         if (toIntType.getWidth() > fromIntType.getWidth())
           return builder.create<SignExtendIOp>(loc, toType, operand);
-        else if (toIntType.getWidth() < fromIntType.getWidth())
+        if (toIntType.getWidth() < fromIntType.getWidth())
           return builder.create<TruncateIOp>(loc, toType, operand);
       }
     } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
@@ -251,7 +255,7 @@ public:
       if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
         if (toFloatType.getWidth() > fromFloatType.getWidth())
           return builder.create<FPExtOp>(loc, toFloatType, operand);
-        else if (toFloatType.getWidth() < fromFloatType.getWidth())
+        if (toFloatType.getWidth() < fromFloatType.getWidth())
           return builder.create<FPTruncOp>(loc, toFloatType, operand);
       }
     }
@@ -262,19 +266,28 @@ public:
   }
 
   Value applyfn__add(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder(lhs);
+    OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
       return builder.create<AddFOp>(lhs.getLoc(), lhs, rhs);
-    else if (isInteger(lhs))
+    if (isInteger(lhs))
       return builder.create<AddIOp>(lhs.getLoc(), lhs, rhs);
     llvm_unreachable("unsupported non numeric type");
   }
 
+  Value applyfn__sub(Value lhs, Value rhs) {
+    OpBuilder builder = getBuilder();
+    if (isFloatingPoint(lhs))
+      return builder.create<SubFOp>(lhs.getLoc(), lhs, rhs);
+    if (isInteger(lhs))
+      return builder.create<SubIOp>(lhs.getLoc(), lhs, rhs);
+    llvm_unreachable("unsupported non numeric type");
+  }
+
   Value applyfn__mul(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder(lhs);
+    OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
       return builder.create<MulFOp>(lhs.getLoc(), lhs, rhs);
-    else if (isInteger(lhs))
+    if (isInteger(lhs))
       return builder.create<MulIOp>(lhs.getLoc(), lhs, rhs);
     llvm_unreachable("unsupported non numeric type");
   }
@@ -284,18 +297,39 @@ public:
     if (values.empty())
       return;
     Value first = values.front();
-    OpBuilder builder = getBuilder(first);
+    OpBuilder builder = getBuilder();
     builder.create<YieldOp>(first.getLoc(), values);
   }
 
+  Value constant(std::string value) {
+    OpBuilder builder = getBuilder();
+    Location loc = builder.getUnknownLoc();
+    Attribute valueAttr = parseAttribute(value, builder.getContext());
+    return builder.create<ConstantOp>(loc, valueAttr.getType(), valueAttr);
+  }
+
+  Value index(int64_t dim) {
+    OpBuilder builder = getBuilder();
+    return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
+  }
+
+  Type getIntegerType(unsigned width) {
+    return IntegerType::get(context, width);
+  }
+
+  Type getFloat32Type() { return Float32Type::get(context); }
+
+  Type getFloat64Type() { return Float64Type::get(context); }
+
 private:
+  MLIRContext *context;
   Block &block;
 
   bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
   bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
 
-  OpBuilder getBuilder(Value value) {
-    OpBuilder builder(value.getContext());
+  OpBuilder getBuilder() {
+    OpBuilder builder(context);
     builder.setInsertionPointToEnd(&block);
     return builder;
   }
@@ -1476,7 +1510,6 @@ computeReshapeCollapsedType(MemRefType type,
       MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
 }
 
-
 template <typename AffineExprTy>
 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
   unsigned pos = 0;
index 9b93d33..2ac0641 100644 (file)
@@ -8,7 +8,7 @@ about it typically involves processing this form into config objects that
 represent actual op definitions (i.e. YAML).
 """
 
-from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 from mlir import ir as _ir
 
@@ -36,8 +36,8 @@ class TensorExpression:
     results = set()
 
     def visit_dim_def(dim_def):
-        if isinstance(dim_def, DimDef):
-          results.add(dim_def)
+      if isinstance(dim_def, DimDef):
+        results.add(dim_def)
 
     def visit_affine_exprs(expr):
       if isinstance(expr, TensorUse):
@@ -52,23 +52,29 @@ class TensorExpression:
 
   def collect_uses(self, uses: Set["TensorUse"]):
     """Collects all TensorUses reachable through this expression."""
+
     def visit_tensor_use(expr):
       if isinstance(expr, TensorUse):
         uses.add(expr)
+
     self.visit_tensor_exprs(visit_tensor_use)
 
   def collect_indices(self, indices: Set["index"]):
     """Collects all index accesses reachable through this expression."""
+
     def visit_index(expr):
       if isinstance(expr, index):
         indices.add(expr)
+
     self.visit_tensor_exprs(visit_index)
 
   def collect_captures(self, captures: Set["CaptureDef"]):
     """Collects all CaptureDefs reachable through this expression."""
+
     def visit_capture_def(expr):
       if isinstance(expr, CaptureDef):
         captures.add(expr)
+
     self.visit_tensor_exprs(visit_capture_def)
 
   def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
@@ -159,8 +165,8 @@ class TensorDef:
 
   def __getitem__(self, dims) -> TensorUse:
     assert self.owner, "TensorDef is not attached to an op"
-    state = AffineBuildState(global_state=self.owner._affine_state,
-                             allow_new_symbols=False)
+    state = AffineBuildState(
+        global_state=self.owner._affine_state, allow_new_symbols=False)
     if not isinstance(dims, tuple):
       dims = (dims,)  # Handle single subscript case.
     # Special case: (None) is a 0d-scalar use.
@@ -196,6 +202,7 @@ class TensorDef:
     return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, "
             f"shape={self.shape})")
 
+
 class CaptureDef(TensorExpression):
   """Defines an SSA value captured by the operation.
 
@@ -226,6 +233,7 @@ class CaptureDef(TensorExpression):
   def __repr__(self):
     return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})")
 
+
 class Comprehension:
   """Represents a single comprehension."""
 
@@ -334,23 +342,27 @@ class PrimApply(TensorExpression):
   def __repr__(self):
     return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
 
+
 class const(TensorExpression):
   """Returns the given constant floating point or integer value."""
 
-  def __init__(self, type_var: TypeVar, value: Any):
-    if not isinstance(type_var, TypeVar):
-      raise ValueError(f"const requires a TypeVar. Got: {repr(type_var)}")
-    if not (isinstance(value, float) or isinstance(value, int)):
-      raise ValueError(f"const requires int or float. Got: {type(value)}")
-    self.type_var = type_var
-    self.value = value
+  def __init__(self, value: Any):
+    with _ir.Context():
+      if isinstance(value, float):
+        self.value = str(_ir.FloatAttr.get_f64(float(value)))
+      elif isinstance(value, int):
+        self.value = str(
+            _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
+      else:
+        raise ValueError(f"const requires int or float. Got: {type(value)}")
 
   def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarConst(self.type_var, self.value).expr()
+    return ScalarConst(self.value).expr()
 
   def __repr__(self):
     return f"const({self.type_var}, {self.value})"
 
+
 class index(TensorExpression):
   """Returns the iteration index for a given dimension name.
 
@@ -358,7 +370,7 @@ class index(TensorExpression):
   domain of the operation.
   """
 
-  def __init__(self, dim : DimDef):
+  def __init__(self, dim: DimDef):
     self.dim_def = dim
     self.dim = -1
 
@@ -433,7 +445,8 @@ class OpMetadataDef(YAMLObject):
   """Metadata about the op (generally not behavior impacting)."""
   yaml_tag = "!LinalgOpMetadata"
 
-  def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]):
+  def __init__(self, name: str, cpp_class_name: Optional[str],
+               doc: Optional[str]):
     self.name = name
     self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
     self.doc = doc
@@ -457,7 +470,8 @@ class LinalgOpDef:
                name: str,
                cpp_class_name: Optional[str] = None,
                doc: Optional[str] = None):
-    self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
+    self.metadata = OpMetadataDef(
+        name=name, cpp_class_name=cpp_class_name, doc=doc)
     self.registered_tensors = dict()  # type: Dict[str, TensorDef]
     self.registered_captures = dict()  # type: Dict[str, CaptureDef]
     self.comprehensions = list()  # type: List[Comprehension]
index a67d18c..9026e20 100644 (file)
@@ -11,7 +11,7 @@ currently encode too many details of how the language is interpreted. Move this
 to helpers on the comprehension objects themselves.
 """
 
-from typing import Any, Dict, Optional
+from typing import Dict, Optional
 
 from mlir import ir as _ir
 
@@ -70,6 +70,7 @@ class TensorDefConfig(YAMLObject):
   def __repr__(self):
     return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})"
 
+
 class CaptureDefConfig(YAMLObject):
   """Wrapper around a CaptureDef."""
   yaml_tag = "LinalgCaptureDef"
@@ -113,8 +114,7 @@ class LinalgIndexingMapsConfig(YAMLObject):
 
 
 class LinalgStructuredOpConfig(YAMLObject):
-  """Configuration for metadata sufficient to construct a linalg single
-  contraction named op."""
+  """Configuration for metadata sufficient to construct a linalg named op."""
 
   yaml_tag = "!LinalgStructuredOpConfig"
 
@@ -156,8 +156,8 @@ class LinalgStructuredOpConfig(YAMLObject):
     for cuse in self.uses.values():
       cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
     for cdef in self.tensor_args.values():
-      cdef.shape_map = self._normalize_affine_map(cdef.shape_map,
-                                                  with_dims=False)
+      cdef.shape_map = self._normalize_affine_map(
+          cdef.shape_map, with_dims=False)
 
     # Now for each write use, propagate the indexing maps from the use to the
     # tensor, ensuring that there are not conflicts.
@@ -198,8 +198,8 @@ class LinalgStructuredOpConfig(YAMLObject):
     for index in collected_indices:
       if index.dim_def.dimname not in self.affine_state.all_dims:
         raise ValueError(
-          f"The dimension {index.dim.dimname} is not part of the iteration "
-          f"domain {self.affine_state.all_dims}")
+            f"The dimension {index.dim.dimname} is not part of the iteration "
+            f"domain {self.affine_state.all_dims}")
       index.resolve_dimension_name(self.affine_state)
 
     # Generate the scalar assignments (used to build a body).
@@ -210,18 +210,21 @@ class LinalgStructuredOpConfig(YAMLObject):
 
   @property
   def ordered_tensor_args(self) -> Sequence[TensorDefConfig]:
-    return sorted(self.tensor_args.values(),
-                  key=lambda tdc: tdc.tensor_def.registered_index)
+    return sorted(
+        self.tensor_args.values(),
+        key=lambda tdc: tdc.tensor_def.registered_index)
 
   @property
   def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]:
-    return sorted(self.uses.values(),
-                  key=lambda tuc: tuc.tensor_use.tensor_def.registered_index)
+    return sorted(
+        self.uses.values(),
+        key=lambda tuc: tuc.tensor_use.tensor_def.registered_index)
 
   @property
   def ordered_capture_args(self) -> Sequence[CaptureDefConfig]:
-    return sorted(self.capture_args.values(),
-                  key=lambda cdc: cdc.capture_def.registered_index)
+    return sorted(
+        self.capture_args.values(),
+        key=lambda cdc: cdc.capture_def.registered_index)
 
   @property
   def ordered_dims(self) -> Sequence[Tuple[str, int]]:
@@ -252,15 +255,14 @@ class LinalgStructuredOpConfig(YAMLObject):
     if tensor_def in self.tensor_args:
       return
     with self.context:
-      local_state = AffineBuildState(global_state=self.affine_state,
-                                     allow_new_dims=False)
+      local_state = AffineBuildState(
+          global_state=self.affine_state, allow_new_dims=False)
       exprs = []
       for expr in tensor_def.shape:
         exprs.append(expr.build(state=local_state))
       assert local_state.local_dim_count == 0
-      indexing_map = _ir.AffineMap.get(dim_count=0,
-                                       symbol_count=local_state.symbol_count,
-                                       exprs=exprs)
+      indexing_map = _ir.AffineMap.get(
+          dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
 
       def_config = TensorDefConfig(tensor_def, indexing_map)
       self.tensor_args[tensor_def] = def_config
@@ -269,15 +271,16 @@ class LinalgStructuredOpConfig(YAMLObject):
     if tensor_use in self.uses:
       return
     with self.context:
-      local_state = AffineBuildState(global_state=self.affine_state,
-                                     allow_new_symbols=False)
+      local_state = AffineBuildState(
+          global_state=self.affine_state, allow_new_symbols=False)
       exprs = []
       for expr in tensor_use.indices:
         exprs.append(expr.build(state=local_state))
       assert local_state.local_symbol_count == 0
-      indexing_map = _ir.AffineMap.get(dim_count=local_state.dim_count,
-                                       symbol_count=local_state.symbol_count,
-                                       exprs=exprs)
+      indexing_map = _ir.AffineMap.get(
+          dim_count=local_state.dim_count,
+          symbol_count=local_state.symbol_count,
+          exprs=exprs)
 
       use_config = TensorUseConfig(tensor_use, indexing_map)
       self.uses[tensor_use] = use_config
@@ -299,16 +302,15 @@ class LinalgStructuredOpConfig(YAMLObject):
           exprs=list(affine_map.results))
 
   def to_yaml_custom_dict(self):
-    self_dict = dict(
-        args=self.ordered_tensor_args,
-        captures=self.ordered_capture_args,
-        # TODO: Refactor the hierarchy internally when supporting more
-        # than static (preserving this serialized form).
-        indexing_maps=LinalgIndexingMapsConfig(
-            static_indexing_maps=self.indexing_maps),
-        iterator_types=self.iterator_types,
-        assignments=self.assignments,
-    )
+    self_dict = dict(args=self.ordered_tensor_args)
+    if self.ordered_capture_args:
+      self_dict["captures"] = self.ordered_capture_args
+    # TODO: Refactor the hierarchy internally when supporting more
+    # than static (preserving this serialized form).
+    self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
+        static_indexing_maps=self.indexing_maps)
+    self_dict["iterator_types"] = self.iterator_types
+    self_dict["assignments"] = self.assignments
     return self_dict
 
   def __repr__(self):
@@ -359,9 +361,10 @@ class LinalgOpConfig(YAMLObject):
     assert len(
         tc_op_def.comprehensions) == 1, "Only one comprehension supported"
     return [
-        LinalgOpConfig(tc_op_def.metadata,
-                       structured_op=LinalgStructuredOpConfig(
-                           tc_op_def.comprehensions[0], context)),
+        LinalgOpConfig(
+            tc_op_def.metadata,
+            structured_op=LinalgStructuredOpConfig(tc_op_def.comprehensions[0],
+                                                   context)),
     ]
 
   def __repr__(self):
index 85c77d5..5538a9e 100644 (file)
@@ -2,7 +2,7 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Any, Dict, Sequence
+from typing import Dict, Sequence
 
 from mlir.ir import *
 from mlir.dialects import linalg
@@ -19,16 +19,17 @@ __all__ = [
     "emit_named_structured_op",
 ]
 
-def isa(cls : Type, ty : Type):
+
+def isa(cls: Type, ty: Type):
   try:
     cls(ty)
     return True
   except ValueError:
     return False
 
+
 def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
-                                 *ins: Value,
-                                 outs: Sequence[Value],
+                                 *ins: Value, outs: Sequence[Value],
                                  captures: Sequence[Value]):
   all_arg_defs = op_config.ordered_tensor_args
   in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
@@ -82,11 +83,13 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
 
   # Emit the generic op.
   # TODO: Support emission of pure memref form.
-  indexing_maps_attr = ArrayAttr.get(
-      [AffineMapAttr.get(am)
-       # TODO: linalg verification does not currently allow symbols.
-       # Compress them for now.
-       for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)])
+  indexing_maps_attr = ArrayAttr.get([
+      AffineMapAttr.get(am)
+      # TODO: linalg verification does not currently allow symbols.
+      # Compress them for now.
+      for am in AffineMap.compress_unused_symbols(op_config.indexing_maps,
+                                                  Context.current)
+  ])
   iterator_types_attr = ArrayAttr.get(
       [StringAttr.get(s) for s in op_config.iterator_types])
 
@@ -144,7 +147,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
 
   # If we get here, there must exist a builtin class `op_class_name`.
   ctx = Context.current
-  fully_qualified_name = 'linalg.' + op_name
+  fully_qualified_name = "linalg." + op_name
   if (not ctx.is_registered_operation(fully_qualified_name) or
       not op_class_name in linalg.__dict__.keys()):
     raise NotImplementedError(
@@ -156,7 +159,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
   # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
   # attribute that the non-yaml path does not. The non-yaml path hardcodes the
   # indexing_maps in C++ directly.
-  named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr
+  named_op.operation.attributes[
+      "linalg.memoized_indexing_maps"] = indexing_maps_attr
   # iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
 
   if len(result_types) == 1:
@@ -168,8 +172,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
 class _BodyBuilder:
   """Constructs a structured op body by evaluating assignments."""
 
-  def __init__(self,
-               type_mapping: Dict[str, Type],
+  def __init__(self, type_mapping: Dict[str, Type],
                block_arg_mapping: Dict[str, Value],
                capture_arg_mapping: Dict[str, Value]):
     self.type_mapping = type_mapping
@@ -195,12 +198,16 @@ class _BodyBuilder:
       try:
         return self.capture_arg_mapping[expr.scalar_capture.capture]
       except KeyError:
-        raise ValueError(f"Capture {expr.scalar_capture.capture} is not bound for "
-                         f"this structured op.")
+        raise ValueError(
+            f"Capture {expr.scalar_capture.capture} is not bound for "
+            f"this structured op.")
     elif expr.scalar_const:
-      return self.constant(expr.scalar_const.type_var.name, expr.scalar_const.value)
+      value_attr = Attribute.parse(expr.scalar_const.value)
+      return std.ConstantOp(value_attr.type, value_attr).result
     elif expr.scalar_index:
-      return self.index(expr.scalar_index.dim)
+      dim_attr = IntegerAttr.get(
+          IntegerType.get_signless(64), expr.scalar_index.dim)
+      return linalg.IndexOp(IndexType.get(), dim_attr).result
     elif expr.scalar_apply:
       try:
         fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}")
@@ -217,25 +224,6 @@ class _BodyBuilder:
       return self.cast(expr.symbolic_cast.to_type.name, operand_value)
     raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
 
-  def constant(self, type_var_name: str, value: Any) -> Value:
-    try:
-      type = self.type_mapping[type_var_name]
-    except KeyError:
-      raise ValueError(f"Unbound type variable '{type_var_name}' ("
-                       f"expected one of {self.type_mappings.keys()}")
-    try:
-      if(_is_floating_point_type(type)):
-        return std.ConstantOp(type, FloatAttr.get(type, float(value))).result
-      elif(_is_integer_type(type)):
-        return std.ConstantOp(type, IntegerAttr.get(type, int(value))).result
-    except ValueError:
-      raise ValueError(f"Unable to cast value {value} to type {type}")
-    raise NotImplementedError(f"Unimplemented constant type {type}")
-
-  def index(self, dim: int) -> Value:
-    dim_attr = IntegerAttr.get(IntegerType.get_signless(64), dim)
-    return linalg.IndexOp(IndexType.get(), dim_attr).result
-
   def cast(self, type_var_name: str, operand: Value) -> Value:
     try:
       to_type = self.type_mapping[type_var_name]
@@ -248,6 +236,7 @@ class _BodyBuilder:
       return self._cast_to_integer(to_type, operand)
     elif _is_floating_point_type(to_type):
       return self._cast_to_floating_point(to_type, operand)
+
   def _cast_to_integer(self, to_type: Type, operand: Value) -> Value:
     to_width = IntegerType(to_type).width
     operand_type = operand.type
@@ -345,6 +334,7 @@ def _get_tensor_def_names(
     *tensor_def_configs: TensorDefConfig) -> Sequence[str]:
   return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs]
 
+
 def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]):
   if name in type_mapping:
     if type_mapping[name] != type:
@@ -352,18 +342,22 @@ def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]):
                        f"{type_mapping[name]} by type {type}")
   type_mapping[name] = type
 
+
 def _is_floating_point_type(t: Type) -> bool:
   # TODO: Create a FloatType in the Python API and implement the switch
   # there.
   return (F64Type.isinstance(t) or F32Type.isinstance(t) or
           F16Type.isinstance(t) or BF16Type.isinstance(t))
 
+
 def _is_integer_type(t: Type) -> bool:
   return IntegerType.isinstance(t)
 
+
 def _is_index_type(t: Type) -> bool:
   return IndexType.isinstance(t)
 
+
 def _get_floating_point_width(t: Type) -> int:
   # TODO: Create a FloatType in the Python API and implement the switch
   # there.
index bb1938d..2cc426b 100644 (file)
@@ -13,7 +13,7 @@ op body. The class hierarchy is laid out to map well to a form of YAML that
 can be easily consumed from the C++ side, not necessarily for ergonomics.
 """
 
-from typing import Any, Optional, Sequence
+from typing import Optional, Sequence
 
 from .yaml_helper import *
 from .types import *
@@ -56,6 +56,7 @@ class ScalarArg:
   def __repr__(self):
     return f"(ScalarArg({self.arg})"
 
+
 class ScalarCapture:
   """A type of ScalarExpression that references a named capture."""
 
@@ -68,23 +69,24 @@ class ScalarCapture:
   def __repr__(self):
     return f"(ScalarCapture({self.capture})"
 
+
 class ScalarConst:
   """A type of ScalarExpression representing a constant."""
 
-  def __init__(self, type_var: TypeVar, value: Any):
-    self.type_var = type_var
+  def __init__(self, value: str):
     self.value = value
 
   def expr(self) -> "ScalarExpression":
     return ScalarExpression(scalar_const=self)
 
   def __repr__(self):
-    return f"(ScalarConst({self.type_var}, {self.value})"
+    return f"(ScalarConst({self.value})"
+
 
 class ScalarIndex:
   """A type of ScalarExpression accessing an iteration index."""
 
-  def __init__(self, dim : int):
+  def __init__(self, dim: int):
     self.dim = dim
 
   def expr(self) -> "ScalarExpression":
@@ -93,9 +95,9 @@ class ScalarIndex:
   def __repr__(self):
     return f"(ScalarIndex({self.dim})"
 
+
 class ScalarSymbolicCast:
-  """A type of ScalarExpression that symbolically casts an operand to a TypeVar.
-  """
+  """A type of ScalarExpression that symbolically casts an operand to a TypeVar."""
 
   def __init__(self, to_type: TypeVar, operand: "ScalarExpression"):
     self.to_type = to_type
@@ -142,25 +144,27 @@ class ScalarExpression(YAMLObject):
 
   def to_yaml_custom_dict(self):
     if self.scalar_apply:
-      return dict(scalar_apply=dict(
-          fn_name=self.scalar_apply.fn_name,
-          operands=list(self.scalar_apply.operands),
-      ))
+      return dict(
+          scalar_apply=dict(
+              fn_name=self.scalar_apply.fn_name,
+              operands=list(self.scalar_apply.operands),
+          ))
     elif self.scalar_arg:
       return dict(scalar_arg=self.scalar_arg.arg)
     elif self.scalar_capture:
       return dict(scalar_capture=self.scalar_capture.capture)
     elif self.scalar_const:
-      return dict(scalar_const=dict(type_var=self.scalar_const.type_var.name,
-                                    attributes=[self.scalar_const.value]))
+      return dict(scalar_const=self.scalar_const.value)
     elif self.scalar_index:
       return dict(scalar_index=self.scalar_index.dim)
     elif self.symbolic_cast:
       # Note that even though operands must be arity 1, we write it the
       # same way as for apply because it allows handling code to be more
       # generic vs having a special form.
-      return dict(symbolic_cast=dict(type_var=self.symbolic_cast.to_type.name,
-                                     operands=[self.symbolic_cast.operand]))
+      return dict(
+          symbolic_cast=dict(
+              type_var=self.symbolic_cast.to_type.name,
+              operands=[self.symbolic_cast.operand]))
     else:
       raise ValueError(f"Unexpected ScalarExpression type: {self}")
 
index b52a0e2..ad79963 100644 (file)
@@ -7,9 +7,10 @@ Batch = S.Batch
 
 
 @linalg_structured_op
-def matmul(A=TensorDef(T1, S.M, S.K),
-           B=TensorDef(T2, S.K, S.N),
-           C=TensorDef(U, S.M, S.N, output=True)):
+def matmul(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True)):
   """Performs a matrix multiplication of two 2D inputs.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -20,9 +21,10 @@ def matmul(A=TensorDef(T1, S.M, S.K),
 
 
 @linalg_structured_op
-def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
-                 B=TensorDef(T2, Batch, S.K, S.N),
-                 C=TensorDef(U, Batch, S.M, S.N, output=True)):
+def batch_matmul(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    C=TensorDef(U, Batch, S.M, S.N, output=True)):
   """Performs a batched matrix multiplication of two 3D inputs.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -33,9 +35,10 @@ def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
 
 
 @linalg_structured_op
-def matvec(A=TensorDef(T1, S.M, S.N),
-           y=TensorDef(T2, S.N),
-           x=TensorDef(U, S.M, output=True)):
+def matvec(
+    A=TensorDef(T1, S.M, S.N),
+    y=TensorDef(T2, S.N),
+    x=TensorDef(U, S.M, output=True)):
   """Performs a matrix-vector multiplication.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -46,9 +49,10 @@ def matvec(A=TensorDef(T1, S.M, S.N),
 
 
 @linalg_structured_op
-def vecmat(y=TensorDef(T1, S.M),
-           A=TensorDef(T2, S.M, S.N),
-           x=TensorDef(U, S.N, output=True)):
+def vecmat(
+    y=TensorDef(T1, S.M),
+    A=TensorDef(T2, S.M, S.N),
+    x=TensorDef(U, S.N, output=True)):
   """Performs a vector-matrix multiplication.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -59,8 +63,8 @@ def vecmat(y=TensorDef(T1, S.M),
 
 
 @linalg_structured_op
-def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
-                                                                output=True)):
+def dot(
+    A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
   """Performs a dot product of two vectors to a scalar result.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -68,3 +72,31 @@ def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
   """
   implements(ContractionOpInterface)
   C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
+
+
+@linalg_structured_op
+def fill_rng_2d(O=TensorDef(T, S.M, S.N, output=True)):
+  """Fills the output tensor with pseudo random numbers.
+
+  The operation generations pseudo random numbers using a linear congruential
+  generator. It provides no guarantees regarding the distribution of the
+  generated random numbers. Instead of generating the random numbers
+  sequentially, it instantiates one random number generator per data element
+  and runs them in parallel. The seed operand and the indices of the data
+  element seed the random number generation. The min and max operands limit
+  the range of the generated random numbers.
+
+  Note: The captures are hard-coded till there is capture support on the C++
+  side.
+  """
+  min = cast(F64, const(-1000))
+  max = cast(F64, const(+1000))
+  seed = cast(I32, const(42))
+  multiplier = cast(I32, const(1103515245))
+  increment = cast(I32, const(12345))
+  rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
+  rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment
+  inv_range = cast(F64, const(2.3283064e-10))
+  offset = cast(F64, const(2147483647))
+  scaling = (max - min) * inv_range
+  O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
index 614a990..79b8d89 100644 (file)
@@ -63,6 +63,7 @@ set(MLIR_TEST_DEPENDS
   mlir-capi-sparse-tensor-test
   mlir-cpu-runner
   mlir-linalg-ods-gen
+  mlir-linalg-ods-yaml-gen
   mlir-lsp-server
   mlir-opt
   mlir-reduce
index 251dfe6..4a431bd 100644 (file)
@@ -29,6 +29,54 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
 // CHECK-NEXT: -> tensor<16x32xi32>
 
 // -----
+
+func @generalize_fill_rng_2d_f32(%O: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_fill_rng_2d_f32
+// CHECK-SAME: (%[[O:.+]]: tensor<16x32xf32>)
+// CHECK-DAG:    %[[MIN:.+]] = constant -1000 : i64
+// CHECK-DAG:    %[[MAX:.+]] = constant 1000 : i64
+// CHECK-DAG:    %[[SEED:.+]] = constant 42 : i32
+// CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK-DAG:    %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
+// CHECK-DAG:    %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32
+// CHECK-DAG:    %[[VAL0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
+// CHECK-DAG:    %[[CST0:.+]] = constant 1103515245 : i32
+// CHECK-DAG:    %[[CST1:.+]] = constant 12345 : i32
+// CHECK-DAG:    %[[VAL1:.+]] = muli %[[VAL0]], %[[CST0]] : i32
+// CHECK-DAG:    %[[VAL2:.+]] = addi %[[VAL1]], %[[CST1]] : i32
+// Skip random number computation for the second index.
+// CHECK-DAG:    %[[MIN_CAST1:.+]] = sitofp %[[MIN]] : i64 to f64
+// CHECK-DAG:    %[[MAX_CAST:.+]] = sitofp %[[MAX]] : i64 to f64
+// CHECK-DAG:    %[[DIFF:.+]] = subf %[[MAX_CAST]], %[[MIN_CAST1]] : f64
+// CHECK-DAG:    %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64
+// CHECK-DAG:    %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64
+// CHECK-DAG:    %[[VAL4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
+// CHECK-DAG:    %[[MIN_CAST2:.+]] = sitofp %[[MIN]] : i64 to f64
+// CHECK-DAG:    %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN_CAST2]] : f64
+// CHECK-DAG:    %[[VAL6:.+]] = fptrunc %[[VAL5]] : f64 to f32
+// CHECK-NEXT:   linalg.yield %[[VAL6]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+
+func @generalize_fill_rng_2d_i32(%O: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32>
+  return %0: tensor<16x32xi32>
+}
+
+// CHECK-LABEL: @generalize_fill_rng_2d_i32
+// CHECK-SAME: (%[[O:.+]]: tensor<16x32xi32>)
+// Verifies floating point to integer cast.
+// CHECK:        %[[VAL6:.+]] = fptosi %{{.+}} : f64 to i32
+// CHECK-NEXT:   linalg.yield %[[VAL6]] : i32
+// CHECK-NEXT: -> tensor<16x32xi32>
+
+// -----
 // Verifies floating point to integer cast.
 func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
   %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
index ad46220..44f2ff1 100644 (file)
@@ -21,7 +21,7 @@ config.name = 'MLIR'
 config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
 
 # suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.test']
+config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test']
 
 # test_source_root: The root path where tests are located.
 config.test_source_root = os.path.dirname(__file__)
@@ -64,6 +64,7 @@ tools = [
     'mlir-edsc-builder-api-test',
     'mlir-cpu-runner',
     'mlir-linalg-ods-gen',
+    'mlir-linalg-ods-yaml-gen',
     'mlir-reduce',
     'mlir-sdbm-api-test',
 ]
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
new file mode 100644 (file)
index 0000000..72b7f6f
--- /dev/null
@@ -0,0 +1,137 @@
+# RUN: mlir-linalg-ods-yaml-gen %s --o-ods-decl=- | FileCheck %s --check-prefix=ODS
+# RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL
+
+# @linalg_structured_op
+# def test1(O=TensorDef(T, S.M, S.N, output=True)):
+#   """Title.
+
+#   Detailed description.
+#   """
+#   O[D.m, D.n] = cast(T, const(42)) + cast(T, index(D.n))
+
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: test1
+  cpp_class_name: Test1Op
+  doc: |-
+    Title.
+
+    Detailed description.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !<LinalgTensorDef>
+    name: O
+    usage: output
+    shape: affine_map<()[s0, s1] -> (s0, s1)>
+    element_type_var: T
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          symbolic_cast:
+            type_var: T
+            operands:
+            - !ScalarExpression
+              scalar_const: '42 : i64'
+        - !ScalarExpression
+          symbolic_cast:
+            type_var: T
+            operands:
+            - !ScalarExpression
+              scalar_index: 1
+
+# ODS-LABEL:  def Test1Op : LinalgStructuredBase_Op<"test1"
+
+#       ODS:  let summary = [{ Title. }];
+#  ODS-NEXT:  let description = [{
+#  ODS-NEXT:    Detailed description.
+#  ODS-NEXT:  }];
+
+#       ODS:  let arguments =
+#  ODS-NEXT:    Variadic<AnyShaped>:$inputs,
+#  ODS-NEXT:    Variadic<AnyShaped>:$outputs
+
+#       ODS:  let builders =
+#       ODS:    $_state.addOperands(inputs);
+#  ODS-NEXT:    $_state.addOperands(outputs);
+#  ODS-NEXT:    $_state.addAttribute(
+#  ODS-NEXT:      "operand_segment_sizes",
+#  ODS-NEXT:      $_builder.getI32VectorAttr({
+#  ODS-NEXT:        static_cast<int32_t>(inputs.size()),
+#  ODS-NEXT:        static_cast<int32_t>(outputs.size())}));
+#  ODS-NEXT:    createAndFillStructuredOpRegion<Test1Op>(
+#  ODS-NEXT:      $_builder,
+#  ODS-NEXT:      $_state,
+#  ODS-NEXT:      TypeRange(inputs),
+#  ODS-NEXT:      TypeRange(outputs)
+
+# IMPL-LABEL:  void Test1Op::regionBuilder
+#  IMPL-SAME:  (Block &block, ValueRange captures)
+#       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
+#   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]);
+#   IMPL-DAG:  Value [[VAL2:[a-z0-9]+]] = helper.index(1);
+#   IMPL-DAG:  Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]]);
+#   IMPL-DAG:  Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]);
+
+
+# @linalg_structured_op
+# def test2(I=TensorDef(T, S.M, S.N),
+#           O=TensorDef(T, S.M, S.N, output=True)):
+#   """Title.
+
+#   Detailed description.
+#   """
+#   O[D.m, D.n] = I[D.n, D.m]
+
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: test2
+  cpp_class_name: Test2Op
+  doc: |-
+    Title.
+
+    Detailed description.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !<LinalgTensorDef>
+    name: I
+    usage: input
+    shape: affine_map<()[s0, s1] -> (s0, s1)>
+    element_type_var: T
+  - !<LinalgTensorDef>
+    name: O
+    usage: output
+    shape: affine_map<()[s0, s1] -> (s0, s1)>
+    element_type_var: T
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
+    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_arg: I
+
+# IMPL-LABEL:  Test2Op::iterator_types()
+#  IMPL-NEXT:  { getParallelIteratorTypeName(), getParallelIteratorTypeName() }
+
+#       IMPL:  Test2Op::indexing_maps()
+#       IMPL:  "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>"
+#       IMPL:  "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>"
+
+#       IMPL:  void Test2Op::regionBuilder(Block &block, ValueRange captures)
+#       IMPL:  yields.push_back(block.getArgument(0));
diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
new file mode 100644 (file)
index 0000000..ce11188
--- /dev/null
@@ -0,0 +1,37 @@
+# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s
+
+from mlir.dialects.linalg.opdsl.lang import *
+
+
+# CHECK: ---
+# CHECK-LABEL: matmul
+# CHECK: args:
+# CHECK:     name: A
+# CHECK:     usage: input
+# CHECK:     shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK:     element_type_var: T
+# CHECK:     name: B
+# CHECK:     usage: input
+# CHECK:     shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+# CHECK:     element_type_var: T
+# CHECK:     name: C
+# CHECK:     usage: output
+# CHECK:     shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK:     element_type_var: U
+@linalg_structured_op
+def matmul(
+    A=TensorDef(T, S.M, S.K),
+    B=TensorDef(T, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True)):
+  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+
+
+# CHECK: ---
+# CHECK-LABEL: fill
+# CHECK: captures:
+# CHECK: - !<LinalgCaptureDef>
+# CHECK:   name: value
+# CHECK:   type_var: T
+@linalg_structured_op
+def fill(O=TensorDef(T, S.M, S.K, output=True), value=CaptureDef(T)):
+  O[D.m, D.n] = value
index e96bc0d..32c56d1 100644 (file)
@@ -2,6 +2,7 @@
 
 from mlir.dialects.linalg.opdsl.lang import *
 
+
 # CHECK: ---
 # CHECK-LABEL: matmul
 # CHECK: assignments:
@@ -23,7 +24,65 @@ from mlir.dialects.linalg.opdsl.lang import *
 # CHECK:                operands:
 # CHECK:                  scalar_arg: B
 @linalg_structured_op
-def matmul(A=TensorDef(T, S.M, S.K),
-           B=TensorDef(T, S.K, S.N),
-           C=TensorDef(U, S.M, S.N, output=True)):
+def matmul(
+    A=TensorDef(T, S.M, S.K),
+    B=TensorDef(T, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True)):
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+
+
+# CHECK: ---
+# CHECK-LABEL: constants
+# CHECK: assignments:
+# CHECK:  -
+# CHECK:    arg: O
+# CHECK:      scalar_apply:
+# CHECK:        fn_name: sub
+# CHECK:        operands:
+# CHECK:          scalar_apply:
+# CHECK:            fn_name: add
+# CHECK:            operands:
+# CHECK:              symbolic_cast:
+# CHECK:                type_var: T
+# CHECK:                operands:
+# CHECK:                  scalar_const: '3.1415926535897931 : f64'
+# CHECK:              symbolic_cast:
+# CHECK:                type_var: T
+# CHECK:                operands:
+# CHECK:                  scalar_const: '42 : i64'
+# CHECK:          symbolic_cast:
+# CHECK:            type_var: T
+# CHECK:            operands:
+# CHECK:              scalar_const: '1.{{[0]*}}e+03 : f64'
+@linalg_structured_op
+def constants(O=TensorDef(T, S.M, S.K, output=True)):
+  pi = cast(T, const(3.1415926535897931))
+  cst42 = cast(T, const(42))
+  cst1000 = cast(T, const(1e+3))
+  O[D.m, D.n] = pi + cst42 - cst1000
+
+
+# CHECK: ---
+# CHECK-LABEL: indices
+# CHECK: assignments:
+# CHECK:  -
+# CHECK:    arg: O
+# CHECK:      scalar_apply:
+# CHECK:        fn_name: add
+# CHECK:        operands:
+# CHECK:          scalar_index: 1
+# CHECK:          scalar_index: 0
+@linalg_structured_op
+def indices(O=TensorDef(T, S.M, S.K, output=True)):
+  O[D.m, D.n] = index(D.n) + index(D.m)
+
+
+# CHECK: ---
+# CHECK-LABEL: fill
+# CHECK: assignments:
+# CHECK:  -
+# CHECK:    arg: O
+# CHECK:      scalar_capture: value
+@linalg_structured_op
+def fill(O=TensorDef(T, S.M, S.K, output=True), value=CaptureDef(T)):
+  O[D.m, D.n] = value
index 91274dd..f84db9b 100644 (file)
@@ -1,7 +1,5 @@
 # RUN: %PYTHON %s | FileCheck %s
 
-from typing import Optional, Sequence
-
 from mlir.ir import *
 from mlir.dialects import builtin
 from mlir.dialects import linalg
@@ -11,30 +9,36 @@ from mlir.dialects.linalg.opdsl.lang import *
 
 
 @linalg_structured_op
-def matmul_mono(A=TensorDef(T, S.M, S.K),
-                B=TensorDef(T, S.K, S.N),
-                C=TensorDef(T, S.M, S.N, output=True)):
+def matmul_mono(
+    A=TensorDef(T, S.M, S.K),
+    B=TensorDef(T, S.K, S.N),
+    C=TensorDef(T, S.M, S.N, output=True)):
   C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
 
 
 @linalg_structured_op
-def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
-                B=TensorDef(TV.T2, S.K, S.N),
-                C=TensorDef(U, S.M, S.N, output=True)):
+def matmul_poly(
+    A=TensorDef(TV.T1, S.M, S.K),
+    B=TensorDef(TV.T2, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True)):
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
+
 @linalg_structured_op
-def fill_rng_2d(A=TensorDef(T, S.M, S.N, output=True),
-                min=CaptureDef(F64),
-                max=CaptureDef(F64),
-                seed=CaptureDef(I32)):
-  multiplier = const(I32, 1103515245)
-  increment = const(I32, 12345)
-  temp1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
-  temp2 = (cast(I32, index(D.n)) + temp1) * multiplier + increment
-  inv_randmax = const(F64, 2.3283064e-10)
-  scaling = (max - min) * inv_randmax
-  A[D.m, D.n] = cast(T, cast(F64, temp2) * scaling + min)
+def fill_rng(
+    O=TensorDef(T, S.M, S.N, output=True),
+    min=CaptureDef(F64),
+    max=CaptureDef(F64),
+    seed=CaptureDef(I32)):
+  multiplier = cast(I32, const(1103515245))
+  increment = cast(I32, const(12345))
+  rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
+  rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment
+  inv_range = cast(F64, const(2.3283064e-10))
+  offset = cast(F64, const(2147483647))
+  scaling = (max - min) * inv_range
+  O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
+
 
 with Context() as ctx, Location.unknown():
   module = Module.create()
@@ -64,8 +68,8 @@ with Context() as ctx, Location.unknown():
     # CHECK-SAME: ins(%[[A]], %[[B]]
     # CHECK-SAME: outs(%[[INITC]]
 
-    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
-                                 RankedTensorType.get((16, 8), f32))
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32))
     def test_matmul_mono(lhs, rhs):
       init_result = linalg.InitTensorOp([4, 8], f32)
       return matmul_mono(lhs, rhs, outs=[init_result.result])
@@ -78,9 +82,9 @@ with Context() as ctx, Location.unknown():
     # CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
     # CHECK-NEXT: -> tensor<4x8xi32>
-    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
-                                 RankedTensorType.get((16, 8), i8),
-                                 RankedTensorType.get((4, 8), i32))
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
+        RankedTensorType.get((4, 8), i32))
     def test_i8i8i32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -92,9 +96,9 @@ with Context() as ctx, Location.unknown():
     # CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
     # CHECK-NEXT: -> tensor<4x8xi32>
-    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
-                                 RankedTensorType.get((16, 8), i16),
-                                 RankedTensorType.get((4, 8), i32))
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i16),
+        RankedTensorType.get((4, 8), i32))
     def test_i8i16i32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -106,9 +110,9 @@ with Context() as ctx, Location.unknown():
     # CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
     # CHECK-NEXT:   linalg.yield %[[ADD]] : i16
     # CHECK-NEXT: -> tensor<4x8xi16>
-    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
-                                 RankedTensorType.get((16, 8), i32),
-                                 RankedTensorType.get((4, 8), i16))
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), i32), RankedTensorType.get((16, 8), i32),
+        RankedTensorType.get((4, 8), i16))
     def test_i32i32i16_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -120,9 +124,9 @@ with Context() as ctx, Location.unknown():
     # CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
     # CHECK-NEXT: -> tensor<4x8xf32>
-    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
-                                 RankedTensorType.get((16, 8), i8),
-                                 RankedTensorType.get((4, 8), f32))
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
+        RankedTensorType.get((4, 8), f32))
     def test_i8i8f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -134,9 +138,9 @@ with Context() as ctx, Location.unknown():
     # CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
     # CHECK-NEXT: -> tensor<4x8xf32>
-    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f16),
-                                 RankedTensorType.get((16, 8), f16),
-                                 RankedTensorType.get((4, 8), f32))
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f16), RankedTensorType.get((16, 8), f16),
+        RankedTensorType.get((4, 8), f32))
     def test_f16f16f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -148,33 +152,36 @@ with Context() as ctx, Location.unknown():
     # CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
     # CHECK-NEXT: -> tensor<4x8xf32>
-    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f64),
-                                 RankedTensorType.get((16, 8), f64),
-                                 RankedTensorType.get((4, 8), f32))
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f64), RankedTensorType.get((16, 8), f64),
+        RankedTensorType.get((4, 8), f32))
     def test_f64f64f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
-    # CHECK-LABEL: @test_fill_rng_2d
+    # CHECK-LABEL: @test_fill_rng
     # CHECK-SAME:  %{{.*}} tensor<4x16xi32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32
     # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
     # CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
     # CHECK-DAG:    %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
     # CHECK-DAG:    %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32
     # CHECK-DAG:    %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
-    # CHECK-DAG:    %[[CST0:.+]] = constant 1103515245 : i32
-    # CHECK-DAG:    %[[CST1:.+]] = constant 12345 : i32
-    # CHECK-DAG:    %[[RND1:.+]] = muli %[[RND0]], %[[CST0]] : i32
-    # CHECK-DAG:    %[[RND2:.+]] = addi %[[RND1]], %[[CST1]] : i32
-    # CHECK:        %[[RND3:.+]] = sitofp %{{.*}} : i32 to f64
+    # CHECK-DAG:    %[[CST0:.+]] = constant 1103515245 : i64
+    # CHECK-DAG:    %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32
+    # CHECK-DAG:    %[[CST1:.+]] = constant 12345 : i64
+    # CHECK-DAG:    %[[CST1_CAST:.+]] = trunci %[[CST1]] : i64 to i32
+    # CHECK-DAG:    %[[RND1:.+]] = muli %[[RND0]], %[[CST0_CAST]] : i32
+    # CHECK-DAG:    %[[RND2:.+]] = addi %[[RND1]], %[[CST1_CAST]] : i32
+    # Skip random number computation for the second index.
     # CHECK-DAG:    %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
-    # CHECK-DAG:    %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64
-    # CHECK-DAG:    %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64
-    # CHECK-DAG:    %[[RND4:.+]] = mulf %[[RND3]], %[[FACT]] : f64
+    # CHECK-DAG:    %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64
+    # CHECK-DAG:    %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64
+    # CHECK-DAG:    %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
     # CHECK-DAG:    %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64
     # CHECK-DAG:    %{{.*}} = fptosi %[[RND5]] : f64 to i32
-    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
-                                 f64, f64, i32)
-    def test_fill_rng_2d(init_result, min, max, seed):
-      return fill_rng_2d(outs=[init_result], captures=[min, max, seed])
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), i32), f64, f64, i32)
+    def test_fill_rng(init_result, min, max, seed):
+      return fill_rng(outs=[init_result], captures=[min, max, seed])
+
 
 print(module)
index c46863b..8d48f0a 100644 (file)
@@ -8,13 +8,15 @@ from mlir.dialects import std
 from mlir.passmanager import *
 from mlir.execution_engine import *
 
+
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
 def log(*args):
   print(*args, file=sys.stderr)
   sys.stderr.flush()
 
-boilerplate = """
+
+matmul_boiler = """
 func @main() -> f32 attributes {llvm.emit_c_interface} {
   %v0 = constant 0.0 : f32
   %v1 = constant 1.0 : f32
@@ -27,7 +29,7 @@ func @main() -> f32 attributes {llvm.emit_c_interface} {
   linalg.fill(%B, %v2) : memref<16x8xf32>, f32
   linalg.fill(%C, %v0) : memref<4x8xf32>, f32
 
-  call @matmul_on_buffers(%A, %B, %C) : 
+  call @matmul_on_buffers(%A, %B, %C) :
     (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
 
   %c0 = constant 0 : index
@@ -38,7 +40,23 @@ func @main() -> f32 attributes {llvm.emit_c_interface} {
 }
 """
 
-def transform(module):
+fill_boiler = """
+func @main() -> i32 attributes {llvm.emit_c_interface} {
+  %O = memref.alloc() : memref<4x16xi32>
+
+  call @fill_on_buffers(%O) :
+    (memref<4x16xi32>) -> ()
+
+  %c0 = constant 0 : index
+  %0 = memref.load %O[%c0, %c0] : memref<4x16xi32>
+
+  // TODO: FFI-based solution to allow testing and printing with python code.
+  return %0 : i32
+}
+"""
+
+
+def transform(module, boilerplate):
   import mlir.conversions
   import mlir.dialects.linalg.passes
   import mlir.transforms
@@ -46,26 +64,27 @@ def transform(module):
   # TODO: Allow cloning functions from one module to another.
   # Atm we have to resort to string concatenation.
   mod = Module.parse(
-    str(module.operation.regions[0].blocks[0].operations[0].operation) +
-    boilerplate)
-  pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," + 
-                         "convert-vector-to-llvm," + 
-                         "convert-std-to-llvm")
+      str(module.operation.regions[0].blocks[0].operations[0].operation) +
+      boilerplate)
+  pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," +
+                         "convert-vector-to-llvm," + "convert-std-to-llvm")
   pm.run(mod)
   return mod
 
-def test_builtin():
+
+def test_matmul_builtin():
   with Context() as ctx, Location.unknown():
     module = Module.create()
     f32 = F32Type.get()
     with InsertionPoint(module.body):
-      @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
-                                   MemRefType.get((16, 8), f32),
-                                   MemRefType.get((4, 8), f32))
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
+          MemRefType.get((4, 8), f32))
       def matmul_on_buffers(lhs, rhs, out):
         linalg.matmul(lhs, rhs, outs=[out])
-    
-    execution_engine = ExecutionEngine(transform(module))
+
+    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
 
     # TODO: FFI-based solution to allow testing and printing with python code.
     # Prepare arguments: one result f32.
@@ -74,23 +93,26 @@ def test_builtin():
     res = c_float_p(-1.)
     execution_engine.invoke("main", res)
 
-    log('RESULT: ', res[0])
+    log("RESULT: ", res[0])
     # CHECK: RESULT: 32.0
 
-test_builtin()
 
-def test_generic():
+test_matmul_builtin()
+
+
+def test_matmul_generic():
   with Context() as ctx, Location.unknown():
     module = Module.create()
     f32 = F32Type.get()
     with InsertionPoint(module.body):
-      @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
-                                   MemRefType.get((16, 8), f32),
-                                   MemRefType.get((4, 8), f32))
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
+          MemRefType.get((4, 8), f32))
       def matmul_on_buffers(lhs, rhs, out):
         linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
-    
-    execution_engine = ExecutionEngine(transform(module))
+
+    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
 
     # TODO: FFI-based solution to allow testing and printing with python code.
     # Prepare arguments: one result f32.
@@ -99,7 +121,62 @@ def test_generic():
     res = c_float_p(-1.)
     execution_engine.invoke("main", res)
 
-    log('RESULT: ', res[0])
+    log("RESULT: ", res[0])
     # CHECK: RESULT: 32.0
 
-test_generic()
+
+test_matmul_generic()
+
+
+def test_fill_builtin():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f64 = F64Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32))
+      def fill_on_buffers(out):
+        linalg.fill_rng_2d(outs=[out])
+
+    execution_engine = ExecutionEngine(transform(module, fill_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: -480
+
+
+test_fill_builtin()
+
+
+def test_fill_generic():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f64 = F64Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32))
+      def fill_on_buffers(out):
+        linalg.fill_rng_2d(outs=[out])
+
+    execution_engine = ExecutionEngine(transform(module, fill_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: -480
+
+
+test_fill_generic()
index 53a5807..1c5a20b 100644 (file)
@@ -51,7 +51,7 @@ struct LinalgYAMLContext {
 
 struct LinalgOpMetadata {
   std::string name;
-  std::string cppOpName;
+  std::string cppClassName;
   Optional<std::string> doc;
   SmallVector<std::string> implements;
 };
@@ -102,6 +102,8 @@ struct ScalarSymbolicCast {
 
 struct ScalarExpression {
   Optional<std::string> arg;
+  Optional<std::string> constant;
+  Optional<int64_t> index;
   Optional<ScalarApply> apply;
   Optional<ScalarSymbolicCast> symbolicCast;
 };
@@ -208,7 +210,7 @@ template <>
 struct MappingTraits<LinalgOpMetadata> {
   static void mapping(IO &io, LinalgOpMetadata &info) {
     io.mapRequired("name", info.name);
-    io.mapRequired("cpp_op_name", info.cppOpName);
+    io.mapRequired("cpp_class_name", info.cppClassName);
     io.mapOptional("doc", info.doc);
     io.mapOptional("implements", info.implements);
   }
@@ -247,6 +249,8 @@ template <>
 struct MappingTraits<ScalarExpression> {
   static void mapping(IO &io, ScalarExpression &info) {
     io.mapOptional("scalar_arg", info.arg);
+    io.mapOptional("scalar_const", info.constant);
+    io.mapOptional("scalar_index", info.index);
     io.mapOptional("scalar_apply", info.apply);
     io.mapOptional("symbolic_cast", info.symbolicCast);
   }
@@ -370,12 +374,26 @@ findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgTensorDef> &args) {
   return None;
 }
 
-static Optional<int>
-findTypeVarArgIndex(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
+// Try to map the TypeVar to a predefined or an argument type.
+static Optional<std::string>
+findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
+  // Handle all predefined types.
+  if (typeVar == "I32")
+    return std::string("helper.getIntegerType(32)");
+  if (typeVar == "I64")
+    return std::string("helper.getIntegerType(64)");
+  if (typeVar == "F32")
+    return std::string("helper.getFloat32Type()");
+  if (typeVar == "F64")
+    return std::string("helper.getFloat64Type()");
+
+  // Search all argument types.
   for (auto it : llvm::enumerate(args)) {
     if (it.value().elementTypeVar == typeVar)
-      return it.index();
+      return llvm::formatv("block.getArgument({0}).getType()", it.index())
+          .str();
   }
+
   return None;
 }
 
@@ -563,10 +581,10 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
 
   interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
 
-  os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppOpName,
-                      opConfig.metadata->name, interfaceNameList, doc, attrList,
-                      opConfig.structuredOp->args.size(), attrBuilder,
-                      attrMethods);
+  os << llvm::formatv(
+      structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName,
+      opConfig.metadata->name, interfaceNameList, doc, attrList,
+      opConfig.structuredOp->args.size(), attrBuilder, attrMethods);
 
   return success();
 }
@@ -578,7 +596,7 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
     return success();
 
   raw_ostream &os = genContext.defns();
-  StringRef className = opConfig.metadata->cppOpName;
+  StringRef className = opConfig.metadata->cppClassName;
 
   // Implementation banner.
   std::string bannerComment = llvm::formatv("Implementation of {0}", className);
@@ -734,12 +752,15 @@ std::string {0}::getLibraryCallName() {{
   {
     // Generates a regionBuilder method. Parameters.
     // {0}: Class name
-    // {1}: Statements
+    // {1}: Number of args
+    // {2}: Statements
     static const char structuredOpRegionBuilderFormat[] = R"FMT(
 void {0}::regionBuilder(Block &block, ValueRange captures) {{
-  RegionBuilderHelper helper(block);
+  assert({1} > 0 && block.getNumArguments() == {1} &&
+         "{0} regionBuilder expects {1} (>=0) args");
+  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
   SmallVector<Value> yields;
-  {1}
+  {2}
   helper.yieldOutputs(yields);
 }
 )FMT";
@@ -769,12 +790,27 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
           Optional<int> argIndex = findTensorDefArgIndex(*expression.arg, args);
           if (!argIndex) {
             emitError(genContext.getLoc())
-                << "scalar argument not defined on the op: " << arg.name;
+                << "scalar argument not defined on the op: " << *expression.arg;
             return None;
           }
           return std::string(
               llvm::formatv("block.getArgument({0})", *argIndex));
-        } else if (expression.apply) {
+        }
+        if (expression.constant) {
+          std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
+          stmts.push_back(
+              llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT",
+                            cppIdent, expression.constant));
+          return cppIdent;
+        }
+        if (expression.index) {
+          // Access an iteration index.
+          std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
+          stmts.push_back(llvm::formatv("Value {0} = helper.index({1});",
+                                        cppIdent, *expression.index));
+          return cppIdent;
+        }
+        if (expression.apply) {
           // Apply function.
           // Recursively generate operands.
           SmallVector<std::string> operandCppValues;
@@ -790,7 +826,8 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
                             expression.apply->fnName,
                             interleaveToString(operandCppValues, ", ")));
           return cppIdent;
-        } else if (expression.symbolicCast) {
+        }
+        if (expression.symbolicCast) {
           // Symbolic cast.
           // Operands must be arity 1.
           if (expression.symbolicCast->operands.size() != 1) {
@@ -803,29 +840,23 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
           if (!operandCppValue)
             return None;
 
-          // Try to map the TypeVar to an arg index (which map to block arg
-          // indices), since we can just get that type directly.
-          // TODO: Handle free type variables which do not map to an argument.
-          Optional<int> typeArgIndex =
-              findTypeVarArgIndex(expression.symbolicCast->typeVar, args);
-          if (!typeArgIndex) {
+          Optional<std::string> typeCppValue =
+              findTypeValue(expression.symbolicCast->typeVar, args);
+          if (!typeCppValue) {
             emitError(genContext.getLoc())
                 << "type variable " << expression.symbolicCast->typeVar
-                << ", used in a symbolic cast must map to an argument but it "
-                << "does not";
+                << ", used in a symbolic cast must map to a predefined or "
+                << "an argument type but it does not";
             return None;
           }
-          std::string typeCppValue =
-              llvm::formatv("block.getArgument({0}).getType()", *typeArgIndex);
           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
           stmts.push_back(llvm::formatv("Value {0} = helper.cast({1}, {2});",
-                                        cppIdent, typeCppValue,
+                                        cppIdent, typeCppValue.getValue(),
                                         *operandCppValue));
           return cppIdent;
-        } else {
-          emitError(genContext.getLoc()) << "unknown ScalarExpression type";
-          return None;
         }
+        emitError(genContext.getLoc()) << "unknown ScalarExpression type";
+        return None;
       };
       Optional<std::string> cppValue = generateExpression(assignment->value);
       if (!cppValue)
@@ -837,7 +868,8 @@ void {0}::regionBuilder(Block &block, ValueRange captures) {{
       return emitError(genContext.getLoc())
              << "mismatched number of assignments vs output arguments";
 
-    os << llvm::formatv(structuredOpRegionBuilderFormat, className,
+    int64_t numOfArgs = args.size();
+    os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
                         interleaveToString(stmts, "\n  "));
   }
 
@@ -937,7 +969,7 @@ int main(int argc, char **argv) {
     }
 
     genContext.setLoc(NameLoc::get(
-        Identifier::get(opConfig.metadata->cppOpName, &mlirContext)));
+        Identifier::get(opConfig.metadata->cppClassName, &mlirContext)));
     if (failed(generateOp(opConfig, genContext))) {
       return 1;
     }