/// `memref<42x?x42xi32>` is converted to `{i32*, i64}` (only one size is
/// dynamic); `memref<2x3x4xf64>` is converted to `{double*}`.
llvm::StructType *convertMemRefType(MemRefType type);
+
+ /// Convert a 1D vector type to an LLVM vector type.
+ llvm::VectorType *convertVectorType(VectorType type);
/// \}
/// Convert a list of types to an LLVM type suitable for being returned from a
/// instruction) on success and nullptr on error.
llvm::Value *emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp);
+ /// Emit a constant splat operation, i.e. an operation that broadcasts a
+ /// single value to a vector. The `op` must have an attribute `value` of
+ /// SplatElementsAttr type. Return an LLVM SSA value of the constant vector;
+ /// return `nullptr` in case of errors.
+ llvm::Value *emitConstantSplat(const ConstantOp &op);
+
/// Create a single LLVM value of struct type that includes the list of
/// given MLIR values. The `values` list must contain at least 2 elements.
llvm::Value *packValues(ArrayRef<const SSAValue *> values);
return builder.getHalfTy();
case Type::Kind::BF16:
return context->emitError(UnknownLoc::get(context),
- "Unsupported type: BF16"),
+ "unsupported type: BF16"),
nullptr;
default:
llvm_unreachable("non-float type in convertFloatType");
return llvm::StructType::get(llvmContext, types);
}
+// Convert a 1D vector type to an LLVM vector type.
+llvm::VectorType *ModuleLowerer::convertVectorType(VectorType type) {
+ if (type.getRank() != 1) {
+ MLIRContext *context = type.getContext();
+ context->emitError(UnknownLoc::get(context),
+ "only 1D vectors are supported");
+ return nullptr;
+ }
+
+ llvm::Type *elementType = convertType(type.getElementType());
+ if (!elementType) {
+ return nullptr;
+ }
+
+ return llvm::VectorType::get(elementType, type.getShape().front());
+}
+
llvm::Type *ModuleLowerer::convertType(Type type) {
if (auto funcType = type.dyn_cast<FunctionType>())
return convertFunctionType(funcType);
return convertIndexType(indexType);
if (auto memRefType = type.dyn_cast<MemRefType>())
return convertMemRefType(memRefType);
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ return convertVectorType(vectorType);
MLIRContext *context = type.getContext();
std::string message;
return builder.CreateCall(freeFunc, data);
}
+// Return an LLVM constant of the `float` type for the given APvalue.
+// This forcibly recreates the APFloat with IEEESingle semantics to make sure
+// LLVM constructs a `float` constant.
+static llvm::ConstantFP *getFloatConstant(APFloat APvalue,
+ const Operation &inst,
+ llvm::LLVMContext *context) {
+ bool unused;
+ APFloat::opStatus status = APvalue.convert(
+ llvm::APFloat::IEEEsingle(), llvm::APFloat::rmTowardZero, &unused);
+ if (status == APFloat::opInexact) {
+ inst.emitWarning("lossy conversion of a float constant to the float type");
+ // No return intended.
+ }
+ if (status != APFloat::opOK)
+ return inst.emitError("failed to convert a floating point constant"),
+ nullptr;
+ auto value = APvalue.convertToFloat();
+ return llvm::ConstantFP::get(*context, APFloat(value));
+}
+
+llvm::Value *ModuleLowerer::emitConstantSplat(const ConstantOp &op) {
+ auto splatAttr = op.getValue().dyn_cast<SplatElementsAttr>();
+ assert(splatAttr && "expected a splat constant");
+
+ auto floatAttr = splatAttr.getValue().dyn_cast<FloatAttr>();
+ if (!floatAttr)
+ return op.emitError("NYI: only float splats are currently supported"),
+ nullptr;
+
+ llvm::Constant *cst =
+ getFloatConstant(floatAttr.getValue(), *op.getOperation(), &llvmContext);
+ if (!cst)
+ return nullptr;
+
+ auto nElements = op.getType().cast<VectorType>().getShape()[0];
+ return llvm::ConstantVector::getSplat(nElements, cst);
+}
+
// Create an undef struct value and insert individual values into it.
llvm::Value *ModuleLowerer::packValues(ArrayRef<const SSAValue *> values) {
assert(values.size() > 1 && "cannot pack less than 2 values");
// type of the constant. This should be fixed at the parser level.
if (!type->isFloatTy())
return inst.emitError("NYI: only floats are currently supported");
- bool unused;
+
auto APvalue = constantOp->getValue();
- APFloat::opStatus status = APvalue.convert(
- llvm::APFloat::IEEEsingle(), llvm::APFloat::rmTowardZero, &unused);
- if (status == APFloat::opInexact) {
- inst.emitWarning(
- "Lossy conversion of a float constant to the float type");
- // No return intended.
- }
- if (status != APFloat::opOK)
- return inst.emitError("Failed to convert a floating point constant");
- auto value = APvalue.convertToFloat();
- valueMapping[constantOp->getResult()] =
- llvm::ConstantFP::get(type->getContext(), llvm::APFloat(value));
+ auto llvmValue = getFloatConstant(APvalue, inst, &type->getContext());
+ if (!llvmValue)
+ return true;
+
+ valueMapping[constantOp->getResult()] = llvmValue;
return false;
}
- if (auto constantOp = inst.dyn_cast<ConstantOp>()) {
+ if (auto constantOp = inst.dyn_cast<ConstantIntOp>()) {
llvm::Type *type = convertType(constantOp->getType());
if (!type)
return true;
- if (!isa<llvm::IntegerType>(type))
- return inst.emitError("only integer types are supported");
- auto attr = (constantOp->getValue()).cast<IntegerAttr>();
+
// Create a new APInt even if we can extract one from the attribute, because
// attributes are currently hardcoded to be 64-bit APInts and LLVM will
// create an i64 constant from those.
+ auto value = constantOp->getValue();
valueMapping[constantOp->getResult()] = llvm::Constant::getIntegerValue(
- type, llvm::APInt(type->getIntegerBitWidth(), attr.getInt()));
+ type, APInt(type->getIntegerBitWidth(), value));
+ return false;
+ }
+ if (auto constantOp = inst.dyn_cast<ConstantOp>()) {
+ llvm::Type *type = convertType(constantOp->getType());
+ if (!type)
+ return true;
+ if (!isa<llvm::VectorType>(type))
+ return inst.emitError("unsupported constant type");
+ auto constantValue = constantOp->getValue();
+ if (!constantValue.isa<SplatElementsAttr>())
+ return inst.emitError("NYI: non-splat vector constants");
+
+ llvm::Value *llvmValue = emitConstantSplat(*constantOp);
+ if (!llvmValue)
+ return true;
+ valueMapping[constantOp->getResult()] = llvmValue;
return false;
}