Syntax:
``` {.ebnf}
-function-attribute ::= function-id `:` function-type
+function-attribute ::= function-id
```
-A function attribute is a literal attribute that represents a reference to the
-given function object.
+A function attribute is a literal attribute that represents a named reference to
+the given function.
#### String Attribute
friend StorageUniquer;
public:
- /// Returns if the derived attribute is or contains a function pointer.
- bool isOrContainsFunctionCache() const {
- return typeAndContainsFunctionAttrPair.getInt();
- }
-
/// Get the type of this attribute.
Type getType() const;
}
protected:
- /// Construct a new attribute storage instance with the given type and a
- /// boolean that signals if the derived attribute is or contains a function
- /// pointer.
+ /// Construct a new attribute storage instance with the given type.
/// Note: All attributes require a valid type. If no type is provided here,
/// the type of the attribute will automatically default to NoneType
/// upon initialization in the uniquer.
- AttributeStorage(Type type, bool isOrContainsFunctionCache = false);
- AttributeStorage(bool isOrContainsFunctionCache);
+ AttributeStorage(Type type);
AttributeStorage();
/// Set the type of this attribute.
/// The dialect for this attribute.
Dialect *dialect;
- /// This field is a pair of:
- /// - The type of the attribute value.
- /// - A boolean that is true if this is, or contains, a function attribute.
- llvm::PointerIntPair<const void *, 1, bool> typeAndContainsFunctionAttrPair;
+ /// The opaque type of the attribute value.
+ const void *type;
};
/// Default storage type for attributes that require no additional
getInitFn(ctx, T::getClassID()), kind, std::forward<Args>(args)...);
}
- /// Erase a uniqued instance of attribute T.
- template <typename T, typename... Args>
- static void erase(MLIRContext *ctx, unsigned kind, Args &&... args) {
- return ctx->getAttributeUniquer().erase<typename T::ImplType>(
- kind, std::forward<Args>(args)...);
- }
-
private:
/// Returns a functor used to initialize new attribute storage instances.
static std::function<void(AttributeStorage *)>
class AffineMap;
class Dialect;
class Function;
-class FunctionAttr;
class FunctionType;
class Identifier;
class IntegerSet;
struct AffineMapAttributeStorage;
struct IntegerSetAttributeStorage;
struct TypeAttributeStorage;
-struct FunctionAttributeStorage;
struct SplatElementsAttributeStorage;
struct DenseElementsAttributeStorage;
struct DenseIntElementsAttributeStorage;
/// Get the dialect this attribute is registered to.
Dialect &getDialect() const;
- /// Return true if this field is, or contains, a function attribute.
- bool isOrContainsFunction() const;
-
- /// Replace a function attribute or function attributes nested in an array
- /// attribute with another function attribute as defined by the provided
- /// remapping table. Return the original attribute if it (or any of nested
- /// attributes) is not present in the table.
- Attribute remapFunctionAttrs(
- const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const;
-
/// Print the attribute.
void print(raw_ostream &os) const;
void dump() const;
};
/// A function attribute represents a reference to a function object.
-///
-/// When working with IR, it is important to know that a function attribute can
-/// exist with a null Function inside of it, which occurs when a function object
-/// is deleted that had an attribute which referenced it. No references to this
-/// attribute should persist across the transformation, but that attribute will
-/// remain in MLIRContext.
class FunctionAttr
: public Attribute::AttrBase<FunctionAttr, Attribute,
- detail::FunctionAttributeStorage> {
+ detail::StringAttributeStorage> {
public:
using Base::Base;
using ValueType = Function *;
static FunctionAttr get(Function *value);
+ static FunctionAttr get(StringRef value, MLIRContext *ctx);
- Function *getValue() const;
-
- FunctionType getType() const;
+ /// Returns the name of the held function reference.
+ StringRef getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Function;
}
- /// This function is used by the internals of the Function class to null out
- /// attributes referring to functions that are about to be deleted.
- static void dropFunctionReference(Function *value);
-
/// This function is used by the internals of the Function class to update the
/// type of the function attribute for 'value'.
static void resetType(Function *value);
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
TypeAttr getTypeAttr(Type type);
FunctionAttr getFunctionAttr(Function *value);
+ FunctionAttr getFunctionAttr(StringRef value);
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<char> data);
ElementsAttr getDenseElementsAttr(ShapedType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs);
- ~Function();
-
/// The source location the function was defined or derived from.
Location getLoc() { return location; }
void setType(FunctionType newType) {
type = newType;
argAttrs.resize(type.getNumInputs());
- FunctionAttr::resetType(this);
}
MLIRContext *getContext();
def FunctionAttr : Attr<CPred<"$_self.isa<FunctionAttr>()">,
"function attribute"> {
let storageType = [{ FunctionAttr }];
- let returnType = [{ Function * }];
+ let returnType = [{ StringRef }];
let constBuilderCall = "$_builder.getFunctionAttr($0)";
}
return success();
}
- /// Resolve a parse function name and a type into a function reference.
- virtual ParseResult resolveFunctionName(StringRef name, FunctionType type,
- llvm::SMLoc loc,
- Function *&result) = 0;
-
/// Emit a diagnostic at the specified location and return failure.
virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
const Twine &message = {}) = 0;
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType().getResults());
+ }]>, OpBuilder<
+ "Builder *builder, OperationState *result, StringRef callee,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result->addOperands(operands);
+ result->addAttribute("callee", builder->getFunctionAttr(callee));
+ result->addTypes(results);
}]>];
let extraClassDeclaration = [{
- Function *getCallee() {
- return getAttrOfType<FunctionAttr>("callee").getValue();
- }
+ Function *getCallee();
/// Get the argument operands to the called function.
operand_range getArgOperands() {
std::unique_ptr<llvm::Module> llvmModule;
// Mappings between original and translated values, used for lookups.
- llvm::DenseMap<Function *, llvm::Function *> functionMapping;
+ llvm::StringMap<llvm::Function *> functionMapping;
llvm::DenseMap<Value *, llvm::Value *> valueMapping;
llvm::DenseMap<Block *, llvm::BasicBlock *> blockMapping;
};
void createAffineComputationSlice(Operation *opInst,
SmallVectorImpl<AffineApplyOp> *sliceOps);
-/// Replaces (potentially nested) function attributes in the operation "op"
-/// with those specified in "remappingTable".
-void remapFunctionAttrs(
- Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable);
-
-/// Replaces (potentially nested) function attributes all operations of the
-/// Function "fn" with those specified in "remappingTable".
-void remapFunctionAttrs(
- Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable);
-
-/// Replaces (potentially nested) function attributes in the entire module
-/// with those specified in "remappingTable". Ignores external functions.
-void remapFunctionAttrs(
- Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable);
-
} // end namespace mlir
#endif // MLIR_TRANSFORMS_UTILS_H
return fn.getContext()->getRegisteredDialect(dialectNamePair.first);
}
- template <typename ErrorContext>
- LogicalResult verifyAttribute(Attribute attr, ErrorContext &ctx) {
- if (!attr.isOrContainsFunction())
- return success();
-
- // If we have a function attribute, check that it is non-null and in the
- // same module as the operation that refers to it.
- if (auto fnAttr = attr.dyn_cast<FunctionAttr>()) {
- if (!fnAttr.getValue())
- return failure("attribute refers to deallocated function!", ctx);
-
- if (fnAttr.getValue()->getModule() != fn.getModule())
- return failure("attribute refers to function '" +
- Twine(fnAttr.getValue()->getName()) +
- "' defined in another module!",
- ctx);
- return success();
- }
-
- // Otherwise, we must have an array attribute, remap the elements.
- for (auto elt : attr.cast<ArrayAttr>().getValue())
- if (failed(verifyAttribute(elt, ctx)))
- return failure();
-
- return success();
- }
-
LogicalResult verify();
LogicalResult verifyBlock(Block &block, bool isTopLevel);
LogicalResult verifyOperation(Operation &op);
if (!identifierRegex.match(attr.first))
return failure("invalid attribute name '" + attr.first.strref() + "'",
fn);
- if (failed(verifyAttribute(attr.second, fn)))
- return failure();
/// Check that the attribute is a dialect attribute, i.e. contains a '.' for
/// the namespace.
llvm::formatv("invalid attribute name '{0}' on argument {1}",
attr.first.strref(), i),
fn);
- if (failed(verifyAttribute(attr.second, fn)))
- return failure();
/// Check that the attribute is a dialect attribute, i.e. contains a '.'
/// for the namespace.
if (!identifierRegex.match(attr.first))
return failure("invalid attribute name '" + attr.first.strref() + "'",
op);
- if (failed(verifyAttribute(attr.second, op)))
- return failure();
// Check for any optional dialect specific attributes.
if (!attr.first.strref().contains('.'))
#include "mlir/GPU/GPUDialect.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
}
Function *LaunchFuncOp::kernel() {
- return this->getAttr(getKernelAttrName()).cast<FunctionAttr>().getValue();
+ auto kernelAttr = getAttr(getKernelAttrName()).cast<FunctionAttr>();
+ return getOperation()->getFunction()->getModule()->getNamedFunction(
+ kernelAttr.getValue());
}
unsigned LaunchFuncOp::getNumKernelOperands() {
} else if (!kernelAttr.isa<FunctionAttr>()) {
return emitOpError("attribute 'kernel' must be a function");
}
+
Function *kernelFunc = this->kernel();
+ if (!kernelFunc)
+ return emitError() << "kernel function '" << kernelAttr << "' is undefined";
+
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
GPUDialect::getKernelFuncAttrName())) {
return emitError("kernel function is missing the '")
printType(attr.cast<TypeAttr>().getValue());
break;
case StandardAttributes::Function: {
- auto *function = attr.cast<FunctionAttr>().getValue();
- if (!function) {
- os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
- } else {
- printFunctionReference(function);
- os << " : ";
- printType(function->getType());
- }
+ os << '@' << attr.cast<FunctionAttr>().getValue();
break;
}
case StandardAttributes::OpaqueElements: {
} else {
specialName << 'c' << intCst.getInt() << '_' << type;
}
- } else if (cst.isa<FunctionAttr>()) {
+ } else if (type.isa<FunctionType>()) {
specialName << 'f';
} else {
specialName << "cst";
struct ArrayAttributeStorage : public AttributeStorage {
using KeyTy = ArrayRef<Attribute>;
- ArrayAttributeStorage(bool hasFunctionAttr, ArrayRef<Attribute> value)
- : AttributeStorage(hasFunctionAttr), value(value) {}
+ ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {}
/// Key equality function.
bool operator==(const KeyTy &key) const { return key == value; }
/// Construct a new storage instance.
static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator,
const KeyTy &key) {
- // Check to see if any of the elements have a function attr.
- bool hasFunctionAttr = llvm::any_of(
- key, [](Attribute elt) { return elt.isOrContainsFunction(); });
-
- // Initialize the memory using placement new.
return new (allocator.allocate<ArrayAttributeStorage>())
- ArrayAttributeStorage(hasFunctionAttr, allocator.copyInto(key));
+ ArrayAttributeStorage(allocator.copyInto(key));
}
ArrayRef<Attribute> value;
Type value;
};
-/// An attribute representing a reference to a function.
-struct FunctionAttributeStorage : public AttributeStorage {
- using KeyTy = Function *;
-
- FunctionAttributeStorage(Function *value)
- : AttributeStorage(value->getType(), /*isOrContainsFunctionCache=*/true),
- value(value) {}
-
- /// Key equality function.
- bool operator==(const KeyTy &key) const { return key == value; }
-
- /// Construct a new storage instance.
- static FunctionAttributeStorage *
- construct(AttributeStorageAllocator &allocator, KeyTy key) {
- return new (allocator.allocate<FunctionAttributeStorage>())
- FunctionAttributeStorage(key);
- }
-
- /// Storage cleanup function.
- void cleanup() {
- // Null out the function reference in the attribute to avoid dangling
- // pointers.
- value = nullptr;
- }
-
- /// Reset the type of this attribute to the type of the held function.
- void resetType() { setType(value->getType()); }
-
- Function *value;
-};
-
/// An attribute representing a reference to a vector or tensor constant,
/// inwhich all elements have the same value.
struct SplatElementsAttributeStorage : public AttributeStorage {
// AttributeStorage
//===----------------------------------------------------------------------===//
-AttributeStorage::AttributeStorage(Type type, bool isOrContainsFunctionCache)
- : typeAndContainsFunctionAttrPair(type.getAsOpaquePointer(),
- isOrContainsFunctionCache) {}
-AttributeStorage::AttributeStorage(bool isOrContainsFunctionCache)
- : AttributeStorage(/*type=*/nullptr, isOrContainsFunctionCache) {}
-AttributeStorage::AttributeStorage()
- : AttributeStorage(/*type=*/nullptr, /*isOrContainsFunctionCache=*/false) {}
+AttributeStorage::AttributeStorage(Type type)
+ : type(type.getAsOpaquePointer()) {}
+AttributeStorage::AttributeStorage() : type(nullptr) {}
Type AttributeStorage::getType() const {
- return Type::getFromOpaquePointer(
- typeAndContainsFunctionAttrPair.getPointer());
+ return Type::getFromOpaquePointer(type);
}
-void AttributeStorage::setType(Type type) {
- typeAndContainsFunctionAttrPair.setPointer(type.getAsOpaquePointer());
+void AttributeStorage::setType(Type newType) {
+ type = newType.getAsOpaquePointer();
}
//===----------------------------------------------------------------------===//
/// Get the dialect this attribute is registered to.
Dialect &Attribute::getDialect() const { return impl->getDialect(); }
-bool Attribute::isOrContainsFunction() const {
- return impl->isOrContainsFunctionCache();
-}
-
-// Given an attribute that could refer to a function attribute in the remapping
-// table, walk it and rewrite it to use the mapped function. If it doesn't
-// refer to anything in the table, then it is returned unmodified.
-Attribute Attribute::remapFunctionAttrs(
- const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const {
- // Most attributes are trivially unrelated to function attributes, skip them
- // rapidly.
- if (!isOrContainsFunction())
- return *this;
-
- // If we have a function attribute, remap it.
- if (auto fnAttr = this->dyn_cast<FunctionAttr>()) {
- auto it = remappingTable.find(fnAttr);
- return it != remappingTable.end() ? it->second : *this;
- }
-
- // Otherwise, we must have an array attribute, remap the elements.
- auto arrayAttr = this->cast<ArrayAttr>();
- SmallVector<Attribute, 8> remappedElts;
- bool anyChange = false;
- for (auto elt : arrayAttr.getValue()) {
- auto newElt = elt.remapFunctionAttrs(remappingTable);
- remappedElts.push_back(newElt);
- anyChange |= (elt != newElt);
- }
-
- if (!anyChange)
- return *this;
-
- return ArrayAttr::get(remappedElts, getContext());
-}
-
//===----------------------------------------------------------------------===//
// OpaqueAttr
//===----------------------------------------------------------------------===//
FunctionAttr FunctionAttr::get(Function *value) {
assert(value && "Cannot get FunctionAttr for a null function");
- return Base::get(value->getContext(), StandardAttributes::Function, value);
+ return get(value->getName(), value->getContext());
}
-/// This function is used by the internals of the Function class to null out
-/// attributes referring to functions that are about to be deleted.
-void FunctionAttr::dropFunctionReference(Function *value) {
- AttributeUniquer::erase<FunctionAttr>(value->getContext(),
- StandardAttributes::Function, value);
+FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
+ return Base::get(ctx, StandardAttributes::Function, value);
}
-/// This function is used by the internals of the Function class to update the
-/// type of the attribute for 'value'.
-void FunctionAttr::resetType(Function *value) {
- FunctionAttr::get(value).getImpl()->resetType();
-}
-
-Function *FunctionAttr::getValue() const { return getImpl()->value; }
-
-FunctionType FunctionAttr::getType() const {
- return Attribute::getType().cast<FunctionType>();
-}
+StringRef FunctionAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
// ElementsAttr
FunctionAttr Builder::getFunctionAttr(Function *value) {
return FunctionAttr::get(value);
}
+FunctionAttr Builder::getFunctionAttr(StringRef value) {
+ return FunctionAttr::get(value, getContext());
+}
ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
return SplatElementsAttr::get(type, elt);
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
-Function::~Function() {
- // Clean up function attributes referring to this function.
- FunctionAttr::dropFunctionReference(this);
-}
-
MLIRContext *Function::getContext() { return getType().getContext(); }
/// Swap the name of the given function with this one.
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/AsmParser/Parser.h"
// callee (first operand) otherwise.
*p << op.getOperationName() << ' ';
if (isDirect)
- *p << '@' << callee.getValue()->getName().strref();
+ *p << '@' << callee.getValue();
else
*p << *op.getOperand(0);
*p << '(';
- p->printOperands(std::next(op.operand_begin(), callee.hasValue() ? 0 : 1),
- op.operand_end());
+ p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
*p << ')';
p->printOptionalAttrDict(op.getAttrs(), {"callee"});
- if (isDirect) {
- *p << " : " << callee.getValue()->getType();
- return;
- }
-
- // Reconstruct the function MLIR function type from LLVM function type,
- // and print it.
- auto operandType = op.getOperand(0)->getType().cast<LLVM::LLVMType>();
- auto *llvmPtrType =
- dyn_cast<llvm::PointerType>(operandType.getUnderlyingType());
- assert(llvmPtrType &&
- "operand #0 must have LLVM pointer type for indirect calls");
- auto *llvmType = dyn_cast<llvm::FunctionType>(llvmPtrType->getElementType());
- assert(llvmType &&
- "operand #0 must have LLVM Function pointer type for indirect calls");
-
- auto *llvmResultType = llvmType->getReturnType();
- SmallVector<Type, 1> resultTypes;
- if (!llvmResultType->isVoidTy())
- resultTypes.push_back(LLVM::LLVMType::get(op.getContext(), llvmResultType));
-
+ // Reconstruct the function MLIR function type from operand and result types.
+ SmallVector<Type, 1> resultTypes(op.getOperation()->getResultTypes());
SmallVector<Type, 8> argTypes;
- argTypes.reserve(llvmType->getNumParams());
- for (int i = 0, e = llvmType->getNumParams(); i < e; ++i)
- argTypes.push_back(
- LLVM::LLVMType::get(op.getContext(), llvmType->getParamType(i)));
+ argTypes.reserve(op.getNumOperands());
+ for (auto *operand : llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1))
+ argTypes.push_back(operand->getType());
*p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
}
return parser->emitError(trailingTypeLoc, "expected function type");
if (isDirect) {
// Add the direct callee as an Op attribute.
- Function *func;
- if (parser->resolveFunctionName(calleeName, funcType, calleeLoc, func))
- return failure();
- auto funcAttr = parser->getBuilder().getFunctionAttr(func);
+ auto funcAttr = parser->getBuilder().getFunctionAttr(calleeName);
attrs.push_back(parser->getBuilder().getNamedAttr("callee", funcAttr));
// Make sure types match.
: context(module->getContext()), module(module), lex(sourceMgr, context),
curToken(lex.lexToken()) {}
- ~ParserState() {
- // Destroy the forward references upon error.
- for (auto forwardRef : functionForwardRefs)
- delete forwardRef.second;
- functionForwardRefs.clear();
- }
-
// A map from attribute alias identifier to Attribute.
llvm::StringMap<Attribute> attributeAliasDefinitions;
// A map from type alias identifier to Type.
llvm::StringMap<Type> typeAliasDefinitions;
- // This keeps track of all forward references to functions along with the
- // temporary function used to represent them.
- llvm::DenseMap<Identifier, Function *> functionForwardRefs;
-
private:
ParserState(const ParserState &) = delete;
void operator=(const ParserState &) = delete;
ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
// Attribute parsing.
- Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
- FunctionType type);
Attribute parseExtendedAttribute(Type type);
Attribute parseAttribute(Type type = {});
return success();
}
-/// Given a parsed reference to a function name like @foo and a type that it
-/// corresponds to, resolve it to a concrete function object (possibly
-/// synthesizing a forward reference) or emit an error and return null on
-/// failure.
-Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
- FunctionType type) {
- Identifier name = builder.getIdentifier(nameStr.drop_front());
-
- // See if the function has already been defined in the module.
- Function *function = getModule()->getNamedFunction(name);
-
- // If not, get or create a forward reference to one.
- if (!function) {
- auto &entry = state.functionForwardRefs[name];
- if (!entry)
- entry = new Function(getEncodedSourceLocation(nameLoc), name, type,
- /*attrs=*/{});
- function = entry;
- }
-
- if (function->getType() != type)
- return (emitError(nameLoc, "reference to function with mismatched type"),
- nullptr);
- return function;
-}
-
/// Parse an extended attribute.
///
/// extended-attribute ::= (dialect-attribute | attribute-alias)
}
case Token::at_identifier: {
- auto nameLoc = getToken().getLoc();
auto nameStr = getTokenSpelling();
consumeToken(Token::at_identifier);
-
- if (parseToken(Token::colon, "expected ':' and function type"))
- return nullptr;
- auto typeLoc = getToken().getLoc();
- Type type = parseType();
- if (!type)
- return nullptr;
- auto fnType = type.dyn_cast<FunctionType>();
- if (!fnType)
- return (emitError(typeLoc, "expected function type"), nullptr);
-
- auto *function = resolveFunctionReference(nameStr, nameLoc, fnType);
- return function ? builder.getFunctionAttr(function) : nullptr;
+ return builder.getFunctionAttr(nameStr.drop_front());
}
case Token::kw_opaque: {
consumeToken(Token::kw_opaque);
if (parser.getToken().isNot(Token::at_identifier))
return failure();
- result = parser.getTokenSpelling();
+ result = parser.getTokenSpelling().drop_front();
parser.consumeToken(Token::at_identifier);
return success();
}
return success();
}
- /// Resolve a parse function name and a type into a function reference.
- ParseResult resolveFunctionName(StringRef name, FunctionType type,
- llvm::SMLoc loc, Function *&result) override {
- result = parser.resolveFunctionReference(name, loc, type);
- return failure(result == nullptr);
- }
-
/// Parse a region that takes `arguments` of `argTypes` types. This
/// effectively defines the SSA values of `arguments` and assignes their type.
ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments,
ParseResult parseModule();
private:
- ParseResult finalizeModule();
-
ParseResult parseAttributeAliasDef();
ParseResult parseTypeAliasDef();
return parser.parseFunctionBody(hadNamedArguments);
}
-/// Finish the end of module parsing - when the result is valid, do final
-/// checking.
-ParseResult ModuleParser::finalizeModule() {
- // Resolve all forward references, building a remapping table of attributes.
- DenseMap<Attribute, FunctionAttr> remappingTable;
- for (auto forwardRef : getState().functionForwardRefs) {
- auto name = forwardRef.first;
-
- // Resolve the reference.
- auto *resolvedFunction = getModule()->getNamedFunction(name);
- if (!resolvedFunction) {
- forwardRef.second->emitError("reference to undefined function '")
- << name << "'";
- return failure();
- }
-
- remappingTable[builder.getFunctionAttr(forwardRef.second)] =
- builder.getFunctionAttr(resolvedFunction);
- }
-
- // If there was nothing to remap, then we're done.
- if (remappingTable.empty())
- return success();
-
- // Otherwise, walk the entire module replacing uses of one attribute set
- // with the correct ones.
- remapFunctionAttrs(*getModule(), remappingTable);
-
- // Now that all references to the forward definition placeholders are
- // resolved, we can deallocate the placeholders.
- for (auto forwardRef : getState().functionForwardRefs)
- delete forwardRef.second;
- getState().functionForwardRefs.clear();
- return success();
-}
-
/// This is the top-level module parser.
ParseResult ModuleParser::parseModule() {
while (1) {
// If we got to the end of the file, then we're done.
case Token::eof:
- return finalizeModule();
+ return success();
// If we got an error token, then the lexer already emitted an error, just
// stop. Someday we could introduce error recovery if there was demand
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
llvm::SMLoc calleeLoc;
FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
- Function *callee = nullptr;
if (parser->parseFunctionName(calleeName, calleeLoc) ||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
- parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
parser->addTypesToList(calleeType.getResults(), result->types) ||
parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result->operands))
return failure();
- result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
+ result->addAttribute("callee",
+ parser->getBuilder().getFunctionAttr(calleeName));
return success();
}
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute");
+ auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
+ fnAttr.getValue());
+ if (!fn)
+ return op.emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
// Verify that the operand and result types match the callee.
- auto fnType = fnAttr.getValue()->getType();
+ auto fnType = fn->getType();
if (fnType.getNumInputs() != op.getNumOperands())
return op.emitOpError("incorrect number of operands for callee");
return success();
}
+Function *CallOp::getCallee() {
+ auto name = getAttrOfType<FunctionAttr>("callee").getValue();
+ return getOperation()->getFunction()->getModule()->getNamedFunction(name);
+}
+
//===----------------------------------------------------------------------===//
// CallIndirectOp
//===----------------------------------------------------------------------===//
return matchFailure();
// Replace with a direct call.
+ SmallVector<Type, 8> callResults(op->getResultTypes());
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
- rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callOperands);
+ rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callResults,
+ callOperands);
return matchSuccess();
}
};
if (op.getAttrs().size() > 1)
*p << ' ';
p->printAttributeAndType(op.getValue());
+
+ // If the value is a function, print a trailing type.
+ if (op.getValue().isa<FunctionAttr>()) {
+ *p << " : ";
+ p->printType(op.getType());
+ }
}
static ParseResult parseConstantOp(OpAsmParser *parser,
OperationState *result) {
Attribute valueAttr;
- Type type;
-
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseAttribute(valueAttr, "value", result->attributes))
return failure();
+ // If the attribute is a function, then we expect a trailing type.
+ Type type;
+ if (!valueAttr.isa<FunctionAttr>())
+ type = valueAttr.getType();
+ else if (parser->parseColonType(type))
+ return failure();
+
// Add the attribute type to the list.
- return parser->addTypeToList(valueAttr.getType(), result->types);
+ return parser->addTypeToList(type, result->types);
}
/// The constant op requires an attribute, and furthermore requires that it
return op.emitOpError("requires a 'value' attribute");
auto type = op.getType();
- if (type != value.getType())
+ if (!value.getType().isa<NoneType>() && type != value.getType())
return op.emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";
}
if (type.isa<FunctionType>()) {
- if (!value.isa<FunctionAttr>())
+ auto fnAttr = value.dyn_cast<FunctionAttr>();
+ if (!fnAttr)
return op.emitOpError("requires 'value' to be a function reference");
+
+ // Try to find the referenced function.
+ auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
+ fnAttr.getValue());
+ if (!fn)
+ return op.emitOpError("reference to undefined function 'bar'");
+
+ // Check that the referenced function has the correct type.
+ if (fn->getType() != type)
+ return op.emitOpError("reference to function with mismatched type");
+
return success();
}
// function.
blockMapping.clear();
valueMapping.clear();
- llvm::Function *llvmFunc = functionMapping.lookup(&func);
+ llvm::Function *llvmFunc = functionMapping.lookup(func.getName());
// Add function arguments to the value remapping table.
// If there was noalias info then we decorate each argument accordingly.
unsigned int argIdx = 0;
// Declare all functions first because there may be function calls that form a
// call graph with cycles.
for (Function &function : mlirModule) {
- Function *functionPtr = &function;
mlir::BoolAttr isVarArgsAttr =
function.getAttrOfType<BoolAttr>("std.varargs");
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
llvm::FunctionCallee llvmFuncCst =
llvmModule->getOrInsertFunction(function.getName(), functionType);
assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
- functionMapping[functionPtr] =
+ functionMapping[function.getName()] =
cast<llvm::Function>(llvmFuncCst.getCallee());
}
opInst->setOperand(idx, newOperands[idx]);
}
}
-
-static void
-remapFunctionAttrs(NamedAttributeList &attrs,
- const DenseMap<Attribute, FunctionAttr> &remappingTable) {
- for (auto attr : attrs.getAttrs()) {
- // Do the remapping, if we got the same thing back, then it must contain
- // functions that aren't getting remapped.
- auto newVal = attr.second.remapFunctionAttrs(remappingTable);
- if (newVal == attr.second)
- continue;
-
- // Otherwise, replace the existing attribute with the new one. It is safe
- // to mutate the attribute list while we walk it because underlying
- // attribute lists are uniqued and immortal.
- attrs.set(attr.first, newVal);
- }
-}
-
-void mlir::remapFunctionAttrs(
- Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
- ::remapFunctionAttrs(op.getAttrList(), remappingTable);
-}
-
-void mlir::remapFunctionAttrs(
- Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
-
- // Remap the attributes of the function.
- ::remapFunctionAttrs(fn.getAttrList(), remappingTable);
-
- // Remap the attributes of the arguments of this function.
- for (auto &attrList : fn.getAllArgAttrs())
- ::remapFunctionAttrs(attrList, remappingTable);
-
- // Look at all operations in a Function.
- fn.walk([&](Operation *op) { remapFunctionAttrs(*op, remappingTable); });
-}
-
-void mlir::remapFunctionAttrs(
- Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
- for (auto &fn : module)
- remapFunctionAttrs(fn, remappingTable);
-}
func @launch_func_no_function_attribute(%sz : index) {
// expected-error@+1 {{attribute 'kernel' must be a function}}
- "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) {kernel: "bar"}
+ "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) {kernel: 10}
+ : (index, index, index, index, index, index) -> ()
+ return
+}
+
+// -----
+
+func @launch_func_undefined_function(%sz : index) {
+ // expected-error@+1 {{kernel function '@kernel_1' is undefined}}
+ "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) { kernel: @kernel_1 }
: (index, index, index, index, index, index) -> ()
return
}
func @launch_func_missing_kernel_attr(%sz : index, %arg : !llvm<"float*">) {
// expected-error@+1 {{kernel function is missing the 'gpu.kernel' attribute}}
- "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
- {kernel: @kernel_1 : (!llvm<"float*">) -> ()}
+ "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg) {kernel: @kernel_1}
: (index, index, index, index, index, index, !llvm<"float*">) -> ()
return
}
func @launch_func_kernel_operand_size(%sz : index, %arg : !llvm<"float*">) {
// expected-error@+1 {{got 2 kernel operands but expected 1}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg, %arg)
- {kernel: @kernel_1 : (!llvm<"float*">) -> ()}
+ {kernel: @kernel_1}
: (index, index, index, index, index, index, !llvm<"float*">,
!llvm<"float*">) -> ()
return
func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
// expected-error@+1 {{type of function argument 0 does not match}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
- {kernel: @kernel_1 : (!llvm<"float*">) -> ()}
+ {kernel: @kernel_1}
: (index, index, index, index, index, index, f32) -> ()
return
}
// CHECK: %c8 = constant 8
%cst = constant 8 : index
- // CHECK: "gpu.launch_func"(%c8, %c8, %c8, %c8, %c8, %c8, %0, %1) {kernel: @kernel_1 : (f32, memref<?xf32, 1>) -> ()} : (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
- "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1)
- {kernel: @kernel_1 : (f32, memref<?xf32, 1>) -> ()}
+ // CHECK: "gpu.launch_func"(%c8, %c8, %c8, %c8, %c8, %c8, %0, %1) {kernel: @kernel_1} : (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
+ "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel: @kernel_1 }
: (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
return
}
// CHECK: %1 = call @return_op(%0) : (i32) -> i32
%y = call @return_op(%x) : (i32) -> i32
// CHECK: %2 = call @return_op(%0) : (i32) -> i32
- %z = "std.call"(%x) {callee: @return_op : (i32) -> i32} : (i32) -> i32
+ %z = "std.call"(%x) {callee: @return_op} : (i32) -> i32
// CHECK: %f = constant @affine_apply : () -> ()
%f = constant @affine_apply : () -> ()
func @constant() {
^bb:
- %x = "std.constant"(){value: "xyz"} : () -> i32 // expected-error {{requires attribute's type (none) to match op's return type (i32)}}
+ %x = "std.constant"(){value: "xyz"} : () -> i32 // expected-error {{requires a result type that aligns with the 'value' attribute}}
return
}
// -----
func @calls(%arg0: i32) {
- %x = call @calls() : () -> i32 // expected-error {{reference to function with mismatched type}}
+ %x = call @calls() : () -> i32 // expected-error {{incorrect number of operands for callee}}
return
}
// -----
-func @func() -> (() -> ())
-func @referer() {
- %f = constant @func : () -> () -> () // expected-error {{reference to function with mismatched type}}
- return
-}
-
-// -----
-
#map1 = (i)[j] -> (i+j)
func @bound_symbol_mismatch(%N : index) {
// CHECK: "foo"() {d: 1.000000e-09 : f64, func: [], i123: 7 : i64, if: "foo"} : () -> ()
"foo"() {if: "foo", func: [], i123: 7, d: 1.e-9} : () -> ()
- // CHECK: "foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
- "foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
+ // CHECK: "foo"() {fn: @attributes, if: @ifinst} : () -> ()
+ "foo"() {fn: @attributes, if: @ifinst} : () -> ()
// CHECK: "foo"() {int: 0 : i42} : () -> ()
"foo"() {int: 0 : i42} : () -> ()
%none_val = "foo.unknown_op"() : () -> none
return
}
-
-// CHECK-LABEL: func @fn_attr_remap
-// CHECK: {some_dialect.arg_attr: @fn_attr_ref : () -> ()}
-func @fn_attr_remap(%arg0: i1 {some_dialect.arg_attr: @fn_attr_ref : () -> ()}) -> ()
- // CHECK-NEXT: {some_dialect.fn_attr: @fn_attr_ref : () -> ()}
- attributes {some_dialect.fn_attr: @fn_attr_ref : () -> ()} {
- return
-}
-
-// CHECK-LABEL: func @fn_attr_ref
-func @fn_attr_ref() -> ()
-
// CHECK-LABEL: func @indirect_const_call(%arg0: !llvm.i32) {
func @indirect_const_call(%arg0: i32) {
-// CHECK-NEXT: %0 = llvm.constant(@body : (!llvm.i32) -> ()) : !llvm<"void (i32)*">
+// CHECK-NEXT: %0 = llvm.constant(@body) : !llvm<"void (i32)*">
%0 = constant @body : (i32) -> ()
// CHECK-NEXT: llvm.call %0(%arg0) : (!llvm.i32) -> ()
call_indirect %0(%arg0) : (i32) -> ()
// CHECK-NEXT: %17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
// CHECK-NEXT: %18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }">
// CHECK-NEXT: %19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }">
-// CHECK-NEXT: %20 = llvm.constant(@foo : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">) : !llvm<"{ i32, double, i32 } (i32)*">
+// CHECK-NEXT: %20 = llvm.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
// CHECK-NEXT: %21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
%17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
%18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }">
%19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }">
- %20 = llvm.constant(@foo : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">) : !llvm<"{ i32, double, i32 } (i32)*">
+ %20 = llvm.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
%21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
// CHECK-LABEL: define void @indirect_const_call(i64) {
func @indirect_const_call(%arg0: !llvm.i64) {
// CHECK-NEXT: call void @body(i64 %0)
- %0 = llvm.constant(@body : (!llvm.i64) -> ()) : !llvm<"void (i64)*">
+ %0 = llvm.constant(@body) : !llvm<"void (i64)*">
llvm.call %0(%arg0) : (!llvm.i64) -> ()
// CHECK-NEXT: ret void
llvm.return
// CHECK: APFloat BOp::f64_attr()
// CHECK: StringRef BOp::str_attr()
// CHECK: ElementsAttr BOp::elements_attr()
-// CHECK: Function *BOp::function_attr()
+// CHECK: StringRef BOp::function_attr()
// CHECK: SomeType BOp::type_attr()
// CHECK: ArrayAttr BOp::array_attr()
// CHECK: ArrayAttr BOp::some_attr_array()