[mlir:PDL] Add support for creating ranges in rewrites
authorRiver Riddle <riddleriver@gmail.com>
Fri, 9 Sep 2022 23:31:24 +0000 (16:31 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 8 Nov 2022 09:57:57 +0000 (01:57 -0800)
This commit adds support for building a concatenated range from
a given set of elements, either single element or other ranges, within a
rewrite. We could conceptually extend this to support constraining
input ranges, but the logic there is quite a bit more complex so it is
left for later work when a need arises.

Differential Revision: https://reviews.llvm.org/D133719

12 files changed:
mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
mlir/test/Dialect/PDL/invalid.mlir
mlir/test/Dialect/PDLInterp/invalid.mlir
mlir/test/Rewrite/pdl-bytecode.mlir

index fbe991a..c85687e 100644 (file)
@@ -437,6 +437,48 @@ def PDL_PatternOp : PDL_Op<"pattern", [
 }
 
 //===----------------------------------------------------------------------===//
+// pdl::RangeOp
+//===----------------------------------------------------------------------===//
+
+def PDL_RangeOp : PDL_Op<"range", [Pure, HasParent<"pdl::RewriteOp">]> {
+  let summary = "Construct a range of pdl entities";
+  let description = [{
+    `pdl.range` operations construct a range from a given set of PDL entities,
+    which all share the same underlying element type. For example, a
+    `!pdl.range<value>` may be constructed from a list of `!pdl.value`
+    or `!pdl.range<value>` entities.
+
+    Example:
+
+    ```mlir
+    // Construct a range of values.
+    %valueRange = pdl.range %inputValue, %inputRange : !pdl.value, !pdl.range<value>
+
+    // Construct a range of types.
+    %typeRange = pdl.range %inputType, %inputRange : !pdl.type, !pdl.range<type>
+    
+    // Construct an empty range of types.
+    %valueRange = pdl.range : !pdl.range<type>
+    ```
+
+    TODO: Range construction is currently limited to rewrites, but it could
+    be extended to constraints under certain circustances; i.e., if we can
+    determine how to extract the underlying elements. If we can't, e.g. if
+    there are multiple sub ranges used for construction, we won't be able
+    to determine their sizes during constraint time.
+  }];
+
+  let arguments = (ins Variadic<PDL_AnyType>:$arguments);
+  let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result);
+  let assemblyFormat = [{
+    ($arguments^ `:` type($arguments))?
+    custom<RangeType>(ref(type($arguments)), type($result))
+    attr-dict
+  }];
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
 // pdl::ReplaceOp
 //===----------------------------------------------------------------------===//
 
index 8cbe31f..a342dcc 100644 (file)
@@ -28,6 +28,11 @@ public:
 
   static bool classof(Type type);
 };
+
+/// If the given type is a range, return its element type, otherwise return
+/// the type itself.
+Type getRangeElementTypeOrSelf(Type type);
+
 } // namespace pdl
 } // namespace mlir
 
index 659bfbc..96d631b 100644 (file)
@@ -992,6 +992,43 @@ def PDLInterp_IsNotNullOp
   let assemblyFormat = "$value `:` type($value) attr-dict `->` successors";
 }
 
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateRangeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateRangeOp : PDLInterp_Op<"create_range", [Pure]> {
+  let summary = "Construct a range of PDL entities";
+  let description = [{
+    `pdl_interp.create_range` operations construct a range from a given set of PDL
+    entities, which all share the same underlying element type. For example, a
+    `!pdl.range<value>` may be constructed from a list of `!pdl.value`
+    or `!pdl.range<value>` entities.
+
+    Example:
+
+    ```mlir
+    // Construct a range of values.
+    %valueRange = pdl_interp.create_range %inputValue, %inputRange : !pdl.value, !pdl.range<value>
+
+    // Construct a range of types.
+    %typeRange = pdl_interp.create_range %inputType, %inputRange : !pdl.type, !pdl.range<type>
+    
+    // Construct an empty range of types.
+    %valueRange = pdl_interp.create_range : !pdl.range<type>
+    ```
+  }];
+
+  let arguments = (ins Variadic<PDL_AnyType>:$arguments);
+  let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result);
+  let assemblyFormat = [{
+    ($arguments^ `:` type($arguments))?
+    custom<RangeType>(ref(type($arguments)), type($result))
+    attr-dict
+  }];
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::RecordMatchOp
 //===----------------------------------------------------------------------===//
index 987e7a3..fdc95ab 100644 (file)
@@ -89,6 +89,9 @@ private:
   void generateRewriter(pdl::OperationOp operationOp,
                         DenseMap<Value, Value> &rewriteValues,
                         function_ref<Value(Value)> mapRewriteValue);
+  void generateRewriter(pdl::RangeOp rangeOp,
+                        DenseMap<Value, Value> &rewriteValues,
+                        function_ref<Value(Value)> mapRewriteValue);
   void generateRewriter(pdl::ReplaceOp replaceOp,
                         DenseMap<Value, Value> &rewriteValues,
                         function_ref<Value(Value)> mapRewriteValue);
@@ -668,8 +671,8 @@ SymbolRefAttr PatternLowering::generateRewriter(
     for (Operation &rewriteOp : *rewriter.getBody()) {
       llvm::TypeSwitch<Operation *>(&rewriteOp)
           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
-                pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
-                pdl::TypeOp, pdl::TypesOp>([&](auto op) {
+                pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
+                pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
             this->generateRewriter(op, rewriteValues, mapRewriteValue);
           });
     }
@@ -776,6 +779,16 @@ void PatternLowering::generateRewriter(
 }
 
 void PatternLowering::generateRewriter(
+    pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
+    function_ref<Value(Value)> mapRewriteValue) {
+  SmallVector<Value, 4> replOperands;
+  for (Value operand : rangeOp.getArguments())
+    replOperands.push_back(mapRewriteValue(operand));
+  rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
+      rangeOp.getLoc(), rangeOp.getType(), replOperands);
+}
+
+void PatternLowering::generateRewriter(
     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
     function_ref<Value(Value)> mapRewriteValue) {
   SmallVector<Value, 4> replOperands;
index b96f34b..e33ba71 100644 (file)
@@ -398,6 +398,39 @@ StringRef PatternOp::getDefaultDialect() {
 }
 
 //===----------------------------------------------------------------------===//
+// pdl::RangeOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
+                                  Type &resultType) {
+  // If arguments were provided, infer the result type from the argument list.
+  if (!argumentTypes.empty()) {
+    resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0]));
+    return success();
+  }
+  // Otherwise, parse the type as a trailing type.
+  return p.parseColonType(resultType);
+}
+
+static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes,
+                           Type resultType) {
+  if (argumentTypes.empty())
+    p << ": " << resultType;
+}
+
+LogicalResult RangeOp::verify() {
+  Type elementType = getType().getElementType();
+  for (Type operandType : getOperandTypes()) {
+    Type operandElementType = getRangeElementTypeOrSelf(operandType);
+    if (operandElementType != elementType) {
+      return emitOpError("expected operand to have element type ")
+             << elementType << ", but got " << operandElementType;
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // pdl::ReplaceOp
 //===----------------------------------------------------------------------===//
 
index f4dbbfb..49eee1a 100644 (file)
@@ -59,6 +59,12 @@ bool PDLType::classof(Type type) {
   return llvm::isa<PDLDialect>(type.getDialect());
 }
 
+Type pdl::getRangeElementTypeOrSelf(Type type) {
+  if (auto rangeType = type.dyn_cast<RangeType>())
+    return rangeType.getElementType();
+  return type;
+}
+
 //===----------------------------------------------------------------------===//
 // RangeType
 //===----------------------------------------------------------------------===//
index 01670e3..e8a61ef 100644 (file)
@@ -238,6 +238,40 @@ static Type getGetValueTypeOpValueType(Type type) {
 }
 
 //===----------------------------------------------------------------------===//
+// pdl::CreateRangeOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
+                                  Type &resultType) {
+  // If arguments were provided, infer the result type from the argument list.
+  if (!argumentTypes.empty()) {
+    resultType =
+        pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
+    return success();
+  }
+  // Otherwise, parse the type as a trailing type.
+  return p.parseColonType(resultType);
+}
+
+static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
+                           TypeRange argumentTypes, Type resultType) {
+  if (argumentTypes.empty())
+    p << ": " << resultType;
+}
+
+LogicalResult CreateRangeOp::verify() {
+  Type elementType = getType().getElementType();
+  for (Type operandType : getOperandTypes()) {
+    Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
+    if (operandElementType != elementType) {
+      return emitOpError("expected operand to have element type ")
+             << elementType << ", but got " << operandElementType;
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // pdl_interp::SwitchAttributeOp
 //===----------------------------------------------------------------------===//
 
index 9cc51da..6b1dfb9 100644 (file)
@@ -99,10 +99,14 @@ enum OpCode : ByteCodeField {
   CheckTypes,
   /// Continue to the next iteration of a loop.
   Continue,
+  /// Create a type range from a list of constant types.
+  CreateConstantTypeRange,
   /// Create an operation.
   CreateOperation,
-  /// Create a range of types.
-  CreateTypes,
+  /// Create a type range from a list of dynamic types.
+  CreateDynamicTypeRange,
+  /// Create a value range.
+  CreateDynamicValueRange,
   /// Erase an operation.
   EraseOp,
   /// Extract the op from a range at the specified index.
@@ -265,6 +269,7 @@ private:
   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
@@ -742,9 +747,9 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
-            pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
-            pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
-            pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
+            pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
+            pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
+            pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
@@ -863,12 +868,24 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
   else
     writer.appendPDLValueList(op.getInputResultTypes());
 }
+void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
+  // Append the correct opcode for the range type.
+  TypeSwitch<Type>(op.getType().getElementType())
+      .Case(
+          [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
+      .Case([&](pdl::ValueType) {
+        writer.append(OpCode::CreateDynamicValueRange);
+      });
+
+  writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
+  writer.appendPDLValueList(op->getOperands());
+}
 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
   // Simply repoint the memory index of the result to the constant.
   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
 }
 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
-  writer.append(OpCode::CreateTypes, op.getResult(),
+  writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
                 getRangeStorageIndex(op.getResult()), op.getValue());
 }
 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
@@ -1103,9 +1120,11 @@ private:
   void executeCheckResultCount();
   void executeCheckTypes();
   void executeContinue();
+  void executeCreateConstantTypeRange();
   void executeCreateOperation(PatternRewriter &rewriter,
                               Location mainRewriteLoc);
-  void executeCreateTypes();
+  template <typename T>
+  void executeDynamicCreateRange(StringRef type);
   void executeEraseOp(PatternRewriter &rewriter);
   template <typename T, typename Range, PDLValue::Kind kind>
   void executeExtract();
@@ -1172,8 +1191,18 @@ private:
   }
 
   /// Read a list of values from the bytecode buffer. The values may be encoded
-  /// as either Value or ValueRange elements.
-  void readValueList(SmallVectorImpl<Value> &list) {
+  /// either as a single element or a range of elements.
+  void readList(SmallVectorImpl<Type> &list) {
+    for (unsigned i = 0, e = read(); i != e; ++i) {
+      if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
+        list.push_back(read<Type>());
+      } else {
+        TypeRange *values = read<TypeRange *>();
+        list.append(values->begin(), values->end());
+      }
+    }
+  }
+  void readList(SmallVectorImpl<Value> &list) {
     for (unsigned i = 0, e = read(); i != e; ++i) {
       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
         list.push_back(read<Value>());
@@ -1292,6 +1321,39 @@ private:
     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
   }
 
+  /// Assign the given range to the given memory index. This allocates a new
+  /// range object if necessary.
+  template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
+  void assignRangeToMemory(RangeT &&range, unsigned memIndex,
+                           unsigned rangeIndex) {
+    // Utility functor used to type-erase the assignment.
+    auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
+      // If the input range is empty, we don't need to allocate anything.
+      if (range.empty()) {
+        rangeMemory[rangeIndex] = {};
+      } else {
+        // Allocate a buffer for this type range.
+        llvm::OwningArrayRef<T> storage(llvm::size(range));
+        llvm::copy(range, storage.begin());
+
+        // Assign this to the range slot and use the range as the value for the
+        // memory index.
+        allocatedRangeMemory.emplace_back(std::move(storage));
+        rangeMemory[rangeIndex] = allocatedRangeMemory.back();
+      }
+      memory[memIndex] = &rangeMemory[rangeIndex];
+    };
+
+    // Dispatch based on the concrete range type.
+    if constexpr (std::is_same_v<T, Type>) {
+      return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
+    } else if constexpr (std::is_same_v<T, Value>) {
+      return assignRange(allocatedValueRangeMemory, valueRangeMemory);
+    } else {
+      llvm_unreachable("unhandled range type");
+    }
+  }
+
   /// The underlying bytecode buffer.
   const ByteCodeField *curCodeIt;
 
@@ -1514,23 +1576,15 @@ void ByteCodeExecutor::executeContinue() {
   popCodeIt();
 }
 
-void ByteCodeExecutor::executeCreateTypes() {
-  LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
+void ByteCodeExecutor::executeCreateConstantTypeRange() {
+  LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
   unsigned memIndex = read();
   unsigned rangeIndex = read();
   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
 
   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
-
-  // Allocate a buffer for this type range.
-  llvm::OwningArrayRef<Type> storage(typesAttr.size());
-  llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
-  allocatedTypeRangeMemory.emplace_back(std::move(storage));
-
-  // Assign this to the range slot and use the range as the value for the
-  // memory index.
-  typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
-  memory[memIndex] = &typeRangeMemory[rangeIndex];
+  assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
+                      rangeIndex);
 }
 
 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
@@ -1539,7 +1593,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
 
   unsigned memIndex = read();
   OperationState state(mainRewriteLoc, read<OperationName>());
-  readValueList(state.operands);
+  readList(state.operands);
   for (unsigned i = 0, e = read(); i != e; ++i) {
     StringAttr name = read<StringAttr>();
     if (Attribute attr = read<Attribute>())
@@ -1587,6 +1641,23 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
   });
 }
 
+template <typename T>
+void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
+  LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
+  unsigned memIndex = read();
+  unsigned rangeIndex = read();
+  SmallVector<T> values;
+  readList(values);
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "\n  * " << type << "s: ";
+    llvm::interleaveComma(values, llvm::dbgs());
+    llvm::dbgs() << "\n";
+  });
+
+  assignRangeToMemory(values, memIndex, rangeIndex);
+}
+
 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
   Operation *op = read<Operation *>();
@@ -1949,7 +2020,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
   Operation *op = read<Operation *>();
   SmallVector<Value, 16> args;
-  readValueList(args);
+  readList(args);
 
   LLVM_DEBUG({
     llvm::dbgs() << "  * Operation: " << *op << "\n"
@@ -2076,11 +2147,17 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
     case Continue:
       executeContinue();
       break;
+    case CreateConstantTypeRange:
+      executeCreateConstantTypeRange();
+      break;
     case CreateOperation:
       executeCreateOperation(rewriter, *mainRewriteLoc);
       break;
-    case CreateTypes:
-      executeCreateTypes();
+    case CreateDynamicTypeRange:
+      executeDynamicCreateRange<Type>("Type");
+      break;
+    case CreateDynamicValueRange:
+      executeDynamicCreateRange<Value>("Value");
       break;
     case EraseOp:
       executeEraseOp(rewriter);
index d6e8f4a..e5a84d6 100644 (file)
@@ -243,3 +243,20 @@ module @unbound_rewrite_op {
 }
 
 // -----
+
+// CHECK-LABEL: module @range_op
+module @range_op {
+  // CHECK: module @rewriters
+  // CHECK:   func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value)
+  // CHECK:     %[[RANGE1:.*]] = pdl_interp.create_range : !pdl.range<value>
+  // CHECK:     %[[RANGE2:.*]] = pdl_interp.create_range %[[OPERAND]], %[[RANGE1]] : !pdl.value, !pdl.range<value>
+  // CHECK:     pdl_interp.finalize
+  pdl.pattern : benefit(1) {
+    %operand = pdl.operand
+    %root = operation "foo.op"(%operand : !pdl.value)
+    rewrite %root {
+      %emptyRange = pdl.range : !pdl.range<value>
+      %range = pdl.range %operand, %emptyRange : !pdl.value, !pdl.range<value>
+    }
+  }
+}
index 61c0aae..522e9fb 100644 (file)
@@ -238,6 +238,23 @@ pdl.pattern : benefit(1) {
 // -----
 
 //===----------------------------------------------------------------------===//
+// pdl::RangeOp
+//===----------------------------------------------------------------------===//
+
+pdl.pattern : benefit(1) {
+  %operand = pdl.operand
+  %resultType = pdl.type
+  %root = pdl.operation "baz.op"(%operand : !pdl.value) -> (%resultType : !pdl.type)
+  
+  rewrite %root {
+    // expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}}
+    %range = pdl.range %operand, %resultType : !pdl.value, !pdl.type
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // pdl::ResultsOp
 //===----------------------------------------------------------------------===//
 
index f194d32..0457a15 100644 (file)
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 //===----------------------------------------------------------------------===//
-// pdl::CreateOperationOp
+// pdl_interp::CreateOperationOp
 //===----------------------------------------------------------------------===//
 
 pdl_interp.func @rewriter() {
@@ -23,3 +23,15 @@ pdl_interp.func @rewriter() {
   } : (!pdl.type) -> (!pdl.operation)
   pdl_interp.finalize
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateRangeOp
+//===----------------------------------------------------------------------===//
+
+pdl_interp.func @rewriter(%value: !pdl.value, %type: !pdl.type) {
+  // expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}}
+  %range = pdl_interp.create_range %value, %type : !pdl.value, !pdl.type
+  pdl_interp.finalize
+}
index 20e2490..565874f 100644 (file)
@@ -569,6 +569,48 @@ module @ir attributes { test.create_op_infer_results } {
 // -----
 
 //===----------------------------------------------------------------------===//
+// pdl_interp::CreateRangeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operand_count of %root is 2 -> ^pat1, ^end
+
+  ^pat1:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root: !pdl.operation) {
+      %rootOperand = pdl_interp.get_operand 0 of %root
+      %rootOperands = pdl_interp.get_operands of %root : !pdl.range<value>
+      %operandRange = pdl_interp.create_range %rootOperand, %rootOperands : !pdl.value, !pdl.range<value>
+      
+      %operandType = pdl_interp.get_value_type of %rootOperand : !pdl.type
+      %operandTypes = pdl_interp.get_value_type of %rootOperands : !pdl.range<type>
+      %typeRange = pdl_interp.create_range %operandType, %operandTypes : !pdl.type, !pdl.range<type>
+
+      %op = pdl_interp.create_operation "test.success"(%operandRange : !pdl.range<value>) -> (%typeRange : !pdl.range<type>)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.create_range_1
+// CHECK: %[[INPUTS:.*]]:2 = "test.input"()
+// CHECK: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#0, %[[INPUTS]]#1) : (i32, i32, i32) -> (i32, i32, i32)
+module @ir attributes { test.create_range_1 } {
+  %values:2 = "test.input"() : () -> (i32, i32)
+  "test.op"(%values#0, %values#1) : (i32, i32) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // pdl_interp::CreateTypeOp
 //===----------------------------------------------------------------------===//