CheckResultCount,
/// Compare a range of types to a constant range of types.
CheckTypes,
+ /// Continue to the next iteration of a loop.
+ Continue,
/// Create an operation.
CreateOperation,
/// Create a range of types.
CreateTypes,
/// Erase an operation.
EraseOp,
+ /// Extract the op from a range at the specified index.
+ ExtractOp,
+ /// Extract the type from a range at the specified index.
+ ExtractType,
+ /// Extract the value from a range at the specified index.
+ ExtractValue,
/// Terminate a matcher or rewrite sequence.
Finalize,
+ /// Iterate over a range of values.
+ ForEach,
/// Get a specific attribute of an operation.
GetAttribute,
/// Get the type of an attribute.
GetResultN,
/// Get a specific result group of an operation.
GetResults,
+ /// Get the users of a value or a range of values.
+ GetUsers,
/// Get the type of a value.
GetValueType,
/// Get the types of a value range.
// Generator
namespace {
+struct ByteCodeLiveRange;
struct ByteCodeWriter;
+/// Check if the given class `T` can be converted to an opaque pointer.
+template <typename T, typename... Args>
+using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
+
/// This class represents the main generator for the pattern bytecode.
class Generator {
public:
SmallVectorImpl<ByteCodeField> &rewriterByteCode,
SmallVectorImpl<PDLByteCodePattern> &patterns,
ByteCodeField &maxValueMemoryIndex,
+ ByteCodeField &maxOpRangeMemoryIndex,
ByteCodeField &maxTypeRangeMemoryIndex,
ByteCodeField &maxValueRangeMemoryIndex,
+ ByteCodeField &maxLoopLevel,
llvm::StringMap<PDLConstraintFunction> &constraintFns,
llvm::StringMap<PDLRewriteFunction> &rewriteFns)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns),
maxValueMemoryIndex(maxValueMemoryIndex),
+ maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
- maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
+ maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
+ maxLoopLevel(maxLoopLevel) {
for (auto it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index());
for (auto it : llvm::enumerate(rewriteFns))
void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
/// Generate the bytecode for the given operation.
+ void generate(Region *region, ByteCodeWriter &writer);
void generate(Operation *op, ByteCodeWriter &writer);
void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
/// `uniquedData`.
DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
+ /// The current level of the foreach loop.
+ ByteCodeField curLoopLevel = 0;
+
/// The current MLIR context.
MLIRContext *ctx;
+ /// Mapping from block to its address.
+ DenseMap<Block *, ByteCodeAddr> blockToAddr;
+
/// Data of the ByteCode class to be populated.
std::vector<const void *> &uniquedData;
SmallVectorImpl<ByteCodeField> &matcherByteCode;
SmallVectorImpl<ByteCodeField> &rewriterByteCode;
SmallVectorImpl<PDLByteCodePattern> &patterns;
ByteCodeField &maxValueMemoryIndex;
+ ByteCodeField &maxOpRangeMemoryIndex;
ByteCodeField &maxTypeRangeMemoryIndex;
ByteCodeField &maxValueRangeMemoryIndex;
+ ByteCodeField &maxLoopLevel;
};
/// This class provides utilities for writing a bytecode stream.
bytecode.append({fieldParts[0], fieldParts[1]});
}
+ /// Append a single successor to the bytecode, the exact address will need to
+ /// be resolved later.
+ void append(Block *successor) {
+ // Add back a reference to the successor so that the address can be resolved
+ // later.
+ unresolvedSuccessorRefs[successor].push_back(bytecode.size());
+ append(ByteCodeAddr(0));
+ }
+
/// Append a successor range to the bytecode, the exact address will need to
/// be resolved later.
void append(SuccessorRange successors) {
- // Add back references to the any successors so that the address can be
- // resolved later.
- for (Block *successor : successors) {
- unresolvedSuccessorRefs[successor].push_back(bytecode.size());
- append(ByteCodeAddr(0));
- }
+ for (Block *successor : successors)
+ append(successor);
}
/// Append a range of values that will be read as generic PDLValues.
}
/// Append the PDLValue::Kind of the given value.
- void appendPDLValueKind(Value value) {
- // Append the type of the value in addition to the value itself.
+ void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
+
+ /// Append the PDLValue::Kind of the given type.
+ void appendPDLValueKind(Type type) {
PDLValue::Kind kind =
- TypeSwitch<Type, PDLValue::Kind>(value.getType())
+ TypeSwitch<Type, PDLValue::Kind>(type)
.Case<pdl::AttributeType>(
[](Type) { return PDLValue::Kind::Attribute; })
.Case<pdl::OperationType>(
bytecode.push_back(static_cast<ByteCodeField>(kind));
}
- /// Check if the given class `T` has an iterator type.
- template <typename T, typename... Args>
- using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
-
/// Append a value that will be stored in a memory slot and not inline within
/// the bytecode.
template <typename T>
/// This class represents a live range of PDL Interpreter values, containing
/// information about when values are live within a match/rewrite.
struct ByteCodeLiveRange {
- using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
+ using Set = llvm::IntervalMap<uint64_t, char, 16>;
using Allocator = Set::Allocator;
- ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
+ ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
/// Union this live range with the one provided.
void unionWith(const ByteCodeLiveRange &rhs) {
- for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
- liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
+ for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
+ ++it)
+ liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
}
/// Returns true if this range overlaps with the one provided.
bool overlaps(const ByteCodeLiveRange &rhs) const {
- return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
+ return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
+ .valid();
}
/// A map representing the ranges of the match/rewrite that a value is live in
/// the interpreter.
- llvm::IntervalMap<ByteCodeField, char, 16> liveness;
+ ///
+ /// We use std::unique_ptr here, because IntervalMap does not provide a
+ /// correct copy or move constructor. We can eliminate the pointer once
+ /// https://reviews.llvm.org/D113240 lands.
+ std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
+
+ /// The operation range storage index for this range.
+ Optional<unsigned> opRangeIndex;
/// The type range storage index for this range.
Optional<unsigned> typeRangeIndex;
"unexpected branches in rewriter function");
// Generate code for the matcher function.
- DenseMap<Block *, ByteCodeAddr> blockToAddr;
- llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
- for (Block *block : rpot) {
- // Keep track of where this block begins within the matcher function.
- blockToAddr.try_emplace(block, matcherByteCode.size());
- for (Operation &op : *block)
- generate(&op, matcherByteCodeWriter);
- }
+ generate(&matcherFunc.getBody(), matcherByteCodeWriter);
// Resolve successor references in the matcher.
for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
// finding the minimal number of overlapping live ranges. This is essentially
// a simplified form of register allocation where we don't necessarily have a
// limited number of registers, but we still want to minimize the number used.
- DenseMap<Operation *, ByteCodeField> opToIndex;
+ DenseMap<Operation *, unsigned> opToIndex;
matcherFunc.getBody().walk([&](Operation *op) {
opToIndex.insert(std::make_pair(op, opToIndex.size()));
});
// Walk each of the blocks, computing the def interval that the value is used.
Liveness matcherLiveness(matcherFunc);
- for (Block &block : matcherFunc.getBody()) {
- const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
+ matcherFunc->walk([&](Block *block) {
+ const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
assert(info && "expected liveness info for block");
auto processValue = [&](Value value, Operation *firstUseOrDef) {
// We don't need to process the root op argument, this value is always
// Set indices for the range of this block that the value is used.
auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
- defRangeIt->second.liveness.insert(
+ defRangeIt->second.liveness->insert(
opToIndex[firstUseOrDef],
opToIndex[info->getEndOperation(value, firstUseOrDef)],
/*dummyValue*/ 0);
// Check to see if this value is a range type.
if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
Type eleType = rangeTy.getElementType();
- if (eleType.isa<pdl::TypeType>())
+ if (eleType.isa<pdl::OperationType>())
+ defRangeIt->second.opRangeIndex = 0;
+ else if (eleType.isa<pdl::TypeType>())
defRangeIt->second.typeRangeIndex = 0;
else if (eleType.isa<pdl::ValueType>())
defRangeIt->second.valueRangeIndex = 0;
};
// Process the live-ins of this block.
- for (Value liveIn : info->in())
- processValue(liveIn, &block.front());
+ for (Value liveIn : info->in()) {
+ // Only process the value if it has been defined in the current region.
+ // Other values that span across pdl_interp.foreach will be added higher
+ // up. This ensures that the we keep them alive for the entire duration
+ // of the loop.
+ if (liveIn.getParentRegion() == block->getParent())
+ processValue(liveIn, &block->front());
+ }
+
+ // Process the block arguments for the entry block (those are not live-in).
+ if (block->isEntryBlock()) {
+ for (Value argument : block->getArguments())
+ processValue(argument, &block->front());
+ }
// Process any new defs within this block.
- for (Operation &op : block)
+ for (Operation &op : *block)
for (Value result : op.getResults())
processValue(result, &op);
- }
+ });
// Greedily allocate memory slots using the computed def live ranges.
std::vector<ByteCodeLiveRange> allocatedIndices;
- ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
+
+ // The number of memory indices currently allocated (and its next value).
+ // Recall that the root gets allocated memory index 0.
+ ByteCodeField numIndices = 1;
+
+ // The number of memory ranges of various types (and their next values).
+ ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
+
for (auto &defIt : valueDefRanges) {
ByteCodeField &memIndex = valueToMemIndex[defIt.first];
ByteCodeLiveRange &defRange = defIt.second;
existingRange.unionWith(defRange);
memIndex = existingIndexIt.index() + 1;
- if (defRange.typeRangeIndex) {
+ if (defRange.opRangeIndex) {
+ if (!existingRange.opRangeIndex)
+ existingRange.opRangeIndex = numOpRanges++;
+ valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
+ } else if (defRange.typeRangeIndex) {
if (!existingRange.typeRangeIndex)
existingRange.typeRangeIndex = numTypeRanges++;
valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
ByteCodeLiveRange &newRange = allocatedIndices.back();
newRange.unionWith(defRange);
- // Allocate an index for type/value ranges.
- if (defRange.typeRangeIndex) {
+ // Allocate an index for op/type/value ranges.
+ if (defRange.opRangeIndex) {
+ newRange.opRangeIndex = numOpRanges;
+ valueToRangeIndex[defIt.first] = numOpRanges++;
+ } else if (defRange.typeRangeIndex) {
newRange.typeRangeIndex = numTypeRanges;
valueToRangeIndex[defIt.first] = numTypeRanges++;
} else if (defRange.valueRangeIndex) {
}
}
+ // Print the index usage and ensure that we did not run out of index space.
+ LLVM_DEBUG({
+ llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
+ << "(down from initial " << valueDefRanges.size() << ").\n";
+ });
+ assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
+ "Ran out of memory for allocated indices");
+
// Update the max number of indices.
if (numIndices > maxValueMemoryIndex)
maxValueMemoryIndex = numIndices;
+ if (numOpRanges > maxOpRangeMemoryIndex)
+ maxOpRangeMemoryIndex = numOpRanges;
if (numTypeRanges > maxTypeRangeMemoryIndex)
maxTypeRangeMemoryIndex = numTypeRanges;
if (numValueRanges > maxValueRangeMemoryIndex)
maxValueRangeMemoryIndex = numValueRanges;
}
+void Generator::generate(Region *region, ByteCodeWriter &writer) {
+ llvm::ReversePostOrderTraversal<Region *> rpot(region);
+ for (Block *block : rpot) {
+ // Keep track of where this block begins within the matcher function.
+ blockToAddr.try_emplace(block, matcherByteCode.size());
+ for (Operation &op : *block)
+ generate(&op, writer);
+ }
+}
+
void Generator::generate(Operation *op, ByteCodeWriter &writer) {
TypeSwitch<Operation *>(op)
.Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
- pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
- pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
- pdl_interp::EraseOp, pdl_interp::FinalizeOp,
- pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
- pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
- pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
- pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
+ pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
+ pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
+ pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
+ pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
+ pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
+ pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
+ pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
+ pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
+ pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
}
+void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
+ assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
+ writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
+}
void Generator::generate(pdl_interp::CreateAttributeOp op,
ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
writer.append(OpCode::EraseOp, op.operation());
}
+void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
+ OpCode opCode =
+ TypeSwitch<Type, OpCode>(op.result().getType())
+ .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
+ .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
+ .Case([](pdl::TypeType) { return OpCode::ExtractType; })
+ .Default([](Type) -> OpCode {
+ llvm_unreachable("unsupported element type");
+ });
+ writer.append(opCode, op.range(), op.index(), op.result());
+}
void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
writer.append(OpCode::Finalize);
}
+void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
+ BlockArgument arg = op.getLoopVariable();
+ writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg);
+ writer.appendPDLValueKind(arg.getType());
+ writer.append(curLoopLevel, op.successor());
+ ++curLoopLevel;
+ if (curLoopLevel > maxLoopLevel)
+ maxLoopLevel = curLoopLevel;
+ generate(&op.region(), writer);
+ --curLoopLevel;
+}
void Generator::generate(pdl_interp::GetAttributeOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
writer.append(std::numeric_limits<ByteCodeField>::max());
writer.append(result);
}
+void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
+ Value operations = op.operations();
+ ByteCodeField rangeIndex = getRangeStorageIndex(operations);
+ writer.append(OpCode::GetUsers, operations, rangeIndex);
+ writer.appendPDLValue(op.value());
+}
void Generator::generate(pdl_interp::GetValueTypeOp op,
ByteCodeWriter &writer) {
if (op.getType().isa<pdl::RangeType>()) {
llvm::StringMap<PDLRewriteFunction> rewriteFns) {
Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex,
- maxTypeRangeCount, maxValueRangeCount, constraintFns,
- rewriteFns);
+ maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
+ maxLoopLevel, constraintFns, rewriteFns);
generator.generate(module);
// Initialize the external functions.
/// bytecode.
void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
state.memory.resize(maxValueMemoryIndex, nullptr);
+ state.opRangeMemory.resize(maxOpRangeCount);
state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
+ state.loopIndex.resize(maxLoopLevel, 0);
state.currentPatternBenefits.reserve(patterns.size());
for (const PDLByteCodePattern &pattern : patterns)
state.currentPatternBenefits.push_back(pattern.getBenefit());
public:
ByteCodeExecutor(
const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
+ MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
MutableArrayRef<TypeRange> typeRangeMemory,
std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
MutableArrayRef<ValueRange> valueRangeMemory,
std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
- ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
+ MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
+ ArrayRef<ByteCodeField> code,
ArrayRef<PatternBenefit> currentPatternBenefits,
ArrayRef<PDLByteCodePattern> patterns,
ArrayRef<PDLConstraintFunction> constraintFunctions,
ArrayRef<PDLRewriteFunction> rewriteFunctions)
- : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
+ : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
+ typeRangeMemory(typeRangeMemory),
allocatedTypeRangeMemory(allocatedTypeRangeMemory),
valueRangeMemory(valueRangeMemory),
allocatedValueRangeMemory(allocatedValueRangeMemory),
- uniquedMemory(uniquedMemory), code(code),
+ loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
currentPatternBenefits(currentPatternBenefits), patterns(patterns),
constraintFunctions(constraintFunctions),
rewriteFunctions(rewriteFunctions) {}
void executeCheckOperationName();
void executeCheckResultCount();
void executeCheckTypes();
+ void executeContinue();
void executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc);
void executeCreateTypes();
void executeEraseOp(PatternRewriter &rewriter);
+ template <typename T, typename Range, PDLValue::Kind kind>
+ void executeExtract();
+ void executeFinalize();
+ void executeForEach();
void executeGetAttribute();
void executeGetAttributeType();
void executeGetDefiningOp();
void executeGetOperands();
void executeGetResult(unsigned index);
void executeGetResults();
+ void executeGetUsers();
void executeGetValueType();
void executeGetValueRangeTypes();
void executeIsNotNull();
void executeSwitchType();
void executeSwitchTypes();
+ /// Pushes a code iterator to the stack.
+ void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
+
+ /// Pops a code iterator from the stack, returning true on success.
+ void popCodeIt() {
+ assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
+ curCodeIt = resumeCodeIt.back();
+ resumeCodeIt.pop_back();
+ }
+
/// Read a value from the bytecode buffer, optionally skipping a certain
/// number of prefix values. These methods always update the buffer to point
/// to the next field after the read data.
selectJump(size_t(0));
}
+ /// Store a pointer to memory.
+ void storeToMemory(unsigned index, const void *value) {
+ memory[index] = value;
+ }
+
+ /// Store a value to memory as an opaque pointer.
+ template <typename T>
+ std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
+ storeToMemory(unsigned index, T value) {
+ memory[index] = value.getAsOpaquePointer();
+ }
+
/// Internal implementation of reading various data types from the bytecode
/// stream.
template <typename T>
/// The underlying bytecode buffer.
const ByteCodeField *curCodeIt;
+ /// The stack of bytecode positions at which to resume operation.
+ SmallVector<const ByteCodeField *> resumeCodeIt;
+
/// The current execution memory.
MutableArrayRef<const void *> memory;
+ MutableArrayRef<OwningOpRange> opRangeMemory;
MutableArrayRef<TypeRange> typeRangeMemory;
std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
MutableArrayRef<ValueRange> valueRangeMemory;
std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
+ /// The current loop indices.
+ MutableArrayRef<unsigned> loopIndex;
+
/// References to ByteCode data necessary for execution.
ArrayRef<const void *> uniquedMemory;
ArrayRef<ByteCodeField> code;
selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
}
+void ByteCodeExecutor::executeContinue() {
+ ByteCodeField level = read();
+ LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
+ << " * Level: " << level << "\n");
+ ++loopIndex[level];
+ popCodeIt();
+}
+
void ByteCodeExecutor::executeCreateTypes() {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
unsigned memIndex = read();
rewriter.eraseOp(op);
}
+template <typename T, typename Range, PDLValue::Kind kind>
+void ByteCodeExecutor::executeExtract() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
+ Range *range = read<Range *>();
+ unsigned index = read<uint32_t>();
+ unsigned memIndex = read();
+
+ if (!range) {
+ memory[memIndex] = nullptr;
+ return;
+ }
+
+ T result = index < range->size() ? (*range)[index] : T();
+ LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
+ << " * Index: " << index << "\n"
+ << " * Result: " << result << "\n");
+ storeToMemory(memIndex, result);
+}
+
+void ByteCodeExecutor::executeFinalize() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
+}
+
+void ByteCodeExecutor::executeForEach() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
+ // Subtract 1 for the op code.
+ const ByteCodeField *it = curCodeIt - 1;
+ unsigned rangeIndex = read();
+ unsigned memIndex = read();
+ const void *value = nullptr;
+
+ switch (read<PDLValue::Kind>()) {
+ case PDLValue::Kind::Operation: {
+ unsigned &index = loopIndex[read()];
+ ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
+ assert(index <= array.size() && "iterated past the end");
+ if (index < array.size()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
+ value = array[index];
+ break;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " * Done\n");
+ index = 0;
+ selectJump(size_t(0));
+ return;
+ }
+ default:
+ llvm_unreachable("unexpected `ForEach` value kind");
+ }
+
+ // Store the iterate value and the stack address.
+ memory[memIndex] = value;
+ pushCodeIt(it);
+
+ // Skip over the successor (we will enter the body of the loop).
+ read<ByteCodeAddr>();
+}
+
void ByteCodeExecutor::executeGetAttribute() {
LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
unsigned memIndex = read();
static void *
executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
ByteCodeField rangeIndex, StringRef attrSizedSegments,
- MutableArrayRef<ValueRange> &valueRangeMemory) {
+ MutableArrayRef<ValueRange> valueRangeMemory) {
// Check for the sentinel index that signals that all values should be
// returned.
if (index == std::numeric_limits<uint32_t>::max()) {
memory[read()] = result;
}
+void ByteCodeExecutor::executeGetUsers() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
+ unsigned memIndex = read();
+ unsigned rangeIndex = read();
+ OwningOpRange &range = opRangeMemory[rangeIndex];
+ memory[memIndex] = ⦥
+
+ range = OwningOpRange();
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
+ // Read the value.
+ Value value = read<Value>();
+ if (!value)
+ return;
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+
+ // Extract the users of a single value.
+ range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
+ llvm::copy(value.getUsers(), range.begin());
+ } else {
+ // Read a range of values.
+ ValueRange *values = read<ValueRange *>();
+ if (!values)
+ return;
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Values (" << values->size() << "): ";
+ llvm::interleaveComma(*values, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ // Extract all the users of a range of values.
+ SmallVector<Operation *> users;
+ for (Value value : *values)
+ users.append(value.user_begin(), value.user_end());
+ range = OwningOpRange(users.size());
+ llvm::copy(users, range.begin());
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
+}
+
void ByteCodeExecutor::executeGetValueType() {
LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
unsigned memIndex = read();
case CheckTypes:
executeCheckTypes();
break;
+ case Continue:
+ executeContinue();
+ break;
case CreateOperation:
executeCreateOperation(rewriter, *mainRewriteLoc);
break;
case EraseOp:
executeEraseOp(rewriter);
break;
+ case ExtractOp:
+ executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
+ break;
+ case ExtractType:
+ executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
+ break;
+ case ExtractValue:
+ executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
+ break;
case Finalize:
- LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
+ executeFinalize();
+ LLVM_DEBUG(llvm::dbgs() << "\n");
return;
+ case ForEach:
+ executeForEach();
+ break;
case GetAttribute:
executeGetAttribute();
break;
case GetResults:
executeGetResults();
break;
+ case GetUsers:
+ executeGetUsers();
+ break;
case GetValueType:
executeGetValueType();
break;
// The matcher function always starts at code address 0.
ByteCodeExecutor executor(
- matcherByteCode.data(), state.memory, state.typeRangeMemory,
- state.allocatedTypeRangeMemory, state.valueRangeMemory,
- state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
- state.currentPatternBenefits, patterns, constraintFunctions,
- rewriteFunctions);
+ matcherByteCode.data(), state.memory, state.opRangeMemory,
+ state.typeRangeMemory, state.allocatedTypeRangeMemory,
+ state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
+ uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
+ constraintFunctions, rewriteFunctions);
executor.execute(rewriter, &matches);
// Order the found matches by benefit.
ByteCodeExecutor executor(
&rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
- state.typeRangeMemory, state.allocatedTypeRangeMemory,
- state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
+ state.opRangeMemory, state.typeRangeMemory,
+ state.allocatedTypeRangeMemory, state.valueRangeMemory,
+ state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
rewriterByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions);
executor.execute(rewriter, /*matches=*/nullptr, match.location);
// -----
//===----------------------------------------------------------------------===//
+// pdl_interp::ContinueOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
// pdl_interp::CreateAttributeOp
//===----------------------------------------------------------------------===//
// Fully tested within the tests for other operations.
//===----------------------------------------------------------------------===//
+// pdl_interp::ExtractOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %val = pdl_interp.get_result 0 of %root
+ %ops = pdl_interp.get_users of %val : !pdl.value
+ %op1 = pdl_interp.extract 1 of %ops : !pdl.operation
+ pdl_interp.is_not_null %op1 : !pdl.operation -> ^success, ^end
+ ^success:
+ pdl_interp.record_match @rewriters::@success(%op1 : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.extract_op
+// CHECK: "test.success"
+// CHECK: %[[OPERAND:.*]] = "test.op"
+// CHECK: "test.op"(%[[OPERAND]])
+module @ir attributes { test.extract_op } {
+ %operand = "test.op"() : () -> i32
+ "test.op"(%operand) : (i32) -> (i32)
+ "test.op"(%operand, %operand) : (i32, i32) -> (i32)
+}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %vals = pdl_interp.get_results of %root : !pdl.range<value>
+ %types = pdl_interp.get_value_type of %vals : !pdl.range<type>
+ %type1 = pdl_interp.extract 1 of %types : !pdl.type
+ pdl_interp.is_not_null %type1 : !pdl.type -> ^success, ^end
+ ^success:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.extract_type
+// CHECK: %[[OPERAND:.*]] = "test.op"
+// CHECK: "test.success"
+// CHECK: "test.op"(%[[OPERAND]])
+module @ir attributes { test.extract_type } {
+ %operand = "test.op"() : () -> i32
+ "test.op"(%operand) : (i32) -> (i32, i32)
+ "test.op"(%operand) : (i32) -> (i32)
+}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %vals = pdl_interp.get_results of %root : !pdl.range<value>
+ %val1 = pdl_interp.extract 1 of %vals : !pdl.value
+ pdl_interp.is_not_null %val1 : !pdl.value -> ^success, ^end
+ ^success:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.extract_value
+// CHECK: %[[OPERAND:.*]] = "test.op"
+// CHECK: "test.success"
+// CHECK: "test.op"(%[[OPERAND]])
+module @ir attributes { test.extract_value } {
+ %operand = "test.op"() : () -> i32
+ "test.op"(%operand) : (i32) -> (i32, i32)
+ "test.op"(%operand) : (i32) -> (i32)
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// pdl_interp::FinalizeOp
//===----------------------------------------------------------------------===//
// Fully tested within the tests for other operations.
//===----------------------------------------------------------------------===//
+// pdl_interp::ForEachOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %val1 = pdl_interp.get_result 0 of %root
+ %ops1 = pdl_interp.get_users of %val1 : !pdl.value
+ pdl_interp.foreach %op1 : !pdl.operation in %ops1 {
+ %val2 = pdl_interp.get_result 0 of %op1
+ %ops2 = pdl_interp.get_users of %val2 : !pdl.value
+ pdl_interp.foreach %op2 : !pdl.operation in %ops2 {
+ pdl_interp.record_match @rewriters::@success(%op2 : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.foreach
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: %[[ROOT:.*]] = "test.op"
+// CHECK: %[[VALA:.*]] = "test.op"(%[[ROOT]])
+// CHECK: %[[VALB:.*]] = "test.op"(%[[ROOT]])
+module @ir attributes { test.foreach } {
+ %root = "test.op"() : () -> i32
+ %valA = "test.op"(%root) : (i32) -> (i32)
+ "test.op"(%valA) : (i32) -> (i32)
+ "test.op"(%valA) : (i32) -> (i32)
+ %valB = "test.op"(%root) : (i32) -> (i32)
+ "test.op"(%valB) : (i32) -> (i32)
+ "test.op"(%valB) : (i32) -> (i32)
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetUsersOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %val = pdl_interp.get_result 0 of %root
+ %ops = pdl_interp.get_users of %val : !pdl.value
+ pdl_interp.foreach %op : !pdl.operation in %ops {
+ pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_users_of_value
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: %[[OPERAND:.*]] = "test.op"
+module @ir attributes { test.get_users_of_value } {
+ %operand = "test.op"() : () -> i32
+ "test.op"(%operand) : (i32) -> (i32)
+ "test.op"(%operand, %operand) : (i32, i32) -> (i32)
+}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end
+ ^next:
+ %vals = pdl_interp.get_results of %root : !pdl.range<value>
+ %ops = pdl_interp.get_users of %vals : !pdl.range<value>
+ pdl_interp.foreach %op : !pdl.operation in %ops {
+ pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_all_users_of_range
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: %[[OPERANDS:.*]]:2 = "test.op"
+module @ir attributes { test.get_all_users_of_range } {
+ %operands:2 = "test.op"() : () -> (i32, i32)
+ "test.op"(%operands#0) : (i32) -> (i32)
+ "test.op"(%operands#1) : (i32) -> (i32)
+}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end
+ ^next:
+ %vals = pdl_interp.get_results of %root : !pdl.range<value>
+ %val = pdl_interp.extract 0 of %vals : !pdl.value
+ %ops = pdl_interp.get_users of %val : !pdl.value
+ pdl_interp.foreach %op : !pdl.operation in %ops {
+ pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_first_users_of_range
+// CHECK: "test.success"
+// CHECK: %[[OPERANDS:.*]]:2 = "test.op"
+// CHECK: "test.op"
+module @ir attributes { test.get_first_users_of_range } {
+ %operands:2 = "test.op"() : () -> (i32, i32)
+ "test.op"(%operands#0) : (i32) -> (i32)
+ "test.op"(%operands#1) : (i32) -> (i32)
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// pdl_interp::GetAttributeOp
//===----------------------------------------------------------------------===//