[mlir][Standard] Extend n-D vector lowering to LLVM to [s|z]exti ops.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 2 Feb 2021 07:41:07 +0000 (07:41 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 2 Feb 2021 07:45:50 +0000 (07:45 +0000)
[s|z]exti ops do not have the same operand and result type.
As a consequence, the lowering of the n-D vector form needs to be relaxed a bit.
This revision additionally performs a few NFC renamings of variables to make them more intuitive.

Differential Revision: https://reviews.llvm.org/D95760

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir [new file with mode: 0644]
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

index 357bd2f..90ebd94 100644 (file)
@@ -656,9 +656,6 @@ public:
     static_assert(
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
-    static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
-                                  SourceOp>::value,
-                  "expected same operands and result type");
     return LLVM::detail::vectorOneToOneRewrite(
         op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
         rewriter);
index c59cbda..bb6376c 100644 (file)
@@ -1472,10 +1472,10 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
 // 1-D LLVM vectors.
 struct NDVectorTypeInfo {
   // LLVM array struct which encodes n-D vectors.
-  Type llvmArrayTy;
+  Type llvmNDVectorTy;
   // LLVM vector type which encodes the inner 1-D vector type.
-  Type llvmVectorTy;
-  // Multiplicity of llvmArrayTy to llvmVectorTy.
+  Type llvm1DVectorTy;
+  // Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
   SmallVector<int64_t, 4> arraySizes;
 };
 } // namespace
@@ -1488,13 +1488,13 @@ static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
                                                 LLVMTypeConverter &converter) {
   assert(vectorType.getRank() > 1 && "expected >1D vector type");
   NDVectorTypeInfo info;
-  info.llvmArrayTy = converter.convertType(vectorType);
-  if (!info.llvmArrayTy || !LLVM::isCompatibleType(info.llvmArrayTy)) {
-    info.llvmArrayTy = nullptr;
+  info.llvmNDVectorTy = converter.convertType(vectorType);
+  if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
+    info.llvmNDVectorTy = nullptr;
     return info;
   }
   info.arraySizes.reserve(vectorType.getRank() - 1);
-  auto llvmTy = info.llvmArrayTy;
+  auto llvmTy = info.llvmNDVectorTy;
   while (llvmTy.isa<LLVM::LLVMArrayType>()) {
     info.arraySizes.push_back(
         llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
@@ -1502,7 +1502,7 @@ static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
   }
   if (!LLVM::isCompatibleVectorType(llvmTy))
     return info;
-  info.llvmVectorTy = llvmTy;
+  info.llvm1DVectorTy = llvmTy;
   return info;
 }
 
@@ -1591,27 +1591,29 @@ static LogicalResult handleMultidimensionalVectors(
     Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
     std::function<Value(Type, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter) {
-  auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
-  if (!vectorType)
-    return failure();
-  auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
-  auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
-  auto llvmArrayTy = operands[0].getType();
-  if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
-    return failure();
-
+  auto operandNDVectorType = op->getOperand(0).getType().dyn_cast<VectorType>();
+  auto resultNDVectorType = op->getResult(0).getType().dyn_cast<VectorType>();
+  assert(operandNDVectorType && resultNDVectorType && "expected vector types");
+
+  auto resultTypeInfo =
+      extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
+  auto operandTypeInfo =
+      extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
+  auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
+  auto operand1DVectorTy = operandTypeInfo.llvm1DVectorTy;
+  auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
   auto loc = op->getLoc();
-  Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
-  nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+  Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
+  nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
     // For this unrolled `position` corresponding to the `linearIndex`^th
     // element, extract operand vectors
     SmallVector<Value, 4> extractedOperands;
     for (auto operand : operands)
       extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
-          loc, llvmVectorTy, operand, position));
-    Value newVal = createOperand(llvmVectorTy, extractedOperands);
-    desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal,
-                                                position);
+          loc, operand1DVectorTy, operand, position));
+    Value newVal = createOperand(result1DVectorTy, extractedOperands);
+    desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
+                                                newVal, position);
   });
   rewriter.replaceOp(op, desc);
   return success();
@@ -1627,14 +1629,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
                     [](Type t) { return isCompatibleType(t); }))
     return failure();
 
-  auto llvmArrayTy = operands[0].getType();
-  if (!llvmArrayTy.isa<LLVM::LLVMArrayType>())
+  auto llvmNDVectorTy = operands[0].getType();
+  if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
     return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
 
-  auto callback = [op, targetOp, &rewriter](Type llvmVectorTy,
+  auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
                                             ValueRange operands) {
     OperationState state(op->getLoc(), targetOp);
-    state.addTypes(llvmVectorTy);
+    state.addTypes(llvm1DVectorTy);
     state.addOperands(operands);
     state.addAttributes(op->getAttrs());
     return rewriter.createOperation(state)->getResult(0);
@@ -1668,6 +1670,8 @@ using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
 using PowFOpLowering = VectorConvertToLLVMPattern<PowFOp, LLVM::PowOp>;
 using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
 using SelectOpLowering = OneToOneConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
+using SignExtendIOpLowering =
+    VectorConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp>;
 using ShiftLeftOpLowering =
     OneToOneConvertToLLVMPattern<ShiftLeftOp, LLVM::ShlOp>;
 using SignedDivIOpLowering =
@@ -1687,6 +1691,8 @@ using UnsignedRemIOpLowering =
 using UnsignedShiftRightOpLowering =
     OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
 using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
+using ZeroExtendIOpLowering =
+    VectorConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp>;
 
 /// Lower `std.assert`. The default lowering calls the `abort` function if the
 /// assertion is violated and has no effect otherwise. The failure message is
@@ -2366,17 +2372,17 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
 
     return handleMultidimensionalVectors(
         op.getOperation(), operands, *getTypeConverter(),
-        [&](Type llvmVectorTy, ValueRange operands) {
+        [&](Type llvm1DVectorTy, ValueRange operands) {
           auto splatAttr = SplatElementsAttr::get(
               mlir::VectorType::get(
-                  {LLVM::getVectorNumElements(llvmVectorTy).getFixedValue()},
+                  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
                   floatType),
               floatOne);
           auto one =
-              rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
+              rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
           auto sqrt =
-              rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
-          return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
+              rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
+          return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
         },
         rewriter);
   }
@@ -3050,21 +3056,11 @@ struct FPTruncLowering
   using Super::Super;
 };
 
-struct SignExtendIOpLowering
-    : public OneToOneConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp> {
-  using Super::Super;
-};
-
 struct TruncateIOpLowering
     : public OneToOneConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp> {
   using Super::Super;
 };
 
-struct ZeroExtendIOpLowering
-    : public OneToOneConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp> {
-  using Super::Super;
-};
-
 // Base class for LLVM IR lowering terminator operations with successors.
 template <typename SourceOp, typename TargetOp>
 struct OneToOneLLVMTerminatorLowering
@@ -3211,21 +3207,21 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
     auto loc = splatOp.getLoc();
     auto vectorTypeInfo =
         extractNDVectorTypeInfo(resultType, *getTypeConverter());
-    auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
-    auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
-    if (!llvmArrayTy || !llvmVectorTy)
+    auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
+    auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
+    if (!llvmNDVectorTy || !llvm1DVectorTy)
       return failure();
 
     // Construct returned value.
-    Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+    Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
 
     // Construct a 1-D vector with the splatted value that we insert in all the
     // places within the returned descriptor.
-    Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
+    Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
     auto zero = rewriter.create<LLVM::ConstantOp>(
         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
-    Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
+    Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
                                                      adaptor.input(), zero);
 
     // Shuffle the value across the desired number of elements.
@@ -3237,7 +3233,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
     // Iterate of linear index, convert to coords space and insert splatted 1-D
     // vector in each position.
     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
-      desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
+      desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
                                                   position);
     });
     rewriter.replaceOp(splatOp, desc);
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir
new file mode 100644 (file)
index 0000000..dce630d
--- /dev/null
@@ -0,0 +1,49 @@
+// RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @vec_bin
+func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+  %0 = addf %arg0, %arg0 : vector<2x2x2xf32>
+  return %0 : vector<2x2x2xf32>
+
+//  CHECK-NEXT: llvm.mlir.undef : !llvm.array<2 x array<2 x vector<2xf32>>>
+
+// This block appears 2x2 times
+//  CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+//  CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+//  CHECK-NEXT: llvm.fadd %{{.*}} : vector<2xf32>
+//  CHECK-NEXT: llvm.insertvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+
+// We check the proper indexing of extract/insert in the remaining 3 positions.
+//       CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
+//       CHECK: llvm.insertvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
+//       CHECK: llvm.extractvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+//       CHECK: llvm.insertvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
+//       CHECK: llvm.extractvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
+//       CHECK: llvm.insertvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
+}
+
+// CHECK-LABEL: @sexti
+func @sexti_vector(%arg0 : vector<1x2x3xi32>, %arg1 : vector<1x2x3xi64>) {
+  // CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xi64>>>
+  // CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>>
+  // CHECK: llvm.sext %{{.*}} : vector<3xi32> to vector<3xi64>
+  // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi64>>>
+  // CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>>
+  // CHECK: llvm.sext %{{.*}} : vector<3xi32> to vector<3xi64>
+  // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi64>>>
+  %0 = sexti %arg0: vector<1x2x3xi32> to vector<1x2x3xi64>
+  return
+}
+
+// CHECK-LABEL: @zexti
+func @zexti_vector(%arg0 : vector<1x2x3xi32>, %arg1 : vector<1x2x3xi64>) {
+  // CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xi64>>>
+  // CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>>
+  // CHECK: llvm.zext %{{.*}} : vector<3xi32> to vector<3xi64>
+  // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi64>>>
+  // CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>>
+  // CHECK: llvm.zext %{{.*}} : vector<3xi32> to vector<3xi64>
+  // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi64>>>
+  %0 = zexti %arg0: vector<1x2x3xi32> to vector<1x2x3xi64>
+  return
+}
index 749d733..5081e7b 100644 (file)
@@ -766,31 +766,6 @@ func @fcmp(f32, f32) -> () {
   return
 }
 
-// CHECK-LABEL: @vec_bin
-func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
-  %0 = addf %arg0, %arg0 : vector<2x2x2xf32>
-  return %0 : vector<2x2x2xf32>
-
-//  CHECK-NEXT: llvm.mlir.undef : !llvm.array<2 x array<2 x vector<2xf32>>>
-
-// This block appears 2x2 times
-//  CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-//  CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-//  CHECK-NEXT: llvm.fadd %{{.*}} : vector<2xf32>
-//  CHECK-NEXT: llvm.insertvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-
-// We check the proper indexing of extract/insert in the remaining 3 positions.
-//       CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
-//       CHECK: llvm.insertvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
-//       CHECK: llvm.extractvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-//       CHECK: llvm.insertvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>>
-//       CHECK: llvm.extractvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
-//       CHECK: llvm.insertvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>>
-
-// And we're done
-//   CHECK-NEXT: return
-}
-
 // CHECK-LABEL: @splat
 // CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32>
 // CHECK-SAME: %[[ELT:arg[0-9]+]]: f32