}
Type TypeConverter::convertFloatType(FloatType type) {
- MLIRContext *context = type.getContext();
switch (type.getKind()) {
case mlir::StandardTypes::F32:
return wrap(builder.getFloatTy());
case mlir::StandardTypes::F16:
return wrap(builder.getHalfTy());
case mlir::StandardTypes::BF16:
- return context->emitError(UnknownLoc::get(context),
- "unsupported type: BF16"),
+ return mlirContext->emitError(UnknownLoc::get(mlirContext),
+ "unsupported type: BF16"),
Type();
default:
llvm_unreachable("non-float type in convertFloatType");
// If function does not return anything, return immediately.
if (type.getNumResults() == 0)
- return FunctionType::get(argTypes, {}, type.getContext());
+ return FunctionType::get(argTypes, {}, mlirContext);
// Otherwise pack the result types into a struct.
if (auto result = getPackedResultType(type.getResults()))
- return FunctionType::get(argTypes, {result}, type.getContext());
+ return FunctionType::get(argTypes, {result}, mlirContext);
return {};
}
// Convert a 1D vector type to an LLVM vector type.
Type TypeConverter::convertVectorType(VectorType type) {
if (type.getRank() != 1) {
- MLIRContext *context = type.getContext();
- context->emitError(UnknownLoc::get(context),
- "only 1D vectors are supported");
+ mlirContext->emitError(UnknownLoc::get(mlirContext),
+ "only 1D vectors are supported");
return {};
}
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
return llvmType;
- MLIRContext *context = type.getContext();
std::string message;
llvm::raw_string_ostream os(message);
os << "unsupported type: ";
type.print(os);
- context->emitError(UnknownLoc::get(context), os.str());
+ mlirContext->emitError(UnknownLoc::get(mlirContext), os.str());
return {};
}
// TODO: Instead of adding all known patterns from the whole system lazily add
// and cache the canonicalization patterns for ops we see in practice when
// building the worklist. For now, we just grab everything.
- auto *context = func.getContext();
+ auto *context = &getContext();
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(patterns, context);