Implement lowering of VectorTypeCastOp to LLVM
authorNicolas Vasilache <ntv@google.com>
Fri, 18 Oct 2019 20:48:26 +0000 (13:48 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 18 Oct 2019 21:00:06 +0000 (14:00 -0700)
A VectorTypeCastOp can only be used to lower between statically sized contiguous memrefs of scalar and matching vector type. The sizes and strides are thus fully static and easy to determine.

A relevant test is added.

This is a step towards solving tensorflow/mlir#189.

PiperOrigin-RevId: 275538981

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

index a5d0d3c..bd21c35 100644 (file)
@@ -82,6 +82,11 @@ public:
   Value *promoteOneMemRefDescriptor(Location loc, Value *operand,
                                     OpBuilder &builder);
 
+  static constexpr unsigned kPtrPosInMemRefDescriptor = 0;
+  static constexpr unsigned kOffsetPosInMemRefDescriptor = 1;
+  static constexpr unsigned kSizePosInMemRefDescriptor = 2;
+  static constexpr unsigned kStridePosInMemRefDescriptor = 3;
+
 protected:
   /// LLVM IR module used to parse/create types.
   llvm::Module *module;
index 15f61ab..490b669 100644 (file)
@@ -156,10 +156,10 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
 //   int64_t sizes[Rank]; // omitted when rank == 0
 //   int64_t strides[Rank]; // omitted when rank == 0
 // };
-static unsigned kPtrPosInMemRefDescriptor = 0;
-static unsigned kOffsetPosInMemRefDescriptor = 1;
-static unsigned kSizePosInMemRefDescriptor = 2;
-static unsigned kStridePosInMemRefDescriptor = 3;
+constexpr unsigned LLVMTypeConverter::kPtrPosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor;
 Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
   int64_t offset;
   SmallVector<int64_t, 4> strides;
@@ -282,7 +282,8 @@ public:
                                         Type elementTypePtr) {
     return builder.create<LLVM::ExtractValueOp>(
         loc, elementTypePtr, memref,
-        builder.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
+        builder.getIndexArrayAttr(
+            LLVMTypeConverter::kPtrPosInMemRefDescriptor));
   }
 
 protected:
@@ -763,11 +764,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
 
     memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
         op->getLoc(), structType, memRefDescriptor, allocated,
-        rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
+        rewriter.getIndexArrayAttr(
+            LLVMTypeConverter::kPtrPosInMemRefDescriptor));
     memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
         op->getLoc(), structType, memRefDescriptor,
         createIndexConstant(rewriter, op->getLoc(), offset),
-        rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor));
+        rewriter.getIndexArrayAttr(
+            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
 
     if (type.getRank() == 0)
       // No size/stride descriptor in memref, return the descriptor value.
@@ -798,10 +801,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
       int64_t index = indexedSize.index();
       memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
           op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
-          rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
+          rewriter.getI64ArrayAttr(
+              {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
       memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
           op->getLoc(), structType, memRefDescriptor, strideValues[index],
-          rewriter.getI64ArrayAttr({kStridePosInMemRefDescriptor, index}));
+          rewriter.getI64ArrayAttr(
+              {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
     }
 
     // Return the final value of the descriptor.
@@ -896,7 +901,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
     }
 
     auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
-    Type elementPtrType = type.getStructElementType(kPtrPosInMemRefDescriptor);
+    Type elementPtrType =
+        type.getStructElementType(LLVMTypeConverter::kPtrPosInMemRefDescriptor);
     Value *bufferPtr = extractMemRefElementPtr(
         rewriter, op->getLoc(), transformed.memref(), elementPtrType);
     Value *casted = rewriter.create<LLVM::BitcastOp>(
@@ -952,7 +958,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
     if (ShapedType::isDynamic(shape[index]))
       rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
           op, getIndexType(), transformed.memrefOrTensor(),
-          rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
+          rewriter.getI64ArrayAttr(
+              {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
     else
       // Use constant for static size.
       rewriter.replaceOp(
@@ -1015,7 +1022,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
         offset == MemRefType::getDynamicStrideOrOffset()
             ? rewriter.create<LLVM::ExtractValueOp>(
                   loc, indexTy, memRefDescriptor,
-                  rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor))
+                  rewriter.getIndexArrayAttr(
+                      LLVMTypeConverter::kOffsetPosInMemRefDescriptor))
             : this->createIndexConstant(rewriter, loc, offset);
     for (int i = 0, e = indices.size(); i < e; ++i) {
       Value *stride;
@@ -1028,7 +1036,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
         // Use dynamic stride.
         stride = rewriter.create<LLVM::ExtractValueOp>(
             loc, indexTy, memRefDescriptor,
-            rewriter.getIndexArrayAttr({kStridePosInMemRefDescriptor, i}));
+            rewriter.getIndexArrayAttr(
+                {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
       }
       Value *additionalOffset =
           rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
index 3c3a18d..765c25a 100644 (file)
@@ -155,10 +155,112 @@ public:
   }
 };
 
+class VectorTypeCastOpConversion : public LLVMOpLowering {
+public:
+  explicit VectorTypeCastOpConversion(MLIRContext *context,
+                                      LLVMTypeConverter &typeConverter)
+      : LLVMOpLowering(vector::VectorTypeCastOp::getOperationName(), context,
+                       typeConverter) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    vector::VectorTypeCastOp castOp = cast<vector::VectorTypeCastOp>(op);
+    MemRefType sourceMemRefType =
+        castOp.getOperand()->getType().cast<MemRefType>();
+    MemRefType targetMemRefType =
+        castOp.getResult()->getType().cast<MemRefType>();
+
+    // Only static shape casts supported atm.
+    if (!sourceMemRefType.hasStaticShape() ||
+        !targetMemRefType.hasStaticShape())
+      return matchFailure();
+
+    Value *sourceMemRef = operands[0];
+    auto llvmSourceDescriptorTy =
+        sourceMemRef->getType().dyn_cast<LLVM::LLVMType>();
+    if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
+      return matchFailure();
+
+    auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
+                                      .dyn_cast_or_null<LLVM::LLVMType>();
+    if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+      return matchFailure();
+
+    Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
+        LLVMTypeConverter::kPtrPosInMemRefDescriptor);
+    Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
+        LLVMTypeConverter::kPtrPosInMemRefDescriptor);
+
+    int64_t offset;
+    SmallVector<int64_t, 4> strides;
+    auto successStrides =
+        getStridesAndOffset(targetMemRefType, strides, offset);
+    bool isContiguous = (strides.back() == 1);
+    if (isContiguous) {
+      auto sizes = targetMemRefType.getShape();
+      for (int index = 0, e = strides.size() - 2; index < e; ++index) {
+        if (strides[index] != strides[index + 1] * sizes[index + 1]) {
+          isContiguous = false;
+          break;
+        }
+      }
+    }
+    // Only contiguous tensors supported atm.
+    if (failed(successStrides) || !isContiguous)
+      return matchFailure();
+
+    auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+
+    // Create descriptor.
+    Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
+    // Set ptr.
+    Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
+        loc, llvmSourceElementTy, sourceMemRef,
+        rewriter.getIndexArrayAttr(
+            LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+    ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
+    desc = rewriter.create<LLVM::InsertValueOp>(
+        op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
+        rewriter.getIndexArrayAttr(
+            LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+    // Fill offset 0.
+    auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
+    auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
+    desc = rewriter.create<LLVM::InsertValueOp>(
+        op->getLoc(), llvmTargetDescriptorTy, desc, zero,
+        rewriter.getIndexArrayAttr(
+            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+    // Fill size and stride descriptors in memref.
+    for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
+      int64_t index = indexedSize.index();
+      auto sizeAttr =
+          rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
+      auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
+      desc = rewriter.create<LLVM::InsertValueOp>(
+          op->getLoc(), llvmTargetDescriptorTy, desc, size,
+          rewriter.getI64ArrayAttr(
+              {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+      auto strideAttr =
+          rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
+      auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
+      desc = rewriter.create<LLVM::InsertValueOp>(
+          op->getLoc(), llvmTargetDescriptorTy, desc, stride,
+          rewriter.getI64ArrayAttr(
+              {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+    }
+
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
-  patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
+  patterns.insert<ExtractElementOpConversion, OuterProductOpConversion,
+                  VectorTypeCastOpConversion>(
       converter.getDialect()->getContext(), converter);
 }
 
@@ -190,5 +292,5 @@ OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
 }
 
 static PassRegistration<LowerVectorToLLVMPass>
-    pass("vector-lower-to-llvm-dialect",
+    pass("convert-vector-to-llvm",
          "Lower the operations from the vector dialect into the LLVM dialect");
index fa2345e..aefeca4 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-llvm | FileCheck %s
 
 func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> {
   %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
@@ -46,4 +46,21 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
 //       CHECK:   llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
 //       CHECK:   llvm.mlir.constant(0 : i32) : !llvm.i32
 //       CHECK:   llvm.extractelement %{{.*}}, %{{.*}} : !llvm<"<16 x float>">
-//       CHECK:   llvm.return %{{.*}} : !llvm.float
\ No newline at end of file
+//       CHECK:   llvm.return %{{.*}} : !llvm.float
+
+func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<1xvector<8x8x8xf32>> {
+  %0 = vector.type_cast %arg0: memref<8x8x8xf32>, memref<1xvector<8x8x8xf32>>
+  return %0 : memref<1xvector<8x8x8xf32>>
+}
+// CHECK-LABEL: vector_type_cast
+//       CHECK:   llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
+//       CHECK:   %[[ptr:.*]] = llvm.extractvalue {{.*}}[0 : index] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   %[[bit:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*">
+//       CHECK:   llvm.insertvalue %[[bit]], {{.*}}[0 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
+//       CHECK:   llvm.mlir.constant(0 : index
+//       CHECK:   llvm.insertvalue {{.*}}[1 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index
+//       CHECK:   llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
+//       CHECK:   llvm.mlir.constant(1 : index
+//       CHECK:   llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
+