void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "dealloc takes one operand");
+ OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
Function *freeFunc =
op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
}
- auto type = operands[0]->getType().cast<LLVM::LLVMType>();
+ auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
- Value *bufferPtr = extractMemRefElementPtr(
- rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
+ Value *bufferPtr =
+ extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
+ elementPtrType, hasStaticShape);
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
+ OperandAdaptor<MemRefCastOp> transformed(operands);
auto targetType = memRefCastOp.getType();
auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
// Copy the data buffer pointer.
auto elementTypePtr = getMemRefElementPtrType(targetType, lowering);
Value *buffer =
- extractMemRefElementPtr(rewriter, op->getLoc(), operands[0],
+ extractMemRefElementPtr(rewriter, op->getLoc(), transformed.source(),
elementTypePtr, sourceType.hasStaticShape());
// Account for static memrefs as target types
if (targetType.hasStaticShape())
sourceSize == -1
? rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(),
- operands[0], // NB: dynamic memref
+ transformed.source(), // NB: dynamic memref
getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
- assert(operands.size() == 1 && "expected exactly one operand");
auto dimOp = cast<DimOp>(op);
+ OperandAdaptor<DimOp> transformed(operands);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
auto shape = type.getShape();
++position;
}
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
- op, getIndexType(), operands[0],
+ op, getIndexType(), transformed.memrefOrTensor(),
getIntegerArrayAttr(rewriter, position));
} else {
rewriter.replaceOp(
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
+ OperandAdaptor<LoadOp> transformed(operands);
auto type = loadOp.getMemRefType();
- Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(),
- operands.drop_front(), rewriter, getModule());
+ Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter, getModule());
auto elementType = lowering.convertType(type.getElementType());
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
+ OperandAdaptor<StoreOp> transformed(operands);
- Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1],
- operands.drop_front(2), rewriter, getModule());
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, operands[0], dataPtr);
+ Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter, getModule());
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
+ dataPtr);
}
};
// CHECK-LABEL: OpA definitions
+// CHECK: OpAOperandAdaptor::OpAOperandAdaptor
+// CHECK-NEXT: tblgen_operands = values
+
+// CHECK: OpAOperandAdaptor::input
+// CHECK-NEXT: return tblgen_operands[0]
+
// CHECK: void OpA::build
// CHECK-SAME: Value *input
// CHECK: tblgen_state->operands.push_back(input);
let arguments = (ins Variadic<AnyTensor>:$input1, Variadic<AnyTensor>:$input2);
}
+// CHECK-LABEL: OpCOperandAdaptor::input1
+// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 0) / 2;
+// CHECK-NEXT: offset = 0 + variadicOperandSize * 0;
+// CHECK-NEXT: return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
+
+// CHECK-LABEL: OpCOperandAdaptor::input2
+// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 0) / 2;
+// CHECK-NEXT: offset = 0 + variadicOperandSize * 1;
+// CHECK-NEXT: return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
+
// CHECK-LABEL: Operation::operand_range OpC::input1()
// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 0) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 0;
let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
}
+// CHECK-LABEL: OpDOperandAdaptor::input1() {
+// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 1) / 2;
+// CHECK-NEXT: offset = 0 + variadicOperandSize * 0;
+// CHECK-NEXT: return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
+
+// CHECK-LABEL: Value *OpDOperandAdaptor::input2() {
+// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 1) / 2;
+// CHECK-NEXT: offset = 0 + variadicOperandSize * 1;
+// CHECK-NEXT: return tblgen_operands[offset];
+
+// CHECK-LABEL: OpDOperandAdaptor::input3() {
+// CHECK-NEXT: variadicOperandSize = (tblgen_operands.size() - 1) / 2;
+// CHECK-NEXT: offset = 1 + variadicOperandSize * 1;
+// CHECK-NEXT: return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
+
// CHECK-LABEL: Operation::operand_range OpD::input1()
// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 0;
// CHECK-LABEL: Value *OpD::input2()
// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: offset = 0 + variadicOperandSize * 1;
-// CHECK-NEXT: return this->getOperand(offset);
+// CHECK-NEXT: return this->getOperation()->getOperand(offset);
// CHECK-LABEL: Operation::operand_range OpD::input3()
// CHECK-NEXT: variadicOperandSize = (this->getNumOperands() - 1) / 2;
let arguments = (ins AnyTensor:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3, AnyTensor:$input4, AnyTensor:$input5);
}
+// CHECK-LABEL: Value *OpEOperandAdaptor::input1() {
+// CHECK-NEXT: return tblgen_operands[0];
+
+// CHECK-LABEL: Value *OpEOperandAdaptor::input2() {
+// CHECK-NEXT: return tblgen_operands[1];
+
+// CHECK-LABEL: OpEOperandAdaptor::input3() {
+// CHECK-NEXT: return {std::next(tblgen_operands.begin(), 2), std::next(tblgen_operands.begin(), 2 + tblgen_operands.size() - 4)};
+
+// CHECK-LABEL: Value *OpEOperandAdaptor::input4() {
+// CHECK-NEXT: return tblgen_operands[tblgen_operands.size() - 2];
+
+// CHECK-LABEL: Value *OpEOperandAdaptor::input5() {
+// CHECK-NEXT: return tblgen_operands[tblgen_operands.size() - 1];
+
// CHECK-LABEL: Value *OpE::input1()
// CHECK-NEXT: return this->getOperation()->getOperand(0);
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
private:
- // Returns true if the given C++ `type` ends with '&' or '*'.
- static bool endsWithRefOrPtr(StringRef type);
+ // Returns true if the given C++ `type` ends with '&' or '*', or is empty.
+ static bool elideSpaceAfterType(StringRef type);
std::string returnType;
std::string methodName;
// querying properties.
enum Property {
MP_None = 0x0,
- MP_Static = 0x1, // Static method
+ MP_Static = 0x1, // Static method
+ MP_Constructor = 0x2, // Constructor
};
OpMethod(StringRef retType, StringRef name, StringRef params,
OpMethodBody methodBody;
};
-// Class for holding an op for C++ code emission
-class OpClass {
+// A class used to emit C++ classes from Tablegen. Contains a list of public
+// methods and a list of private fields to be emitted.
+class Class {
public:
- explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
-
- // Adds an op trait.
- void addTrait(Twine trait);
+ explicit Class(StringRef name);
- // Creates a new method in this op's class.
+ // Creates a new method in this class.
OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "",
OpMethod::Property = OpMethod::MP_None,
bool declOnly = false);
+ OpMethod &newConstructor(StringRef params = "", bool declOnly = false);
+
+ // Creates a new field in this class.
+ void newField(StringRef type, StringRef name, StringRef defaultValue = "");
+
// Writes this op's class as a declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
// Writes the method definitions in this op's class to the given `os`.
// Returns the C++ class name of the op.
StringRef getClassName() const { return className; }
+protected:
+ std::string className;
+ SmallVector<OpMethod, 8> methods;
+ SmallVector<std::string, 4> fields;
+};
+
+// Class for holding an op for C++ code emission
+class OpClass : public Class {
+public:
+ explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
+
+ // Adds an op trait.
+ void addTrait(Twine trait);
+
+ // Writes this op's class as a declaration to the given `os`. Redefines
+ // Class::writeDeclTo to also emit traits and extra class declarations.
+ void writeDeclTo(raw_ostream &os) const;
+
private:
- StringRef className;
StringRef extraClassDeclaration;
SmallVector<std::string, 4> traits;
- SmallVector<OpMethod, 8> methods;
};
} // end anonymous namespace
: returnType(retType), methodName(name), parameters(params) {}
void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
- os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << methodName
+ os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName
<< "(" << parameters << ")";
}
return result;
};
- os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << namePrefix
+ os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << namePrefix
<< (namePrefix.empty() ? "" : "::") << methodName << "("
<< removeParamDefaultValue(parameters) << ")";
}
-bool OpMethodSignature::endsWithRefOrPtr(StringRef type) {
- return type.endswith("&") || type.endswith("*");
+bool OpMethodSignature::elideSpaceAfterType(StringRef type) {
+ return type.empty() || type.endswith("&") || type.endswith("*");
}
OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
os << "}";
}
+Class::Class(StringRef name) : className(name) {}
+
+OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params,
+ OpMethod::Property property, bool declOnly) {
+ methods.emplace_back(retType, name, params, property, declOnly);
+ return methods.back();
+}
+
+OpMethod &Class::newConstructor(StringRef params, bool declOnly) {
+ return newMethod("", getClassName(), params, OpMethod::MP_Constructor,
+ declOnly);
+}
+
+void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
+ std::string varName = formatv("{0} {1}", type, name).str();
+ std::string field = defaultValue.empty()
+ ? varName
+ : formatv("{0} = {1}", varName, defaultValue).str();
+ fields.push_back(std::move(field));
+}
+
+void Class::writeDeclTo(raw_ostream &os) const {
+ os << "class " << className << " {\n";
+ os << "public:\n";
+ for (const auto &method : methods) {
+ method.writeDeclTo(os);
+ os << '\n';
+ }
+ os << '\n';
+ os << "private:\n";
+ for (const auto &field : fields)
+ os.indent(2) << field << ";\n";
+ os << "};\n";
+}
+
+void Class::writeDefTo(raw_ostream &os) const {
+ for (const auto &method : methods) {
+ method.writeDefTo(os, className);
+ os << "\n\n";
+ }
+}
+
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
- : className(name), extraClassDeclaration(extraClassDeclaration) {}
+ : Class(name), extraClassDeclaration(extraClassDeclaration) {}
// Adds the given trait to this op. Prefixes "OpTrait::" to `trait` implicitly.
void OpClass::addTrait(Twine trait) {
traits.push_back(("OpTrait::" + trait).str());
}
-OpMethod &OpClass::newMethod(StringRef retType, StringRef name,
- StringRef params, OpMethod::Property property,
- bool declOnly) {
- methods.emplace_back(retType, name, params, property, declOnly);
- return methods.back();
-}
-
void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public Op<" << className;
for (const auto &trait : traits)
os << ", " << trait;
os << "> {\npublic:\n";
os << " using Op::Op;\n";
+ os << " using OperandAdaptor = " << className << "OperandAdaptor;\n";
for (const auto &method : methods) {
method.writeDeclTo(os);
os << "\n";
}
// TODO: Add line control markers to make errors easier to debug.
- os << extraClassDeclaration << "\n";
- os << "};";
-}
-
-void OpClass::writeDefTo(raw_ostream &os) const {
- for (const auto &method : methods) {
- method.writeDefTo(os, className);
- os << "\n\n";
- }
+ if (!extraClassDeclaration.empty())
+ os << extraClassDeclaration << "\n";
+ os << "};\n";
}
//===----------------------------------------------------------------------===//
// Helper class to emit a record into the given output stream.
class OpEmitter {
public:
- static void emitDecl(const Record &def, raw_ostream &os);
- static void emitDef(const Record &def, raw_ostream &os);
+ static void emitDecl(const Operator &op, raw_ostream &os);
+ static void emitDef(const Operator &op, raw_ostream &os);
private:
- OpEmitter(const Record &def);
+ OpEmitter(const Operator &op);
void emitDecl(raw_ostream &os);
void emitDef(raw_ostream &os);
void genOpNameGetter();
// The TableGen record for this op.
+ // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
+ // it should rather go through the Operator for better abstraction.
const Record &def;
// The wrapper operator class for querying information from this op.
};
} // end anonymous namespace
-OpEmitter::OpEmitter(const Record &def)
- : def(def), op(def),
+OpEmitter::OpEmitter(const Operator &op)
+ : def(op.getDef()), op(op),
opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
verifyCtx.withOp("(*this->getOperation())");
genFolderDecls();
}
-void OpEmitter::emitDecl(const Record &def, raw_ostream &os) {
- OpEmitter(def).emitDecl(os);
+void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
+ OpEmitter(op).emitDecl(os);
}
-void OpEmitter::emitDef(const Record &def, raw_ostream &os) {
- OpEmitter(def).emitDef(os);
+void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
+ OpEmitter(op).emitDef(os);
}
void OpEmitter::emitDecl(raw_ostream &os) {
- os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
opClass.writeDeclTo(os);
}
void OpEmitter::emitDef(raw_ostream &os) {
- os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
-
opClass.writeDefTo(os);
}
}
}
-void OpEmitter::genNamedOperandGetters() {
+// Generates the named operand getter methods for the given Operator `op` and
+// puts them in `opClass`. Uses `rangeType` as the return type of getters that
+// return a range of operands (individual operands are `Value *` and each
+// element in the range must also be `Value *`); use `rangeBeginCall` to get an
+// iterator to the beginning of the operand range; use `rangeSizeCall` to obtain
+// the number of operands. `getOperandCallPattern` contains the code necessary
+// to obtain a single operand whose position will be substituted instead of
+// "{0}" marker in the pattern. Note that the pattern should work for any kind
+// of ops, in particular for one-operand ops that may not have the
+// `getOperand(unsigned)` method.
+static void generateNamedOperandGetters(const Operator &op, Class &opClass,
+ StringRef rangeType,
+ StringRef rangeBeginCall,
+ StringRef rangeSizeCall,
+ StringRef getOperandCallPattern) {
const int numOperands = op.getNumOperands();
const int numVariadicOperands = op.getNumVariadicOperands();
const int numNormalOperands = numOperands - numVariadicOperands;
continue;
if (operand.isVariadic()) {
- auto &m = opClass.newMethod("Operation::operand_range", operand.name);
+ auto &m = opClass.newMethod(rangeType, operand.name);
m.body() << formatv(
- " return {{std::next(operand_begin(), {0}), "
- "std::next(operand_begin(), {0} + this->getNumOperands() - {1})};",
- i, numNormalOperands);
+ " return {{std::next({2}, {0}), std::next({2}, {0} + {3} - {1})};",
+ i, numNormalOperands, rangeBeginCall, rangeSizeCall);
emittedVariadicOperand = true;
} else {
auto &m = opClass.newMethod("Value *", operand.name);
- m.body() << " return this->getOperation()->getOperand(";
- if (emittedVariadicOperand)
- m.body() << "this->getNumOperands() - " << numOperands - i;
- else
- m.body() << i;
- m.body() << ");\n";
+
+ auto operandIndex =
+ emittedVariadicOperand
+ ? formatv("{0} - {1}", rangeSizeCall, numOperands - i).str()
+ : std::to_string(i);
+
+ m.body() << " return "
+ << formatv(getOperandCallPattern.data(), operandIndex)
+ << ";\n";
}
}
return;
continue;
const char *code = R"(
- int variadicOperandSize = (this->getNumOperands() - {0}) / {1};
+ int variadicOperandSize = ({4} - {0}) / {1};
int offset = {2} + variadicOperandSize * {3};
return )";
auto sizeAndOffset =
formatv(code, numNormalOperands, numVariadicOperands,
- emittedNormalOperands, emittedVariadicOperands);
+ emittedNormalOperands, emittedVariadicOperands, rangeSizeCall);
if (operand.isVariadic()) {
- auto &m = opClass.newMethod("Operation::operand_range", operand.name);
- m.body() << sizeAndOffset
- << "{std::next(operand_begin(), offset), "
- "std::next(operand_begin(), offset + variadicOperandSize)};";
+ auto &m = opClass.newMethod(rangeType, operand.name);
+ m.body() << sizeAndOffset << "{std::next(" << rangeBeginCall
+ << ", offset), std::next(" << rangeBeginCall
+ << ", offset + variadicOperandSize)};";
++emittedVariadicOperands;
} else {
auto &m = opClass.newMethod("Value *", operand.name);
- m.body() << sizeAndOffset << "this->getOperand(offset);";
+ m.body() << sizeAndOffset
+ << formatv(getOperandCallPattern.data(), "offset") << ";";
++emittedNormalOperands;
}
}
}
+void OpEmitter::genNamedOperandGetters() {
+ generateNamedOperandGetters(
+ op, opClass, /*rangeType=*/"Operation::operand_range",
+ /*rangeBeginCall=*/"operand_begin()",
+ /*rangeSizeCall=*/"this->getNumOperands()",
+ /*getOperandCallPattern=*/"this->getOperation()->getOperand({0})");
+}
+
void OpEmitter::genNamedResultGetters() {
const int numResults = op.getNumResults();
const int numVariadicResults = op.getNumVariadicResults();
method.body() << " return \"" << op.getOperationName() << "\";\n";
}
+//===----------------------------------------------------------------------===//
+// OpOperandAdaptor emitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Helper class to emit Op operand adaptors to an output stream. Operand
+// adaptors are wrappers around ArrayRef<Value *> that provide named operand
+// getters identical to those defined in the Op.
+class OpOperandAdaptorEmitter {
+public:
+ static void emitDecl(const Operator &op, raw_ostream &os);
+ static void emitDef(const Operator &op, raw_ostream &os);
+
+private:
+ explicit OpOperandAdaptorEmitter(const Operator &op);
+
+ Class adapterClass;
+};
+} // end namespace
+
+OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
+ : adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
+ adapterClass.newField("ArrayRef<Value *>", "tblgen_operands");
+ auto &constructor = adapterClass.newConstructor("ArrayRef<Value *> values");
+ constructor.body() << " tblgen_operands = values;\n";
+
+ generateNamedOperandGetters(op, adapterClass,
+ /*rangeType=*/"ArrayRef<Value *>",
+ /*rangeBeginCall=*/"tblgen_operands.begin()",
+ /*rangeSizeCall=*/"tblgen_operands.size()",
+ /*getOperandCallPattern=*/"tblgen_operands[{0}]");
+}
+
+void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
+ OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os);
+}
+
+void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
+ OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os);
+}
+
// Emits the opcode enum and op classes.
static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
bool emitDecl) {
IfDefScope scope("GET_OP_CLASSES", os);
for (auto *def : defs) {
+ Operator op(*def);
if (emitDecl) {
- OpEmitter::emitDecl(*def, os);
+ os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
+ OpOperandAdaptorEmitter::emitDecl(op, os);
+ OpEmitter::emitDecl(op, os);
} else {
- OpEmitter::emitDef(*def, os);
+ os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
+ OpOperandAdaptorEmitter::emitDef(op, os);
+ OpEmitter::emitDef(op, os);
}
}
}