[mlir][PDL] Add a PDL Interpreter Dialect
authorRiver Riddle <riddleriver@gmail.com>
Wed, 26 Aug 2020 12:12:07 +0000 (05:12 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Wed, 26 Aug 2020 12:22:27 +0000 (05:22 -0700)
The PDL Interpreter dialect provides a lower level abstraction compared to the PDL dialect, and is targeted towards low level optimization and interpreter code generation. The dialect operations encapsulates low-level pattern match and rewrite "primitives", such as navigating the IR (Operation::getOperand), creating new operations (OpBuilder::create), etc. Many of the operations within this dialect also fuse branching control flow with some form of a predicate comparison operation. This type of fusion reduces the amount of work that an interpreter must do when executing.

An example of this representation is shown below:

```mlir
// The following high level PDL pattern:
pdl.pattern : benefit(1) {
  %resultType = pdl.type
  %inputOperand = pdl.input
  %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType
  pdl.rewrite %root {
    pdl.replace %root with (%inputOperand)
  }
}

// May be represented in the interpreter dialect as follows:
module {
  func @matcher(%arg0: !pdl.operation) {
    pdl_interp.check_operation_name of %arg0 is "foo.op" -> ^bb2, ^bb1
  ^bb1:
    pdl_interp.return
  ^bb2:
    pdl_interp.check_operand_count of %arg0 is 1 -> ^bb3, ^bb1
  ^bb3:
    pdl_interp.check_result_count of %arg0 is 1 -> ^bb4, ^bb1
  ^bb4:
    %0 = pdl_interp.get_operand 0 of %arg0
    pdl_interp.is_not_null %0 : !pdl.value -> ^bb5, ^bb1
  ^bb5:
    %1 = pdl_interp.get_result 0 of %arg0
    pdl_interp.is_not_null %1 : !pdl.value -> ^bb6, ^bb1
  ^bb6:
    pdl_interp.record_match @rewriters::@rewriter(%0, %arg0 : !pdl.value, !pdl.operation) : benefit(1), loc([%arg0]), root("foo.op") -> ^bb1
  }
  module @rewriters {
    func @rewriter(%arg0: !pdl.value, %arg1: !pdl.operation) {
      pdl_interp.replace %arg1 with(%arg0)
      pdl_interp.return
    }
  }
}
```

Differential Revision: https://reviews.llvm.org/D84579

25 files changed:
mlir/include/mlir/Dialect/CMakeLists.txt
mlir/include/mlir/Dialect/PDL/IR/PDLBase.td
mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
mlir/include/mlir/Dialect/PDLInterp/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/Dialect/PDLInterp/IR/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h [new file with mode: 0644]
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td [new file with mode: 0644]
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/Dialect/PDLInterp/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Dialect/PDLInterp/IR/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp [new file with mode: 0644]
mlir/lib/IR/Builders.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/test/Dialect/PDL/invalid.mlir
mlir/test/Dialect/PDL/ops.mlir
mlir/test/Dialect/PDLInterp/ops.mlir [new file with mode: 0644]
mlir/tools/mlir-tblgen/OpFormatGen.cpp

index 118210c..6426fa8 100644 (file)
@@ -6,6 +6,7 @@ add_subdirectory(LLVMIR)
 add_subdirectory(OpenACC)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)
+add_subdirectory(PDLInterp)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
index a3392f2..9802bf9 100644 (file)
@@ -49,7 +49,7 @@ def PDL_Dialect : Dialect {
       %resultType = pdl.type
       %inputOperand = pdl.input
       %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType
-      pdl.rewrite(%root) {
+      pdl.rewrite %root {
         pdl.replace %root with (%inputOperand)
       }
     }
index 1e865fb..73b4f26 100644 (file)
@@ -51,17 +51,18 @@ def PDL_ApplyConstraintOp
     ```
   }];
 
-  let arguments = (ins Variadic<PDL_PositionalValue>:$args,
-                       ArrayAttr:$params,
-                       StrAttr:$name);
-  let assemblyFormat = "$name $params `(` $args `:` type($args) `)` attr-dict";
+  let arguments = (ins StrAttr:$name,
+                       Variadic<PDL_PositionalValue>:$args,
+                       OptionalAttr<ArrayAttr>:$constParams);
+  let assemblyFormat = [{
+    $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict
+  }];
 
   let builders = [
-    OpBuilder<"OpBuilder &builder, OperationState &state, "
-              "ValueRange args, ArrayRef<Attribute> params, "
-              "StringRef name", [{
-      build(builder, state, args, builder.getArrayAttr(params),
-            builder.getStringAttr(name));
+    OpBuilder<"OpBuilder &builder, OperationState &state, StringRef name, "
+              "ValueRange args = {}, ArrayRef<Attribute> params = {}", [{
+      build(builder, state, builder.getStringAttr(name), args,
+            params.empty() ? ArrayAttr() : builder.getArrayAttr(params));
     }]>,
   ];
 }
@@ -135,12 +136,13 @@ def PDL_CreateNativeOp
     ```
   }];
 
-  let arguments = (ins StrAttr:$name, Variadic<PDL_PositionalValue>:$arguments,
-                       ArrayAttr:$constantParams);
+  let arguments = (ins StrAttr:$name,
+                       Variadic<PDL_PositionalValue>:$args,
+                       OptionalAttr<ArrayAttr>:$constParams);
   let results = (outs PDL_PositionalValue:$result);
   let assemblyFormat = [{
-    $name $constantParams (`(` $arguments^ `:` type($arguments) `)`)?
-    `:` type($result) attr-dict
+    $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
+    attr-dict
   }];
   let verifier = ?;
 }
@@ -222,7 +224,7 @@ def PDL_OperationOp
     `pdl.operation`s are composed of a name, and a set of attribute, operand,
     and result type values, that map to what those that would be on a
     constructed instance of that operation. The results of a `pdl.operation` are
-    a handle to the operation itself, and a handle to each of the operation 
+    a handle to the operation itself, and a handle to each of the operation
     result values.
 
     When used within a matching context, the name of the operation may be
@@ -380,16 +382,18 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
     rewrite is specified either via a string name (`name`) to an external
     rewrite function, or via the region body. The rewrite region, if specified,
     must contain a single block and terminate via the `pdl.rewrite_end`
-    operation.
+    operation. If the rewrite is external, it also takes a set of constant
+    parameters and a set of additional positional values defined within the
+    matcher as arguments.
 
     Example:
 
     ```mlir
     // Specify an external rewrite function:
-    pdl.rewrite "myExternalRewriter"(%root)
+    pdl.rewrite %root with "myExternalRewriter"(%value : !pdl.value)
 
     // Specify the rewrite inline using PDL:
-    pdl.rewrite(%root) {
+    pdl.rewrite %root {
       %op = pdl.operation "foo.op"(%arg0, %arg1)
       pdl.replace %root with %op
     }
@@ -397,7 +401,9 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
   }];
 
   let arguments = (ins PDL_Operation:$root,
-                       OptionalAttr<StrAttr>:$name);
+                       OptionalAttr<StrAttr>:$name,
+                       Variadic<PDL_PositionalValue>:$externalArgs,
+                       OptionalAttr<ArrayAttr>:$externalConstParams);
   let regions = (region AnyRegion:$body);
 }
 
diff --git a/mlir/include/mlir/Dialect/PDLInterp/CMakeLists.txt b/mlir/include/mlir/Dialect/PDLInterp/CMakeLists.txt
new file mode 100644 (file)
index 0000000..f33061b
--- /dev/null
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/PDLInterp/IR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..8a70766
--- /dev/null
@@ -0,0 +1,2 @@
+add_mlir_dialect(PDLInterpOps pdl_interp)
+add_mlir_doc(PDLInterpOps -gen-op-doc PDLInterpOps Dialects/)
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h
new file mode 100644 (file)
index 0000000..6d89567
--- /dev/null
@@ -0,0 +1,39 @@
+//===- PDLInterp.h - PDL Interpreter dialect --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the interpreter dialect for the PDL pattern descriptor
+// language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_
+#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_
+
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+namespace mlir {
+namespace pdl_interp {
+//===----------------------------------------------------------------------===//
+// PDLInterp Dialect
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.h.inc"
+
+//===----------------------------------------------------------------------===//
+// PDLInterp Dialect Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.h.inc"
+
+} // end namespace pdl_interp
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
new file mode 100644 (file)
index 0000000..58a2032
--- /dev/null
@@ -0,0 +1,926 @@
+//===- PDLInterpOps.td - Pattern Interpreter Dialect -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the PDL interpreter dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS
+#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS
+
+include "mlir/Dialect/PDL/IR/PDLBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// PDLInterp Dialect
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_Dialect : Dialect {
+  let summary = "Interpreted pattern execution dialect";
+  let description = [{
+    The PDL Interpreter dialect provides a lower level abstraction compared to
+    the PDL dialect, and is targeted towards low level optimization and
+    interpreter code generation. The dialect operations encapsulates
+    low-level pattern match and rewrite "primitives", such as navigating the
+    IR (Operation::getOperand), creating new operations (OpBuilder::create),
+    etc. Many of the operations within this dialect also fuse branching control
+    flow with some form of a predicate comparison operation. This type of fusion
+    reduces the amount of work that an interpreter must do when executing.
+  }];
+
+  let name = "pdl_interp";
+  let cppNamespace = "mlir::pdl_interp";
+  let dependentDialects = ["pdl::PDLDialect"];
+}
+
+//===----------------------------------------------------------------------===//
+// PDLInterp Operations
+//===----------------------------------------------------------------------===//
+
+// Generic interpreter operation.
+class PDLInterp_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<PDLInterp_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// PDLInterp_PredicateOp
+
+// Check operations evaluate a predicate on a positional value and then
+// conditionally branch on the result.
+class PDLInterp_PredicateOp<string mnemonic, list<OpTrait> traits = []> :
+    PDLInterp_Op<mnemonic, !listconcat([Terminator], traits)> {
+  let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
+}
+
+//===----------------------------------------------------------------------===//
+// PDLInterp_SwitchOp
+
+// Switch operations evaluate a predicate on a positional value and then
+// conditionally branch on the result.
+class PDLInterp_SwitchOp<string mnemonic, list<OpTrait> traits = []> :
+    PDLInterp_Op<mnemonic, !listconcat([Terminator], traits)> {
+  let successors = (successor AnySuccessor:$defaultDest,
+                              VariadicSuccessor<AnySuccessor>:$cases);
+
+  let verifier = [{
+    // Verify that the number of case destinations matches the number of case
+    // values.
+    size_t numDests = cases().size();
+    size_t numValues = caseValues().size();
+    if (numDests != numValues) {
+      return emitOpError("expected number of cases to match the number of case "
+                         "values, got ")
+          << numDests << " but expected " << numValues;
+    }
+    return success();
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ApplyConstraintOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
+  let summary = "Apply a constraint to a set of positional values";
+  let description = [{
+    `pdl_interp.apply_constraint` operations apply a generic constraint, that
+    has been registered with the interpreter, with a given set of positional
+    values. The constraint may have any number of constant parameters. On
+    success, this operation branches to the true destination, otherwise the
+    false destination is taken.
+
+    Example:
+
+    ```mlir
+    // Apply `myConstraint` to the entities defined by `input`, `attr`, and
+    // `op`.
+    pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+    ```
+  }];
+
+  let arguments = (ins StrAttr:$name,
+                       Variadic<PDL_PositionalValue>:$args,
+                       OptionalAttr<ArrayAttr>:$constParams);
+  let assemblyFormat = [{
+    $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict `->`
+    successors
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ApplyRewriteOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> {
+  let summary = "Invoke and apply an externally registered rewrite method";
+  let description = [{
+    `pdl_interp.apply_rewrite` operations invoke an external rewriter that has
+    been registered with the interpreter to perform the rewrite after a
+    successful match. The rewrite is passed the root operation being matched, a
+    set of additional positional arguments generated within the matcher, and a
+    set of constant parameters.
+
+    Example:
+
+    ```mlir
+    // Rewriter operating solely on the root operation.
+    pdl_interp.apply_rewrite "rewriter" on %root
+
+    // Rewriter operating on the root operation along with additional arguments
+    // from the matcher.
+    pdl_interp.apply_rewrite "rewriter"(%value : !pdl.value) on %root
+
+    // Rewriter operating on the root operation along with additional arguments
+    // and constant parameters.
+    pdl_interp.apply_rewrite "rewriter"[42](%value : !pdl.value) on %root
+    ```
+  }];
+  let arguments = (ins StrAttr:$name,
+                       PDL_Operation:$root,
+                       Variadic<PDL_PositionalValue>:$args,
+                       OptionalAttr<ArrayAttr>:$constParams);
+  let assemblyFormat = [{
+    $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `on` $root
+    attr-dict
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::AreEqualOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_AreEqualOp
+    : PDLInterp_PredicateOp<"are_equal", [NoSideEffect, SameTypeOperands]> {
+  let summary = "Check if two positional values are equivalent";
+  let description = [{
+    `pdl_interp.are_equal` operations compare two positional values for
+    equality. On success, this operation branches to the true destination,
+    otherwise the false destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.are_equal %result1, %result2 : !pdl.value -> ^matchDest, ^failureDest
+    ```
+  }];
+
+  let arguments = (ins PDL_PositionalValue:$lhs,
+                       PDL_PositionalValue:$rhs);
+  let assemblyFormat = "operands `:` type($lhs) attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::BranchOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_BranchOp : PDLInterp_Op<"branch", [NoSideEffect, Terminator]> {
+  let summary = "General branch operation";
+  let description = [{
+    `pdl_interp.branch` operations expose general branch functionality to the
+    interpreter, and are generally used to branch from one pattern match
+    sequence to another.
+
+    Example:
+
+    ```mlir
+    pdl_interp.branch ^dest
+    ```
+  }];
+
+  let successors = (successor AnySuccessor:$dest);
+  let assemblyFormat = "$dest attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckAttributeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckAttributeOp
+    : PDLInterp_PredicateOp<"check_attribute", [NoSideEffect]> {
+  let summary = "Check the value of an `Attribute`";
+  let description = [{
+    `pdl_interp.check_attribute` operations compare the value of a given
+    attribute with a constant value. On success, this operation branches to the
+    true destination, otherwise the false destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.check_attribute %attr is 10 -> ^matchDest, ^failureDest
+    ```
+  }];
+
+  let arguments = (ins PDL_Attribute:$attribute, AnyAttr:$constantValue);
+  let assemblyFormat = [{
+    $attribute `is` $constantValue attr-dict `->` successors
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckOperandCountOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckOperandCountOp
+    : PDLInterp_PredicateOp<"check_operand_count", [NoSideEffect]> {
+  let summary = "Check the number of operands of an `Operation`";
+  let description = [{
+    `pdl_interp.check_operand_count` operations compare the number of operands
+    of a given operation value with a constant. On success, this operation
+    branches to the true destination, otherwise the false destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.check_operand_count of %op is 2 -> ^matchDest, ^failureDest
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation,
+                       Confined<I32Attr, [IntNonNegative]>:$count);
+  let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckOperationNameOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckOperationNameOp
+    : PDLInterp_PredicateOp<"check_operation_name", [NoSideEffect]> {
+  let summary = "Check the OperationName of an `Operation`";
+  let description = [{
+    `pdl_interp.check_operation_name` operations compare the name of a given
+    operation with a known name. On success, this operation branches to the true
+    destination, otherwise the false destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.check_operation_name of %op is "foo.op" -> ^matchDest, ^failureDest
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation, StrAttr:$name);
+  let assemblyFormat = "`of` $operation `is` $name attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckResultCountOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckResultCountOp
+    : PDLInterp_PredicateOp<"check_result_count", [NoSideEffect]> {
+  let summary = "Check the number of results of an `Operation`";
+  let description = [{
+    `pdl_interp.check_result_count` operations compare the number of results
+    of a given operation value with a constant. On success, this operation
+    branches to the true destination, otherwise the false destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.check_result_count of %op is 0 -> ^matchDest, ^failureDest
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation,
+                       Confined<I32Attr, [IntNonNegative]>:$count);
+  let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckTypeOp
+    : PDLInterp_PredicateOp<"check_type", [NoSideEffect]> {
+  let summary = "Compare a type to a known value";
+  let description = [{
+    `pdl_interp.check_type` operations compare a type with a statically known
+    type. On success, this operation branches to the true destination, otherwise
+    the false destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.check_type %type is 0 -> ^matchDest, ^failureDest
+    ```
+  }];
+
+  let arguments = (ins PDL_Type:$value, TypeAttr:$type);
+  let assemblyFormat = "$value `is` $type attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateAttributeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateAttributeOp
+    : PDLInterp_Op<"create_attribute", [NoSideEffect]> {
+  let summary = "Create an interpreter handle to a constant `Attribute`";
+  let description = [{
+    `pdl_interp.create_attribute` operations generate a handle within the
+    interpreter for a specific constant attribute value.
+
+    Example:
+
+    ```mlir
+    pdl_interp.create_attribute 10 : i64
+    ```
+  }];
+
+  let arguments = (ins AnyAttr:$value);
+  let results = (outs PDL_Attribute:$attribute);
+  let assemblyFormat = "$value attr-dict";
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, "
+              "Attribute value", [{
+    build(builder, state, builder.getType<pdl::AttributeType>(), value);
+  }]>];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateNativeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> {
+  let summary = "Call a native creation method to construct an `Attribute`, "
+                "`Operation`, `Type`, or `Value`";
+  let description = [{
+    `pdl_interp.create_native` operations invoke a native C++ function, that has
+    been registered externally with the consumer of PDL, to create an
+    `Attribute`, `Operation`, `Type`, or `Value`. The native function must
+    produce a value of the specified return type, and may accept any number of
+    positional arguments and constant attribute parameters.
+
+    Example:
+
+    ```mlir
+    %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
+    ```
+  }];
+
+  let arguments = (ins StrAttr:$name,
+                       Variadic<PDL_PositionalValue>:$args,
+                       OptionalAttr<ArrayAttr>:$constParams);
+  let results = (outs PDL_PositionalValue:$result);
+  let assemblyFormat = [{
+    $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
+    attr-dict
+  }];
+  let verifier = ?;
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateOperationOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateOperationOp
+    : PDLInterp_Op<"create_operation", [AttrSizedOperandSegments]> {
+  let summary = "Create an instance of a specific `Operation`";
+  let description = [{
+    `pdl_interp.create_operation` operations create an `Operation` instance with
+    the specified attributes, operands, and result types.
+
+    Example:
+
+    ```mlir
+    // Create an instance of a `foo.op` operation.
+    %op = pdl_interp.create_operation "foo.op"(%arg0) {"attrA" = %attr0} -> %type, %type
+    ```
+  }];
+
+  let arguments = (ins StrAttr:$name,
+                       Variadic<PDL_Value>:$operands,
+                       Variadic<PDL_Attribute>:$attributes,
+                       StrArrayAttr:$attributeNames,
+                       Variadic<PDL_Type>:$types);
+  let results = (outs PDL_Operation:$operation);
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, StringRef name, "
+              "ValueRange types, ValueRange operands, ValueRange attributes, "
+              "ArrayAttr attributeNames", [{
+    build(builder, state, builder.getType<pdl::OperationType>(), name,
+          operands, attributes, attributeNames, types);
+  }]>];
+  let parser = [{ return ::parseCreateOperationOp(parser, result); }];
+  let printer = [{ ::print(p, *this); }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateTypeOp : PDLInterp_Op<"create_type", [NoSideEffect]> {
+  let summary = "Create an interpreter handle to a constant `Type`";
+  let description = [{
+    `pdl_interp.create_type` operations generate a handle within the interpreter
+    for a specific constant type value.
+
+    Example:
+
+    ```mlir
+    pdl_interp.create_type i64
+    ```
+  }];
+
+  let arguments = (ins TypeAttr:$value);
+  let results = (outs PDL_Type:$result);
+  let assemblyFormat = "$value attr-dict";
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, TypeAttr type", [{
+      build(builder, state, builder.getType<pdl::TypeType>(), type);
+    }]>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::EraseOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_EraseOp : PDLInterp_Op<"erase"> {
+  let summary = "Mark an operation as `erased`";
+  let description = [{
+    `pdl.erase` operations are used to specify that an operation should be
+    marked as erased. The semantics of this operation correspond with the
+    `eraseOp` method on a `PatternRewriter`.
+
+    Example:
+
+    ```mlir
+    pdl_interp.erase %root
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation);
+  let assemblyFormat = "$operation attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::FinalizeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_FinalizeOp
+    : PDLInterp_Op<"finalize", [NoSideEffect, Terminator]> {
+  let summary = "Finalize a pattern match or rewrite sequence";
+  let description = [{
+    `pdl_interp.finalize` is used to denote the termination of a match or
+    rewrite sequence.
+
+    Example:
+
+    ```mlir
+    pdl_interp.finalize
+    ```
+  }];
+  let assemblyFormat = "attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetAttributeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetAttributeOp : PDLInterp_Op<"get_attribute", [NoSideEffect]> {
+  let summary = "Get a specified attribute value from an `Operation`";
+  let description = [{
+    `pdl_interp.get_attribute` operations try to get a specific attribute from
+    an operation. If the operation does not have that attribute, a null value is
+    returned.
+
+    Example:
+
+    ```mlir
+    %attr = pdl_interp.get_attribute "attr" of %op
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation,
+                       StrAttr:$name);
+  let results = (outs PDL_Attribute:$attribute);
+  let assemblyFormat = "$name `of` $operation attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetAttributeTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetAttributeTypeOp
+    : PDLInterp_Op<"get_attribute_type", [NoSideEffect]> {
+  let summary = "Get the result type of a specified `Attribute`";
+  let description = [{
+    `pdl_interp.get_attribute_type` operations get the resulting type of a
+    specific attribute.
+
+    Example:
+
+    ```mlir
+    %type = pdl_interp.get_attribute_type of %attr
+    ```
+  }];
+
+  let arguments = (ins PDL_Attribute:$value);
+  let results = (outs PDL_Type:$result);
+  let assemblyFormat = "`of` $value attr-dict";
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, Value value", [{
+      build(builder, state, builder.getType<pdl::TypeType>(), value);
+    }]>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetDefiningOpOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetDefiningOpOp
+    : PDLInterp_Op<"get_defining_op", [NoSideEffect]> {
+  let summary = "Get the defining operation of a `Value`";
+  let description = [{
+    `pdl_interp.get_defining_op` operations try to get the defining operation
+    of a specific value. If the value is not an operation result, null is
+    returned.
+
+    Example:
+
+    ```mlir
+    %op = pdl_interp.get_defining_op of %value
+    ```
+  }];
+
+  let arguments = (ins PDL_Value:$value);
+  let results = (outs PDL_Operation:$operation);
+  let assemblyFormat = "`of` $value attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetOperandOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetOperandOp : PDLInterp_Op<"get_operand", [NoSideEffect]> {
+  let summary = "Get a specified operand from an `Operation`";
+  let description = [{
+    `pdl_interp.get_operand` operations try to get a specific operand from an
+    operation If the operation does not have an operand for the given index, a
+    null value is returned.
+
+    Example:
+
+    ```mlir
+    %operand = pdl_interp.get_operand 1 of %op
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation,
+                       Confined<I32Attr, [IntNonNegative]>:$index);
+  let results = (outs PDL_Value:$value);
+  let assemblyFormat = "$index `of` $operation attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetResultOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetResultOp : PDLInterp_Op<"get_result", [NoSideEffect]> {
+  let summary = "Get a specified result from an `Operation`";
+  let description = [{
+    `pdl_interp.get_result` operations try to get a specific result from an
+    operation. If the operation does not have a result for the given index, a
+    null value is returned.
+
+    Example:
+
+    ```mlir
+    %result = pdl_interp.get_result 1 of %op
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation,
+                       Confined<I32Attr, [IntNonNegative]>:$index);
+  let results = (outs PDL_Value:$value);
+  let assemblyFormat = "$index `of` $operation attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetValueTypeOp
+//===----------------------------------------------------------------------===//
+
+// Get a type from the root operation, held in the rewriter context.
+def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect]> {
+  let summary = "Get the result type of a specified `Value`";
+  let description = [{
+    `pdl_interp.get_value_type` operations get the resulting type of a specific
+    value.
+
+    Example:
+
+    ```mlir
+    %type = pdl_interp.get_value_type of %value
+    ```
+  }];
+
+  let arguments = (ins PDL_Value:$value);
+  let results = (outs PDL_Type:$result);
+  let assemblyFormat = "`of` $value attr-dict";
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, Value value", [{
+      build(builder, state, builder.getType<pdl::TypeType>(), value);
+    }]>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::InferredTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_InferredTypeOp : PDLInterp_Op<"inferred_type"> {
+  let summary = "Generate a handle to a Type that is \"inferred\"";
+  let description = [{
+    `pdl_interp.inferred_type` operations generate a handle to a type that
+    should be inferred. This signals to other operations, such as
+    `pdl_interp.create_operation`, that this type should be inferred.
+
+    Example:
+
+    ```mlir
+    pdl_interp.inferred_type
+    ```
+  }];
+  let results = (outs PDL_Type:$type);
+  let assemblyFormat = "attr-dict";
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state", [{
+      build(builder, state, builder.getType<pdl::TypeType>());
+    }]>,
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::IsNotNullOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_IsNotNullOp
+    : PDLInterp_PredicateOp<"is_not_null", [NoSideEffect]> {
+  let summary = "Check if a positional value is non-null";
+  let description = [{
+    `pdl_interp.is_not_null` operations check that a positional value exists. On
+    success, this operation branches to the true destination. Otherwise, the
+    false destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.is_not_null %value : !pdl.value -> ^matchDest, ^failureDest
+    ```
+  }];
+
+  let arguments = (ins PDL_PositionalValue:$value);
+  let assemblyFormat = "$value `:` type($value) attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::RecordMatchOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_RecordMatchOp
+    : PDLInterp_Op<"record_match", [AttrSizedOperandSegments, Terminator]> {
+  let summary = "Record the metadata for a successful pattern match";
+  let description = [{
+    `pdl_interp.record_match` operations record a successful pattern match with
+    the interpreter and branch to the next part of the matcher. The metadata
+    recorded by these operations correspond to a specific `pdl.pattern`, as well
+    as what values were used during that match that should be propagated to the
+    rewriter.
+
+    Example:
+
+    ```mlir
+    pdl_interp.record_match @rewriters::myRewriter(%root : !pdl.operation) : benefit(1), loc([%root, %op1]), root("foo.op") -> ^nextDest
+    ```
+  }];
+
+  let arguments = (ins Variadic<PDL_PositionalValue>:$inputs,
+                       Variadic<PDL_Operation>:$matchedOps,
+                       SymbolRefAttr:$rewriter,
+                       OptionalAttr<StrAttr>:$rootKind,
+                       OptionalAttr<StrArrayAttr>:$generatedOps,
+                       Confined<I16Attr, [IntNonNegative]>:$benefit);
+  let successors = (successor AnySuccessor:$dest);
+  let assemblyFormat = [{
+    $rewriter (`(` $inputs^ `:` type($inputs) `)`)? `:`
+    `benefit` `(` $benefit `)` `,`
+    (`generatedOps` `(` $generatedOps^ `)` `,`)?
+    `loc` `(` `[` $matchedOps `]` `)`
+    (`,` `root` `(` $rootKind^ `)`)? attr-dict `->` $dest
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ReplaceOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_ReplaceOp : PDLInterp_Op<"replace"> {
+  let summary = "Mark an operation as `replace`d";
+  let description = [{
+    `pdl_interp.replaced` operations are used to specify that an operation
+    should be marked as replaced. The semantics of this operation correspond
+    with the `replaceOp` method on a `PatternRewriter`. The set of replacement
+    values must match the number of results specified by the operation.
+
+    Example:
+
+    ```mlir
+    // Replace root node with 2 values:
+    pdl_interp.replace %root with (%val0, %val1)
+    ```
+  }];
+  let arguments = (ins PDL_Operation:$operation,
+                       Variadic<PDL_Value>:$replValues);
+  let assemblyFormat = "$operation `with` `(` $replValues `)` attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchAttributeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchAttributeOp
+    : PDLInterp_SwitchOp<"switch_attribute", [NoSideEffect]> {
+  let summary = "Switch on the value of an `Attribute`";
+  let description = [{
+    `pdl_interp.switch_attribute` operations compare the value of a given
+    attribute with a set of constant attributes. If the value matches one of the
+    provided case values the destination for that case value is taken, otherwise
+    the default destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.switch_attribute %attr to [10, true] -> ^10Dest, ^trueDest, ^defaultDest
+    ```
+  }];
+  let arguments = (ins PDL_Attribute:$attribute, ArrayAttr:$caseValues);
+  let assemblyFormat = [{
+    $attribute `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+  }];
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, Value attribute,"
+              "ArrayRef<Attribute> caseValues,"
+              "Block *defaultDest, ArrayRef<Block *> dests", [{
+    build(builder, state, attribute, builder.getArrayAttr(caseValues),
+          defaultDest, dests);
+  }]>];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchOperandCountOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchOperandCountOp
+    : PDLInterp_SwitchOp<"switch_operand_count", [NoSideEffect]> {
+  let summary = "Switch on the operand count of an `Operation`";
+  let description = [{
+    `pdl_interp.switch_operand_count` operations compare the operand count of a
+    given operation with a set of potential counts. If the value matches one of
+    the provided case values the destination for that case value is taken,
+    otherwise the default destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.switch_operand_count of %op to [10, 2] -> ^10Dest, ^2Dest, ^defaultDest
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation, I32ElementsAttr:$caseValues);
+  let assemblyFormat = [{
+    `of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+  }];
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, "
+              "ArrayRef<int32_t> counts, Block *defaultDest, "
+              "ArrayRef<Block *> dests", [{
+    build(builder, state, operation, builder.getI32VectorAttr(counts),
+          defaultDest, dests);
+  }]>];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchOperationNameOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchOperationNameOp
+    : PDLInterp_SwitchOp<"switch_operation_name", [NoSideEffect]> {
+  let summary = "Switch on the OperationName of an `Operation`";
+  let description = [{
+    `pdl_interp.switch_operation_name` operations compare the name of a given
+    operation with a set of known names. If the value matches one of the
+    provided case values the destination for that case value is taken, otherwise
+    the default destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"] -> ^fooDest, ^barDest, ^defaultDest
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation,
+                       StrArrayAttr:$caseValues);
+  let assemblyFormat = [{
+    `of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+  }];
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, "
+              "ArrayRef<OperationName> names, "
+              "Block *defaultDest, ArrayRef<Block *> dests", [{
+      auto stringNames = llvm::to_vector<8>(llvm::map_range(names,
+          [](OperationName name) { return name.getStringRef(); }));
+      build(builder, state, operation, builder.getStrArrayAttr(stringNames),
+            defaultDest, dests);
+    }]>,
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchResultCountOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchResultCountOp
+    : PDLInterp_SwitchOp<"switch_result_count", [NoSideEffect]> {
+  let summary = "Switch on the result count of an `Operation`";
+  let description = [{
+    `pdl_interp.switch_result_count` operations compare the result count of a
+    given operation with a set of potential counts. If the value matches one of
+    the provided case values the destination for that case value is taken,
+    otherwise the default destination is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.switch_result_count of %op to [0, 2] -> ^0Dest, ^2Dest, ^defaultDest
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$operation, I32ElementsAttr:$caseValues);
+  let assemblyFormat = [{
+    `of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+  }];
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, "
+              "ArrayRef<int32_t> counts, Block *defaultDest, "
+              "ArrayRef<Block *> dests", [{
+    build(builder, state, operation, builder.getI32VectorAttr(counts),
+          defaultDest, dests);
+  }]>];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchTypeOp : PDLInterp_SwitchOp<"switch_type", [NoSideEffect]> {
+  let summary = "Switch on a `Type` value";
+  let description = [{
+    `pdl_interp.switch_type` operations compare a type with a set of statically
+    known types. If the value matches one of the provided case values the
+    destination for that case value is taken, otherwise the default destination
+    is taken.
+
+    Example:
+
+    ```mlir
+    pdl_interp.switch_type %type to [i32, i64] -> ^i32Dest, ^i64Dest, ^defaultDest
+    ```
+  }];
+
+  let arguments = (ins PDL_Type:$value, TypeArrayAttr:$caseValues);
+  let assemblyFormat = [{
+    $value `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+  }];
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state, Value edge, "
+              "TypeRange types, Block *defaultDest, ArrayRef<Block *> dests", [{
+      build(builder, state, edge, builder.getTypeArrayAttr(types), defaultDest,
+            dests);
+    }]>,
+  ];
+
+  let extraClassDeclaration = [{
+    auto getCaseTypes() { return caseValues().getAsValueRange<TypeAttr>(); }
+  }];
+}
+
+#endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS
index aa8f2ea..d1b25cd 100644 (file)
@@ -217,12 +217,12 @@ private:
 
 public:
   template <typename AttrTy>
-  llvm::iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
+  iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
     return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
                             attr_value_iterator<AttrTy>(end()));
   }
-  template <typename AttrTy, typename UnderlyingTy>
-  auto getAsRange() {
+  template <typename AttrTy, typename UnderlyingTy = typename AttrTy::ValueType>
+  auto getAsValueRange() {
     return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) {
       return static_cast<UnderlyingTy>(attr.getValue());
     });
@@ -589,6 +589,9 @@ public:
   /// Returns the number of elements held by this attribute.
   int64_t getNumElements() const;
 
+  /// Returns the number of elements held by this attribute.
+  int64_t size() const { return getNumElements(); }
+
   /// Generates a new ElementsAttr by mapping each int value to a new
   /// underlying APInt. The new values can represent either an integer or float.
   /// This ElementsAttr should contain integers.
index c27585a..aa1cc0a 100644 (file)
@@ -139,6 +139,7 @@ public:
   ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
   ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
   ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
+  ArrayAttr getTypeArrayAttr(TypeRange values);
 
   // Affine expressions and affine maps.
   AffineExpr getAffineDimExpr(unsigned position);
index 5bb4eff..e0726a9 100644 (file)
@@ -426,6 +426,12 @@ public:
     return parseOptionalAttribute(result, Type(), attrName, attrs);
   }
 
+  /// Specialized variants of `parseOptionalAttribute` that remove potential
+  /// ambiguities in syntax.
+  virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
+                                                     StringRef attrName,
+                                                     NamedAttrList &attrs) = 0;
+
   /// Parse an arbitrary attribute of a given type and return it in result. This
   /// also adds the attribute to the specified attribute list with the specified
   /// name.
index deb27bb..190486a 100644 (file)
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SDBM/SDBMDialect.h"
@@ -49,6 +50,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   scf::SCFDialect,
                   omp::OpenMPDialect,
                   pdl::PDLDialect,
+                  pdl_interp::PDLInterpDialect,
                   quant::QuantizationDialect,
                   spirv::SPIRVDialect,
                   StandardOpsDialect,
index 3681763..790264f 100644 (file)
@@ -6,6 +6,7 @@ add_subdirectory(LLVMIR)
 add_subdirectory(OpenACC)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)
+add_subdirectory(PDLInterp)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(SDBM)
index 559d411..c8e20ce 100644 (file)
@@ -76,9 +76,7 @@ static LogicalResult isContraction(Operation *op) {
   if (!genericOp)
     return failure();
 
-  auto mapRange =
-      genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>();
-
+  auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>();
   return success(
       genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
       llvm::all_of(mapRange,
index fc28e69..0146f0d 100644 (file)
@@ -446,20 +446,39 @@ static LogicalResult verify(ReplaceOp op) {
 //===----------------------------------------------------------------------===//
 
 static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) {
-  // If the first token isn't a '(', this is an external rewrite.
-  StringAttr nameAttr;
-  if (failed(p.parseOptionalLParen())) {
-    if (p.parseAttribute(nameAttr, "name", state.attributes) || p.parseLParen())
-      return failure();
-  }
-
   // Parse the root operand.
   OpAsmParser::OperandType rootOperand;
-  if (p.parseOperand(rootOperand) || p.parseRParen() ||
+  if (p.parseOperand(rootOperand) ||
       p.resolveOperand(rootOperand, p.getBuilder().getType<OperationType>(),
                        state.operands))
     return failure();
 
+  // Parse an external rewrite.
+  StringAttr nameAttr;
+  if (succeeded(p.parseOptionalKeyword("with"))) {
+    if (p.parseAttribute(nameAttr, "name", state.attributes))
+      return failure();
+
+    // Parse the optional set of constant parameters.
+    ArrayAttr constantParams;
+    OptionalParseResult constantParamResult = p.parseOptionalAttribute(
+        constantParams, "externalConstParams", state.attributes);
+    if (constantParamResult.hasValue() && failed(*constantParamResult))
+      return failure();
+
+    // Parse the optional additional arguments.
+    if (succeeded(p.parseOptionalLParen())) {
+      SmallVector<OpAsmParser::OperandType, 4> arguments;
+      SmallVector<Type, 4> argumentTypes;
+      llvm::SMLoc argumentLoc = p.getCurrentLocation();
+      if (p.parseOperandList(arguments) ||
+          p.parseColonTypeList(argumentTypes) || p.parseRParen() ||
+          p.resolveOperands(arguments, argumentTypes, argumentLoc,
+                            state.operands))
+        return failure();
+    }
+  }
+
   // If this isn't an external rewrite, parse the region body.
   Region &rewriteRegion = *state.addRegion();
   if (!nameAttr) {
@@ -468,27 +487,58 @@ static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) {
       return failure();
     RewriteOp::ensureTerminator(rewriteRegion, p.getBuilder(), state.location);
   }
-  return success();
+
+  return p.parseOptionalAttrDictWithKeyword(state.attributes);
 }
 
 static void print(OpAsmPrinter &p, RewriteOp op) {
-  p << "pdl.rewrite";
+  p << "pdl.rewrite " << op.root();
   if (Optional<StringRef> name = op.name()) {
-    p << " \"" << *name << "\"(" << op.root() << ")";
-    return;
+    p << " with \"" << *name << "\"";
+
+    if (ArrayAttr constantParams = op.externalConstParamsAttr())
+      p << constantParams;
+
+    OperandRange externalArgs = op.externalArgs();
+    if (!externalArgs.empty())
+      p << "(" << externalArgs << " : " << externalArgs.getTypes() << ")";
+  } else {
+    p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
+                  /*printBlockTerminators=*/false);
   }
 
-  p << "(" << op.root() << ")";
-  p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/false);
+  p.printOptionalAttrDictWithKeyword(op.getAttrs(),
+                                     {"name", "externalConstParams"});
 }
 
 static LogicalResult verify(RewriteOp op) {
   Region &rewriteRegion = op.body();
-  if (llvm::hasNItemsOrMore(rewriteRegion, 2)) {
-    return op.emitOpError()
-           << "expected rewrite region when specified to have a single block";
+
+  // Handle the case where the rewrite is external.
+  if (op.name()) {
+    if (!rewriteRegion.empty()) {
+      return op.emitOpError()
+             << "expected rewrite region to be empty when rewrite is external";
+    }
+    return success();
+  }
+
+  // Otherwise, check that the rewrite region only contains a single block.
+  if (rewriteRegion.empty()) {
+    return op.emitOpError() << "expected rewrite region to be non-empty if "
+                               "external name is not specified";
   }
+
+  // Check that no additional arguments were provided.
+  if (!op.externalArgs().empty()) {
+    return op.emitOpError() << "expected no external arguments when the "
+                               "rewrite is specified inline";
+  }
+  if (op.externalConstParams()) {
+    return op.emitOpError() << "expected no external constant parameters when "
+                               "the rewrite is specified inline";
+  }
+
   return success();
 }
 
diff --git a/mlir/lib/Dialect/PDLInterp/CMakeLists.txt b/mlir/lib/Dialect/PDLInterp/CMakeLists.txt
new file mode 100644 (file)
index 0000000..f33061b
--- /dev/null
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/PDLInterp/IR/CMakeLists.txt b/mlir/lib/Dialect/PDLInterp/IR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..6e0ebc6
--- /dev/null
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRPDLInterp
+  PDLInterp.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/PDLInterp
+
+  DEPENDS
+  MLIRPDLInterpOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRPDL
+  MLIRInferTypeOpInterface
+  MLIRSideEffectInterfaces
+  )
diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
new file mode 100644 (file)
index 0000000..2119d7a
--- /dev/null
@@ -0,0 +1,122 @@
+//===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::pdl_interp;
+
+//===----------------------------------------------------------------------===//
+// PDLInterp Dialect
+//===----------------------------------------------------------------------===//
+
+void PDLInterpDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
+      >();
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateOperationOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCreateOperationOp(OpAsmParser &p,
+                                          OperationState &state) {
+  if (p.parseOptionalAttrDict(state.attributes))
+    return failure();
+  Builder &builder = p.getBuilder();
+
+  // Parse the operation name.
+  StringAttr opName;
+  if (p.parseAttribute(opName, "name", state.attributes))
+    return failure();
+
+  // Parse the operands.
+  SmallVector<OpAsmParser::OperandType, 4> operands;
+  if (p.parseLParen() || p.parseOperandList(operands) || p.parseRParen() ||
+      p.resolveOperands(operands, builder.getType<pdl::ValueType>(),
+                        state.operands))
+    return failure();
+
+  // Parse the attributes.
+  SmallVector<Attribute, 4> attrNames;
+  if (succeeded(p.parseOptionalLBrace())) {
+    SmallVector<OpAsmParser::OperandType, 4> attrOps;
+    do {
+      StringAttr nameAttr;
+      OpAsmParser::OperandType operand;
+      if (p.parseAttribute(nameAttr) || p.parseEqual() ||
+          p.parseOperand(operand))
+        return failure();
+      attrNames.push_back(nameAttr);
+      attrOps.push_back(operand);
+    } while (succeeded(p.parseOptionalComma()));
+
+    if (p.parseRBrace() ||
+        p.resolveOperands(attrOps, builder.getType<pdl::AttributeType>(),
+                          state.operands))
+      return failure();
+  }
+  state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
+  state.addTypes(builder.getType<pdl::OperationType>());
+
+  // Parse the result types.
+  SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
+  if (p.parseArrow())
+    return failure();
+  if (succeeded(p.parseOptionalLParen())) {
+    if (p.parseRParen())
+      return failure();
+  } else if (p.parseOperandList(opResultTypes) ||
+             p.resolveOperands(opResultTypes, builder.getType<pdl::TypeType>(),
+                               state.operands)) {
+    return failure();
+  }
+
+  int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
+                                   static_cast<int32_t>(attrNames.size()),
+                                   static_cast<int32_t>(opResultTypes.size())};
+  state.addAttribute("operand_segment_sizes",
+                     builder.getI32VectorAttr(operandSegmentSizes));
+  return success();
+}
+
+static void print(OpAsmPrinter &p, CreateOperationOp op) {
+  p << "pdl_interp.create_operation ";
+  p.printOptionalAttrDict(op.getAttrs(),
+                          {"attributeNames", "name", "operand_segment_sizes"});
+  p << '"' << op.name() << "\"(" << op.operands() << ')';
+
+  // Emit the optional attributes.
+  ArrayAttr attrNames = op.attributeNames();
+  if (!attrNames.empty()) {
+    Operation::operand_range attrArgs = op.attributes();
+    p << " {";
+    interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
+                    [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
+    p << '}';
+  }
+
+  // Print the result type constraints of the operation.
+  auto types = op.types();
+  if (types.empty())
+    p << " -> ()";
+  else
+    p << " -> " << op.types();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen Auto-Generated Op and Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
index c45d031..dfd7b10 100644 (file)
@@ -261,6 +261,12 @@ ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
   return getArrayAttr(attrs);
 }
 
+ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
+  auto attrs = llvm::to_vector<8>(llvm::map_range(
+      values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
+  return getArrayAttr(attrs);
+}
+
 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
   auto attrs = llvm::to_vector<8>(llvm::map_range(
       values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
index 37ee938..b7cae27 100644 (file)
@@ -221,6 +221,9 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
     return result;
   }
 }
+OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute) {
+  return parseOptionalAttributeWithToken(Token::l_square, attribute);
+}
 
 /// Attribute dictionary.
 ///
index 0d3c659..32d11e5 100644 (file)
@@ -1045,15 +1045,37 @@ public:
   }
 
   /// Parse an optional attribute.
-  OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
-                                             StringRef attrName,
-                                             NamedAttrList &attrs) override {
+  /// Template utilities to simplify specifying multiple derived overloads.
+  template <typename AttrT>
+  OptionalParseResult
+  parseOptionalAttributeAndAddToList(AttrT &result, Type type,
+                                     StringRef attrName, NamedAttrList &attrs) {
     OptionalParseResult parseResult =
         parser.parseOptionalAttribute(result, type);
     if (parseResult.hasValue() && succeeded(*parseResult))
       attrs.push_back(parser.builder.getNamedAttr(attrName, result));
     return parseResult;
   }
+  template <typename AttrT>
+  OptionalParseResult parseOptionalAttributeAndAddToList(AttrT &result,
+                                                         StringRef attrName,
+                                                         NamedAttrList &attrs) {
+    OptionalParseResult parseResult = parser.parseOptionalAttribute(result);
+    if (parseResult.hasValue() && succeeded(*parseResult))
+      attrs.push_back(parser.builder.getNamedAttr(attrName, result));
+    return parseResult;
+  }
+
+  OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
+                                             StringRef attrName,
+                                             NamedAttrList &attrs) override {
+    return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
+  }
+  OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
+                                             StringRef attrName,
+                                             NamedAttrList &attrs) override {
+    return parseOptionalAttributeAndAddToList(result, attrName, attrs);
+  }
 
   /// Parse a named dictionary into 'result' if it is present.
   ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
index 3b2c6e8..61e54be 100644 (file)
@@ -187,6 +187,22 @@ public:
   /// Parse an optional attribute with the provided type.
   OptionalParseResult parseOptionalAttribute(Attribute &attribute,
                                              Type type = {});
+  OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute);
+
+  /// Parse an optional attribute that is demarcated by a specific token.
+  template <typename AttributeT>
+  OptionalParseResult parseOptionalAttributeWithToken(Token::Kind kind,
+                                                      AttributeT &attr,
+                                                      Type type = {}) {
+    if (getToken().isNot(kind))
+      return llvm::None;
+
+    if (Attribute parsedAttr = parseAttribute()) {
+      attr = parsedAttr.cast<ArrayAttr>();
+      return success();
+    }
+    return failure();
+  }
 
   /// Parse an attribute dictionary.
   ParseResult parseAttributeDict(NamedAttrList &attributes);
index 7058d8b..f5c6540 100644 (file)
@@ -9,7 +9,7 @@ pdl.pattern : benefit(1) {
 
   // expected-error@below {{expected at least one argument}}
   "pdl.apply_constraint"() {name = "foo", params = []} : () -> ()
-  pdl.rewrite "rewriter"(%op)
+  pdl.rewrite %op with "rewriter"
 }
 
 // -----
@@ -25,14 +25,14 @@ pdl.pattern : benefit(1) {
   %attr = pdl.attribute : %type 10
 
   %op, %result = pdl.operation "foo.op" {"attr" = %attr} -> %type
-  pdl.rewrite "rewriter"(%op)
+  pdl.rewrite %op with "rewriter"
 }
 
 // -----
 
 pdl.pattern : benefit(1) {
   %op = pdl.operation "foo.op"
-  pdl.rewrite(%op) {
+  pdl.rewrite %op {
     %type = pdl.type
 
     // expected-error@below {{expected constant value when specified within a `pdl.rewrite`}}
@@ -44,7 +44,7 @@ pdl.pattern : benefit(1) {
 
 pdl.pattern : benefit(1) {
   %op = pdl.operation "foo.op"
-  pdl.rewrite(%op) {
+  pdl.rewrite %op {
     // expected-error@below {{expected constant value when specified within a `pdl.rewrite`}}
     %attr = pdl.attribute
   }
@@ -57,7 +57,7 @@ pdl.pattern : benefit(1) {
   %unused = pdl.attribute
 
   %op = pdl.operation "foo.op"
-  pdl.rewrite "rewriter"(%op)
+  pdl.rewrite %op with "rewriter"
 }
 
 // -----
@@ -71,7 +71,7 @@ pdl.pattern : benefit(1) {
   %unused = pdl.input
 
   %op = pdl.operation "foo.op"
-  pdl.rewrite "rewriter"(%op)
+  pdl.rewrite %op with "rewriter"
 }
 
 // -----
@@ -82,7 +82,7 @@ pdl.pattern : benefit(1) {
 
 pdl.pattern : benefit(1) {
   %op = pdl.operation "foo.op"
-  pdl.rewrite(%op) {
+  pdl.rewrite %op {
     // expected-error@below {{must have an operation name when nested within a `pdl.rewrite`}}
     %newOp = pdl.operation
   }
@@ -96,14 +96,14 @@ pdl.pattern : benefit(1) {
     attributeNames = ["attr"],
     operand_segment_sizes = dense<0> : vector<3xi32>
   } : () -> (!pdl.operation)
-  pdl.rewrite "rewriter"(%op)
+  pdl.rewrite %op with "rewriter"
 }
 
 // -----
 
 pdl.pattern : benefit(1) {
   %op = pdl.operation "foo.op"()
-  pdl.rewrite (%op) {
+  pdl.rewrite %op {
     %type = pdl.type
 
     // expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}}
@@ -119,7 +119,7 @@ pdl.pattern : benefit(1) {
   %unused = pdl.operation "foo.op"
 
   %op = pdl.operation "foo.op"
-  pdl.rewrite "rewriter"(%op)
+  pdl.rewrite %op with "rewriter"
 }
 
 // -----
@@ -142,7 +142,7 @@ pdl.pattern : benefit(1) {
   "foo.other_op"() : () -> ()
 
   %root = pdl.operation "foo.op"
-  pdl.rewrite "foo"(%root)
+  pdl.rewrite %root with "foo"
 }
 
 // -----
@@ -153,7 +153,7 @@ pdl.pattern : benefit(1) {
 
 pdl.pattern : benefit(1) {
   %root = pdl.operation "foo.op"
-  pdl.rewrite (%root) {
+  pdl.rewrite %root {
     %type = pdl.type : i32
     %newOp, %newResult = pdl.operation "foo.op" -> %type
 
@@ -167,7 +167,7 @@ pdl.pattern : benefit(1) {
 pdl.pattern : benefit(1) {
   %type = pdl.type : i32
   %root, %oldResult = pdl.operation "foo.op" -> %type
-  pdl.rewrite (%root) {
+  pdl.rewrite %root {
     %newOp, %newResult = pdl.operation "foo.op" -> %type
 
     // expected-error@below {{expected no replacement values to be provided when the replacement operation is present}}
@@ -181,7 +181,7 @@ pdl.pattern : benefit(1) {
 
 pdl.pattern : benefit(1) {
   %root = pdl.operation "foo.op"
-  pdl.rewrite (%root) {
+  pdl.rewrite %root {
     %type = pdl.type : i32
     %newOp, %newResult = pdl.operation "foo.op" -> %type
 
@@ -193,6 +193,55 @@ pdl.pattern : benefit(1) {
 // -----
 
 //===----------------------------------------------------------------------===//
+// pdl::RewriteOp
+//===----------------------------------------------------------------------===//
+
+pdl.pattern : benefit(1) {
+  %op = pdl.operation "foo.op"
+
+  // expected-error@below {{expected rewrite region to be non-empty if external name is not specified}}
+  "pdl.rewrite"(%op) ({}) : (!pdl.operation) -> ()
+}
+
+// -----
+
+pdl.pattern : benefit(1) {
+  %op = pdl.operation "foo.op"
+
+  // expected-error@below {{expected no external arguments when the rewrite is specified inline}}
+  "pdl.rewrite"(%op, %op) ({
+    ^bb1:
+      pdl.rewrite_end
+  }) : (!pdl.operation, !pdl.operation) -> ()
+}
+
+// -----
+
+pdl.pattern : benefit(1) {
+  %op = pdl.operation "foo.op"
+
+  // expected-error@below {{expected no external constant parameters when the rewrite is specified inline}}
+  "pdl.rewrite"(%op) ({
+    ^bb1:
+      pdl.rewrite_end
+  }) {externalConstParams = []} : (!pdl.operation) -> ()
+}
+
+// -----
+
+pdl.pattern : benefit(1) {
+  %op = pdl.operation "foo.op"
+
+  // expected-error@below {{expected rewrite region to be empty when rewrite is external}}
+  "pdl.rewrite"(%op) ({
+    ^bb1:
+      pdl.rewrite_end
+  }) {name = "foo"} : (!pdl.operation) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // pdl::TypeOp
 //===----------------------------------------------------------------------===//
 
@@ -201,5 +250,5 @@ pdl.pattern : benefit(1) {
   %unused = pdl.type
 
   %op = pdl.operation "foo.op"
-  pdl.rewrite "rewriter"(%op)
+  pdl.rewrite %op with "rewriter"
 }
index f7c425f..37db36e 100644 (file)
@@ -1,8 +1,6 @@
 // RUN: mlir-opt -split-input-file %s | mlir-opt
-// Verify the printed output can be parsed.
-// RUN: mlir-opt %s | mlir-opt
 // Verify the generic form can be parsed.
-// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt
+// RUN: mlir-opt -split-input-file -mlir-print-op-generic %s | mlir-opt
 
 // -----
 
@@ -15,7 +13,30 @@ pdl.pattern @operations : benefit(1) {
   // Operation with input.
   %input = pdl.input
   %root = pdl.operation(%op0_result, %input)
-  pdl.rewrite "rewriter"(%root)
+  pdl.rewrite %root with "rewriter"
+}
+
+// -----
+
+pdl.pattern @rewrite_with_args : benefit(1) {
+  %input = pdl.input
+  %root = pdl.operation(%input)
+  pdl.rewrite %root with "rewriter"(%input : !pdl.value)
+}
+
+// -----
+
+pdl.pattern @rewrite_with_params : benefit(1) {
+  %root = pdl.operation
+  pdl.rewrite %root with "rewriter"["I am param"]
+}
+
+// -----
+
+pdl.pattern @rewrite_with_args_and_params : benefit(1) {
+  %input = pdl.input
+  %root = pdl.operation(%input)
+  pdl.rewrite %root with "rewriter"["I am param"](%input : !pdl.value)
 }
 
 // -----
@@ -26,7 +47,7 @@ pdl.pattern @infer_type_from_operation_replace : benefit(1) {
   %type1 = pdl.type : i32
   %type2 = pdl.type
   %root, %results:2 = pdl.operation -> %type1, %type2
-  pdl.rewrite(%root) {
+  pdl.rewrite %root {
     %type3 = pdl.type
     %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
     pdl.replace %root with %newOp
@@ -41,7 +62,7 @@ pdl.pattern @infer_type_from_result_replace : benefit(1) {
   %type1 = pdl.type : i32
   %type2 = pdl.type
   %root, %results:2 = pdl.operation -> %type1, %type2
-  pdl.rewrite(%root) {
+  pdl.rewrite %root {
     %type3 = pdl.type
     %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
     pdl.replace %root with (%newResults#0, %newResults#1)
@@ -56,7 +77,7 @@ pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
   %type1 = pdl.type : i32
   %type2 = pdl.type
   %root, %results:2 = pdl.operation -> %type1, %type2
-  pdl.rewrite(%root) {
+  pdl.rewrite %root {
     %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type2
   }
 }
diff --git a/mlir/test/Dialect/PDLInterp/ops.mlir b/mlir/test/Dialect/PDLInterp/ops.mlir
new file mode 100644 (file)
index 0000000..d76b17c
--- /dev/null
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -split-input-file %s | mlir-opt
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt
+
+// -----
+
+func @operations(%attribute: !pdl.attribute,
+                 %input: !pdl.value,
+                 %type: !pdl.type) {
+  // attributes, operands, and results
+  %op0 = pdl_interp.create_operation "foo.op"(%input) {"attr" = %attribute} -> %type
+
+  // attributes, and results
+  %op1 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute} -> %type
+
+  // attributes
+  %op2 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute, "attr1" = %attribute} -> ()
+
+  // operands, and results
+  %op3 = pdl_interp.create_operation "foo.op"(%input) -> %type
+
+  pdl_interp.finalize
+}
index 82a7312..1f3ac49 100644 (file)
@@ -226,7 +226,7 @@ bool LiteralElement::isValidLiteral(StringRef value) {
   // If there is only one character, this must either be punctuation or a
   // single character bare identifier.
   if (value.size() == 1)
-    return isalpha(front) || StringRef("_:,=<>()[]?").contains(front);
+    return isalpha(front) || StringRef("_:,=<>()[]{}?").contains(front);
 
   // Check the punctuation that are larger than a single character.
   if (value == "->")
@@ -583,6 +583,8 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
               .Case("=", "Equal()")
               .Case("<", "Less()")
               .Case(">", "Greater()")
+              .Case("{", "LBrace()")
+              .Case("}", "RBrace()")
               .Case("(", "LParen()")
               .Case(")", "RParen()")
               .Case("[", "LSquare()")