From 3eb1647af036dc0e8370ed5a8b1ecbb5701f850b Mon Sep 17 00:00:00 2001 From: Stanislav Funiak Date: Fri, 26 Nov 2021 18:08:34 +0530 Subject: [PATCH] Introduced iterative bytecode execution. This is commit 2 of 4 for the multi-root matching in PDL, discussed in https://llvm.discourse.group/t/rfc-multi-root-pdl-patterns-for-kernel-matching/4148 (topic flagged for review). This commit implements the features needed for the execution of the new operations pdl_interp.get_accepting_ops, pdl_interp.choose_op: 1. The implementation of the generation and execution of the two ops. 2. The addition of Stack of bytecode positions within the ByteCodeExecutor. This is needed because in pdl_interp.choose_op, we iterate over the values returned by pdl_interp.get_accepting_ops until we reach finalize. When we reach finalize, we need to return back to the position marked in the stack. 3. The functionality to extend the lifetime of values that cross the nondeterministic choice. The existing bytecode generator allocates the values to memory positions by representing the liveness of values as a collection of disjoint intervals over the matcher positions. This is akin to register allocation, and substantially reduces the footprint of the bytecode executor. However, because with iterative operation pdl_interp.choose_op, execution "returns" back, so any values whose original liveness cross the nondeterminstic choice must have their lifetime executed until finalize. Testing: pdl-bytecode.mlir test Reviewed By: rriddle, Mogball Differential Revision: https://reviews.llvm.org/D108547 --- mlir/include/mlir/IR/PatternMatch.h | 8 + mlir/lib/IR/PatternMatch.cpp | 23 ++ mlir/lib/Rewrite/ByteCode.cpp | 412 ++++++++++++++++++++++++++++++------ mlir/lib/Rewrite/ByteCode.h | 16 ++ mlir/test/Rewrite/pdl-bytecode.mlir | 271 ++++++++++++++++++++++++ 5 files changed, 668 insertions(+), 62 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index ba8d145..d02bda7 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -446,6 +446,9 @@ public: /// Print this value to the provided output stream. void print(raw_ostream &os) const; + /// Print the specified value kind to an output stream. + static void print(raw_ostream &os, Kind kind); + private: /// Find the index of a given type in a range of other types. template @@ -491,6 +494,11 @@ inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { return os; } +inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) { + PDLValue::print(os, kind); + return os; +} + //===----------------------------------------------------------------------===// // PDLResultList diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 4482b5c..39d8bad 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -126,6 +126,29 @@ void PDLValue::print(raw_ostream &os) const { } } +void PDLValue::print(raw_ostream &os, Kind kind) { + switch (kind) { + case Kind::Attribute: + os << "Attribute"; + break; + case Kind::Operation: + os << "Operation"; + break; + case Kind::Type: + os << "Type"; + break; + case Kind::TypeRange: + os << "TypeRange"; + break; + case Kind::Value: + os << "Value"; + break; + case Kind::ValueRange: + os << "ValueRange"; + break; + } +} + //===----------------------------------------------------------------------===// // PDLPatternModule //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 810bcf6..380f54d 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -95,14 +95,24 @@ enum OpCode : ByteCodeField { CheckResultCount, /// Compare a range of types to a constant range of types. CheckTypes, + /// Continue to the next iteration of a loop. + Continue, /// Create an operation. CreateOperation, /// Create a range of types. CreateTypes, /// Erase an operation. EraseOp, + /// Extract the op from a range at the specified index. + ExtractOp, + /// Extract the type from a range at the specified index. + ExtractType, + /// Extract the value from a range at the specified index. + ExtractValue, /// Terminate a matcher or rewrite sequence. Finalize, + /// Iterate over a range of values. + ForEach, /// Get a specific attribute of an operation. GetAttribute, /// Get the type of an attribute. @@ -125,6 +135,8 @@ enum OpCode : ByteCodeField { GetResultN, /// Get a specific result group of an operation. GetResults, + /// Get the users of a value or a range of values. + GetUsers, /// Get the type of a value. GetValueType, /// Get the types of a value range. @@ -158,8 +170,13 @@ enum OpCode : ByteCodeField { // Generator namespace { +struct ByteCodeLiveRange; struct ByteCodeWriter; +/// Check if the given class `T` can be converted to an opaque pointer. +template +using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); + /// This class represents the main generator for the pattern bytecode. class Generator { public: @@ -168,15 +185,19 @@ public: SmallVectorImpl &rewriterByteCode, SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, + ByteCodeField &maxOpRangeMemoryIndex, ByteCodeField &maxTypeRangeMemoryIndex, ByteCodeField &maxValueRangeMemoryIndex, + ByteCodeField &maxLoopLevel, llvm::StringMap &constraintFns, llvm::StringMap &rewriteFns) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), + maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), - maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { + maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), + maxLoopLevel(maxLoopLevel) { for (auto it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (auto it : llvm::enumerate(rewriteFns)) @@ -221,6 +242,7 @@ private: void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); /// Generate the bytecode for the given operation. + void generate(Region *region, ByteCodeWriter &writer); void generate(Operation *op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); @@ -232,12 +254,15 @@ private: void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); @@ -245,6 +270,7 @@ private: void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); @@ -279,17 +305,25 @@ private: /// `uniquedData`. DenseMap uniquedDataToMemIndex; + /// The current level of the foreach loop. + ByteCodeField curLoopLevel = 0; + /// The current MLIR context. MLIRContext *ctx; + /// Mapping from block to its address. + DenseMap blockToAddr; + /// Data of the ByteCode class to be populated. std::vector &uniquedData; SmallVectorImpl &matcherByteCode; SmallVectorImpl &rewriterByteCode; SmallVectorImpl &patterns; ByteCodeField &maxValueMemoryIndex; + ByteCodeField &maxOpRangeMemoryIndex; ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; + ByteCodeField &maxLoopLevel; }; /// This class provides utilities for writing a bytecode stream. @@ -311,15 +345,20 @@ struct ByteCodeWriter { bytecode.append({fieldParts[0], fieldParts[1]}); } + /// Append a single successor to the bytecode, the exact address will need to + /// be resolved later. + void append(Block *successor) { + // Add back a reference to the successor so that the address can be resolved + // later. + unresolvedSuccessorRefs[successor].push_back(bytecode.size()); + append(ByteCodeAddr(0)); + } + /// Append a successor range to the bytecode, the exact address will need to /// be resolved later. void append(SuccessorRange successors) { - // Add back references to the any successors so that the address can be - // resolved later. - for (Block *successor : successors) { - unresolvedSuccessorRefs[successor].push_back(bytecode.size()); - append(ByteCodeAddr(0)); - } + for (Block *successor : successors) + append(successor); } /// Append a range of values that will be read as generic PDLValues. @@ -336,10 +375,12 @@ struct ByteCodeWriter { } /// Append the PDLValue::Kind of the given value. - void appendPDLValueKind(Value value) { - // Append the type of the value in addition to the value itself. + void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } + + /// Append the PDLValue::Kind of the given type. + void appendPDLValueKind(Type type) { PDLValue::Kind kind = - TypeSwitch(value.getType()) + TypeSwitch(type) .Case( [](Type) { return PDLValue::Kind::Attribute; }) .Case( @@ -354,10 +395,6 @@ struct ByteCodeWriter { bytecode.push_back(static_cast(kind)); } - /// Check if the given class `T` has an iterator type. - template - using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); - /// Append a value that will be stored in a memory slot and not inline within /// the bytecode. template @@ -396,25 +433,34 @@ struct ByteCodeWriter { /// This class represents a live range of PDL Interpreter values, containing /// information about when values are live within a match/rewrite. struct ByteCodeLiveRange { - using Set = llvm::IntervalMap; + using Set = llvm::IntervalMap; using Allocator = Set::Allocator; - ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} + ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} /// Union this live range with the one provided. void unionWith(const ByteCodeLiveRange &rhs) { - for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) - liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); + for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; + ++it) + liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); } /// Returns true if this range overlaps with the one provided. bool overlaps(const ByteCodeLiveRange &rhs) const { - return llvm::IntervalMapOverlaps(liveness, rhs.liveness).valid(); + return llvm::IntervalMapOverlaps(*liveness, *rhs.liveness) + .valid(); } /// A map representing the ranges of the match/rewrite that a value is live in /// the interpreter. - llvm::IntervalMap liveness; + /// + /// We use std::unique_ptr here, because IntervalMap does not provide a + /// correct copy or move constructor. We can eliminate the pointer once + /// https://reviews.llvm.org/D113240 lands. + std::unique_ptr> liveness; + + /// The operation range storage index for this range. + Optional opRangeIndex; /// The type range storage index for this range. Optional typeRangeIndex; @@ -446,15 +492,8 @@ void Generator::generate(ModuleOp module) { "unexpected branches in rewriter function"); // Generate code for the matcher function. - DenseMap blockToAddr; - llvm::ReversePostOrderTraversal rpot(&matcherFunc.getBody()); ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); - for (Block *block : rpot) { - // Keep track of where this block begins within the matcher function. - blockToAddr.try_emplace(block, matcherByteCode.size()); - for (Operation &op : *block) - generate(&op, matcherByteCodeWriter); - } + generate(&matcherFunc.getBody(), matcherByteCodeWriter); // Resolve successor references in the matcher. for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { @@ -501,7 +540,7 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, // finding the minimal number of overlapping live ranges. This is essentially // a simplified form of register allocation where we don't necessarily have a // limited number of registers, but we still want to minimize the number used. - DenseMap opToIndex; + DenseMap opToIndex; matcherFunc.getBody().walk([&](Operation *op) { opToIndex.insert(std::make_pair(op, opToIndex.size())); }); @@ -516,8 +555,8 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, // Walk each of the blocks, computing the def interval that the value is used. Liveness matcherLiveness(matcherFunc); - for (Block &block : matcherFunc.getBody()) { - const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); + matcherFunc->walk([&](Block *block) { + const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); assert(info && "expected liveness info for block"); auto processValue = [&](Value value, Operation *firstUseOrDef) { // We don't need to process the root op argument, this value is always @@ -527,7 +566,7 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; - defRangeIt->second.liveness.insert( + defRangeIt->second.liveness->insert( opToIndex[firstUseOrDef], opToIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); @@ -535,7 +574,9 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, // Check to see if this value is a range type. if (auto rangeTy = value.getType().dyn_cast()) { Type eleType = rangeTy.getElementType(); - if (eleType.isa()) + if (eleType.isa()) + defRangeIt->second.opRangeIndex = 0; + else if (eleType.isa()) defRangeIt->second.typeRangeIndex = 0; else if (eleType.isa()) defRangeIt->second.valueRangeIndex = 0; @@ -543,18 +584,37 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, }; // Process the live-ins of this block. - for (Value liveIn : info->in()) - processValue(liveIn, &block.front()); + for (Value liveIn : info->in()) { + // Only process the value if it has been defined in the current region. + // Other values that span across pdl_interp.foreach will be added higher + // up. This ensures that the we keep them alive for the entire duration + // of the loop. + if (liveIn.getParentRegion() == block->getParent()) + processValue(liveIn, &block->front()); + } + + // Process the block arguments for the entry block (those are not live-in). + if (block->isEntryBlock()) { + for (Value argument : block->getArguments()) + processValue(argument, &block->front()); + } // Process any new defs within this block. - for (Operation &op : block) + for (Operation &op : *block) for (Value result : op.getResults()) processValue(result, &op); - } + }); // Greedily allocate memory slots using the computed def live ranges. std::vector allocatedIndices; - ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; + + // The number of memory indices currently allocated (and its next value). + // Recall that the root gets allocated memory index 0. + ByteCodeField numIndices = 1; + + // The number of memory ranges of various types (and their next values). + ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; + for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; ByteCodeLiveRange &defRange = defIt.second; @@ -566,7 +626,11 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, existingRange.unionWith(defRange); memIndex = existingIndexIt.index() + 1; - if (defRange.typeRangeIndex) { + if (defRange.opRangeIndex) { + if (!existingRange.opRangeIndex) + existingRange.opRangeIndex = numOpRanges++; + valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; + } else if (defRange.typeRangeIndex) { if (!existingRange.typeRangeIndex) existingRange.typeRangeIndex = numTypeRanges++; valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; @@ -585,8 +649,11 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, ByteCodeLiveRange &newRange = allocatedIndices.back(); newRange.unionWith(defRange); - // Allocate an index for type/value ranges. - if (defRange.typeRangeIndex) { + // Allocate an index for op/type/value ranges. + if (defRange.opRangeIndex) { + newRange.opRangeIndex = numOpRanges; + valueToRangeIndex[defIt.first] = numOpRanges++; + } else if (defRange.typeRangeIndex) { newRange.typeRangeIndex = numTypeRanges; valueToRangeIndex[defIt.first] = numTypeRanges++; } else if (defRange.valueRangeIndex) { @@ -599,15 +666,35 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, } } + // Print the index usage and ensure that we did not run out of index space. + LLVM_DEBUG({ + llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " + << "(down from initial " << valueDefRanges.size() << ").\n"; + }); + assert(allocatedIndices.size() <= std::numeric_limits::max() && + "Ran out of memory for allocated indices"); + // Update the max number of indices. if (numIndices > maxValueMemoryIndex) maxValueMemoryIndex = numIndices; + if (numOpRanges > maxOpRangeMemoryIndex) + maxOpRangeMemoryIndex = numOpRanges; if (numTypeRanges > maxTypeRangeMemoryIndex) maxTypeRangeMemoryIndex = numTypeRanges; if (numValueRanges > maxValueRangeMemoryIndex) maxValueRangeMemoryIndex = numValueRanges; } +void Generator::generate(Region *region, ByteCodeWriter &writer) { + llvm::ReversePostOrderTraversal rpot(region); + for (Block *block : rpot) { + // Keep track of where this block begins within the matcher function. + blockToAddr.try_emplace(block, matcherByteCode.size()); + for (Operation &op : *block) + generate(&op, writer); + } +} + void Generator::generate(Operation *op, ByteCodeWriter &writer) { TypeSwitch(op) .Case 0 && "encountered pdl_interp.continue at top level"); + writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); +} void Generator::generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. @@ -736,9 +829,31 @@ void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { writer.append(OpCode::EraseOp, op.operation()); } +void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { + OpCode opCode = + TypeSwitch(op.result().getType()) + .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) + .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) + .Case([](pdl::TypeType) { return OpCode::ExtractType; }) + .Default([](Type) -> OpCode { + llvm_unreachable("unsupported element type"); + }); + writer.append(opCode, op.range(), op.index(), op.result()); +} void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { writer.append(OpCode::Finalize); } +void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { + BlockArgument arg = op.getLoopVariable(); + writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg); + writer.appendPDLValueKind(arg.getType()); + writer.append(curLoopLevel, op.successor()); + ++curLoopLevel; + if (curLoopLevel > maxLoopLevel) + maxLoopLevel = curLoopLevel; + generate(&op.region(), writer); + --curLoopLevel; +} void Generator::generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), @@ -793,6 +908,12 @@ void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { writer.append(std::numeric_limits::max()); writer.append(result); } +void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { + Value operations = op.operations(); + ByteCodeField rangeIndex = getRangeStorageIndex(operations); + writer.append(OpCode::GetUsers, operations, rangeIndex); + writer.appendPDLValue(op.value()); +} void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { if (op.getType().isa()) { @@ -865,8 +986,8 @@ PDLByteCode::PDLByteCode(ModuleOp module, llvm::StringMap rewriteFns) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, - maxTypeRangeCount, maxValueRangeCount, constraintFns, - rewriteFns); + maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, + maxLoopLevel, constraintFns, rewriteFns); generator.generate(module); // Initialize the external functions. @@ -880,8 +1001,10 @@ PDLByteCode::PDLByteCode(ModuleOp module, /// bytecode. void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); + state.opRangeMemory.resize(maxOpRangeCount); state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); + state.loopIndex.resize(maxLoopLevel, 0); state.currentPatternBenefits.reserve(patterns.size()); for (const PDLByteCodePattern &pattern : patterns) state.currentPatternBenefits.push_back(pattern.getBenefit()); @@ -896,20 +1019,23 @@ class ByteCodeExecutor { public: ByteCodeExecutor( const ByteCodeField *curCodeIt, MutableArrayRef memory, + MutableArrayRef> opRangeMemory, MutableArrayRef typeRangeMemory, std::vector> &allocatedTypeRangeMemory, MutableArrayRef valueRangeMemory, std::vector> &allocatedValueRangeMemory, - ArrayRef uniquedMemory, ArrayRef code, + MutableArrayRef loopIndex, ArrayRef uniquedMemory, + ArrayRef code, ArrayRef currentPatternBenefits, ArrayRef patterns, ArrayRef constraintFunctions, ArrayRef rewriteFunctions) - : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), + : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), + typeRangeMemory(typeRangeMemory), allocatedTypeRangeMemory(allocatedTypeRangeMemory), valueRangeMemory(valueRangeMemory), allocatedValueRangeMemory(allocatedValueRangeMemory), - uniquedMemory(uniquedMemory), code(code), + loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), currentPatternBenefits(currentPatternBenefits), patterns(patterns), constraintFunctions(constraintFunctions), rewriteFunctions(rewriteFunctions) {} @@ -932,10 +1058,15 @@ private: void executeCheckOperationName(); void executeCheckResultCount(); void executeCheckTypes(); + void executeContinue(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); void executeCreateTypes(); void executeEraseOp(PatternRewriter &rewriter); + template + void executeExtract(); + void executeFinalize(); + void executeForEach(); void executeGetAttribute(); void executeGetAttributeType(); void executeGetDefiningOp(); @@ -943,6 +1074,7 @@ private: void executeGetOperands(); void executeGetResult(unsigned index); void executeGetResults(); + void executeGetUsers(); void executeGetValueType(); void executeGetValueRangeTypes(); void executeIsNotNull(); @@ -956,6 +1088,16 @@ private: void executeSwitchType(); void executeSwitchTypes(); + /// Pushes a code iterator to the stack. + void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } + + /// Pops a code iterator from the stack, returning true on success. + void popCodeIt() { + assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); + curCodeIt = resumeCodeIt.back(); + resumeCodeIt.pop_back(); + } + /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point /// to the next field after the read data. @@ -1012,6 +1154,18 @@ private: selectJump(size_t(0)); } + /// Store a pointer to memory. + void storeToMemory(unsigned index, const void *value) { + memory[index] = value; + } + + /// Store a value to memory as an opaque pointer. + template + std::enable_if_t::value> + storeToMemory(unsigned index, T value) { + memory[index] = value.getAsOpaquePointer(); + } + /// Internal implementation of reading various data types from the bytecode /// stream. template @@ -1076,13 +1230,20 @@ private: /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; + /// The stack of bytecode positions at which to resume operation. + SmallVector resumeCodeIt; + /// The current execution memory. MutableArrayRef memory; + MutableArrayRef opRangeMemory; MutableArrayRef typeRangeMemory; std::vector> &allocatedTypeRangeMemory; MutableArrayRef valueRangeMemory; std::vector> &allocatedValueRangeMemory; + /// The current loop indices. + MutableArrayRef loopIndex; + /// References to ByteCode data necessary for execution. ArrayRef uniquedMemory; ArrayRef code; @@ -1277,6 +1438,14 @@ void ByteCodeExecutor::executeCheckTypes() { selectJump(*lhs == rhs.cast().getAsValueRange()); } +void ByteCodeExecutor::executeContinue() { + ByteCodeField level = read(); + LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" + << " * Level: " << level << "\n"); + ++loopIndex[level]; + popCodeIt(); +} + void ByteCodeExecutor::executeCreateTypes() { LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); unsigned memIndex = read(); @@ -1357,6 +1526,65 @@ void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { rewriter.eraseOp(op); } +template +void ByteCodeExecutor::executeExtract() { + LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); + Range *range = read(); + unsigned index = read(); + unsigned memIndex = read(); + + if (!range) { + memory[memIndex] = nullptr; + return; + } + + T result = index < range->size() ? (*range)[index] : T(); + LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" + << " * Index: " << index << "\n" + << " * Result: " << result << "\n"); + storeToMemory(memIndex, result); +} + +void ByteCodeExecutor::executeFinalize() { + LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); +} + +void ByteCodeExecutor::executeForEach() { + LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); + // Subtract 1 for the op code. + const ByteCodeField *it = curCodeIt - 1; + unsigned rangeIndex = read(); + unsigned memIndex = read(); + const void *value = nullptr; + + switch (read()) { + case PDLValue::Kind::Operation: { + unsigned &index = loopIndex[read()]; + ArrayRef array = opRangeMemory[rangeIndex]; + assert(index <= array.size() && "iterated past the end"); + if (index < array.size()) { + LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); + value = array[index]; + break; + } + + LLVM_DEBUG(llvm::dbgs() << " * Done\n"); + index = 0; + selectJump(size_t(0)); + return; + } + default: + llvm_unreachable("unexpected `ForEach` value kind"); + } + + // Store the iterate value and the stack address. + memory[memIndex] = value; + pushCodeIt(it); + + // Skip over the successor (we will enter the body of the loop). + read(); +} + void ByteCodeExecutor::executeGetAttribute() { LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); unsigned memIndex = read(); @@ -1421,7 +1649,7 @@ template