LLVM IR lowering: support 1D vector operations
authorAlex Zinenko <zinenko@google.com>
Wed, 12 Dec 2018 14:11:33 +0000 (06:11 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:26:37 +0000 (14:26 -0700)
Introduce initial support for 1D vector operations.  LLVM does not support
higher-dimensional vectors so the caller must make sure they don't appear in
the input MLIR.  Handle the presence of higher-dimensional vectors by failing
gracefully.

Introduce the type conversion for 1D vector types and hook it up with the rest
of the type convresion system.  Support "splat" constants for vector types.  As
a side effect, this refactors constant operation emission by separating out
scalar integer constants into a separate case and by extracting out the helper
function for scalar float construction.  Existing binary operations apply to
vectors transparently.

PiperOrigin-RevId: 225172349

mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
mlir/test/Target/llvmir.mlir

index eb427c5c2b3fa746acfa0499539d69d8b8dead31..57c1d94d6159feeb2df254a875f7e1a978cfb349 100644 (file)
@@ -77,6 +77,9 @@ private:
   /// `memref<42x?x42xi32>` is converted to `{i32*, i64}` (only one size is
   /// dynamic); `memref<2x3x4xf64>` is converted to `{double*}`.
   llvm::StructType *convertMemRefType(MemRefType type);
+
+  /// Convert a 1D vector type to an LLVM vector type.
+  llvm::VectorType *convertVectorType(VectorType type);
   /// \}
 
   /// Convert a list of types to an LLVM type suitable for being returned from a
@@ -124,6 +127,12 @@ private:
   /// instruction) on success and nullptr on error.
   llvm::Value *emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp);
 
+  /// Emit a constant splat operation, i.e. an operation that broadcasts a
+  /// single value to a vector.  The `op` must have an attribute `value` of
+  /// SplatElementsAttr type.  Return an LLVM SSA value of the constant vector;
+  /// return `nullptr` in case of errors.
+  llvm::Value *emitConstantSplat(const ConstantOp &op);
+
   /// Create a single LLVM value of struct type that includes the list of
   /// given MLIR values.  The `values` list must contain at least 2 elements.
   llvm::Value *packValues(ArrayRef<const SSAValue *> values);
@@ -162,7 +171,7 @@ llvm::Type *ModuleLowerer::convertFloatType(FloatType type) {
     return builder.getHalfTy();
   case Type::Kind::BF16:
     return context->emitError(UnknownLoc::get(context),
-                              "Unsupported type: BF16"),
+                              "unsupported type: BF16"),
            nullptr;
   default:
     llvm_unreachable("non-float type in convertFloatType");
@@ -226,6 +235,23 @@ llvm::StructType *ModuleLowerer::convertMemRefType(MemRefType type) {
   return llvm::StructType::get(llvmContext, types);
 }
 
+// Convert a 1D vector type to an LLVM vector type.
+llvm::VectorType *ModuleLowerer::convertVectorType(VectorType type) {
+  if (type.getRank() != 1) {
+    MLIRContext *context = type.getContext();
+    context->emitError(UnknownLoc::get(context),
+                       "only 1D vectors are supported");
+    return nullptr;
+  }
+
+  llvm::Type *elementType = convertType(type.getElementType());
+  if (!elementType) {
+    return nullptr;
+  }
+
+  return llvm::VectorType::get(elementType, type.getShape().front());
+}
+
 llvm::Type *ModuleLowerer::convertType(Type type) {
   if (auto funcType = type.dyn_cast<FunctionType>())
     return convertFunctionType(funcType);
@@ -237,6 +263,8 @@ llvm::Type *ModuleLowerer::convertType(Type type) {
     return convertIndexType(indexType);
   if (auto memRefType = type.dyn_cast<MemRefType>())
     return convertMemRefType(memRefType);
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return convertVectorType(vectorType);
 
   MLIRContext *context = type.getContext();
   std::string message;
@@ -392,6 +420,44 @@ ModuleLowerer::emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp) {
   return builder.CreateCall(freeFunc, data);
 }
 
+// Return an LLVM constant of the `float` type for the given APvalue.
+// This forcibly recreates the APFloat with IEEESingle semantics to make sure
+// LLVM constructs a `float` constant.
+static llvm::ConstantFP *getFloatConstant(APFloat APvalue,
+                                          const Operation &inst,
+                                          llvm::LLVMContext *context) {
+  bool unused;
+  APFloat::opStatus status = APvalue.convert(
+      llvm::APFloat::IEEEsingle(), llvm::APFloat::rmTowardZero, &unused);
+  if (status == APFloat::opInexact) {
+    inst.emitWarning("lossy conversion of a float constant to the float type");
+    // No return intended.
+  }
+  if (status != APFloat::opOK)
+    return inst.emitError("failed to convert a floating point constant"),
+           nullptr;
+  auto value = APvalue.convertToFloat();
+  return llvm::ConstantFP::get(*context, APFloat(value));
+}
+
+llvm::Value *ModuleLowerer::emitConstantSplat(const ConstantOp &op) {
+  auto splatAttr = op.getValue().dyn_cast<SplatElementsAttr>();
+  assert(splatAttr && "expected a splat constant");
+
+  auto floatAttr = splatAttr.getValue().dyn_cast<FloatAttr>();
+  if (!floatAttr)
+    return op.emitError("NYI: only float splats are currently supported"),
+           nullptr;
+
+  llvm::Constant *cst =
+      getFloatConstant(floatAttr.getValue(), *op.getOperation(), &llvmContext);
+  if (!cst)
+    return nullptr;
+
+  auto nElements = op.getType().cast<VectorType>().getShape()[0];
+  return llvm::ConstantVector::getSplat(nElements, cst);
+}
+
 // Create an undef struct value and insert individual values into it.
 llvm::Value *ModuleLowerer::packValues(ArrayRef<const SSAValue *> values) {
   assert(values.size() > 1 && "cannot pack less than 2 values");
@@ -492,35 +558,43 @@ bool ModuleLowerer::convertInstruction(const Instruction &inst) {
     // type of the constant.  This should be fixed at the parser level.
     if (!type->isFloatTy())
       return inst.emitError("NYI: only floats are currently supported");
-    bool unused;
+
     auto APvalue = constantOp->getValue();
-    APFloat::opStatus status = APvalue.convert(
-        llvm::APFloat::IEEEsingle(), llvm::APFloat::rmTowardZero, &unused);
-    if (status == APFloat::opInexact) {
-      inst.emitWarning(
-          "Lossy conversion of a float constant to the float type");
-      // No return intended.
-    }
-    if (status != APFloat::opOK)
-      return inst.emitError("Failed to convert a floating point constant");
-    auto value = APvalue.convertToFloat();
-    valueMapping[constantOp->getResult()] =
-        llvm::ConstantFP::get(type->getContext(), llvm::APFloat(value));
+    auto llvmValue = getFloatConstant(APvalue, inst, &type->getContext());
+    if (!llvmValue)
+      return true;
+
+    valueMapping[constantOp->getResult()] = llvmValue;
     return false;
   }
-  if (auto constantOp = inst.dyn_cast<ConstantOp>()) {
+  if (auto constantOp = inst.dyn_cast<ConstantIntOp>()) {
     llvm::Type *type = convertType(constantOp->getType());
     if (!type)
       return true;
-    if (!isa<llvm::IntegerType>(type))
-      return inst.emitError("only integer types are supported");
-    auto attr = (constantOp->getValue()).cast<IntegerAttr>();
+
     // Create a new APInt even if we can extract one from the attribute, because
     // attributes are currently hardcoded to be 64-bit APInts and LLVM will
     // create an i64 constant from those.
+    auto value = constantOp->getValue();
     valueMapping[constantOp->getResult()] = llvm::Constant::getIntegerValue(
-        type, llvm::APInt(type->getIntegerBitWidth(), attr.getInt()));
+        type, APInt(type->getIntegerBitWidth(), value));
+    return false;
+  }
+  if (auto constantOp = inst.dyn_cast<ConstantOp>()) {
+    llvm::Type *type = convertType(constantOp->getType());
+    if (!type)
+      return true;
+    if (!isa<llvm::VectorType>(type))
+      return inst.emitError("unsupported constant type");
 
+    auto constantValue = constantOp->getValue();
+    if (!constantValue.isa<SplatElementsAttr>())
+      return inst.emitError("NYI: non-splat vector constants");
+
+    llvm::Value *llvmValue = emitConstantSplat(*constantOp);
+    if (!llvmValue)
+      return true;
+    valueMapping[constantOp->getResult()] = llvmValue;
     return false;
   }
 
index 7bd51e26cbc617892bf744e0b43837caf5c89615..99e6073ba9b7a4786cc756c609eb181951efe7c9 100644 (file)
@@ -597,3 +597,14 @@ bb0:
   %6 = load %0#2 [%5, %5, %5, %5] : memref<42x?x10x?xf32>
   return
 }
+
+// CHECK-LABEL: define <4 x float> @vector_ops(<4 x float>) {
+// CHECK-NEXT:    %2 = fadd <4 x float> %0, <float 4.200000e+01, float 4.200000e+01, float 4.200000e+01, float 4.200000e+01>
+// CHECK-NEXT:    ret <4 x float> %2
+// CHECK-NEXT:  }
+cfgfunc @vector_ops(vector<4xf32>) -> vector<4xf32> {
+bb0(%arg0 : vector<4xf32>):
+  %0 = constant splat<vector<4xf32>, 42.> : vector<4xf32>
+  %1 = addf %arg0, %0 : vector<4xf32>
+  return %1 : vector<4xf32>
+}