Add LLVM lowering of std.subview
authorNicolas Vasilache <ntv@google.com>
Tue, 12 Nov 2019 15:22:51 +0000 (07:22 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 12 Nov 2019 15:23:18 +0000 (07:23 -0800)
A followup CL will replace usage of linalg.subview by std.subview.

PiperOrigin-RevId: 279961981

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

index 84eba82..791a237 100644 (file)
@@ -1412,6 +1412,118 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
   }
 };
 
+/// Conversion pattern that transforms a subview op into:
+///   1. An `llvm.mlir.undef` operation to create a memref descriptor
+///   2. Updates to the descriptor to introduce the data ptr, offset, size
+///      and stride.
+/// The subview op is replaced by the descriptor.
+struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
+  using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto viewOp = cast<SubViewOp>(op);
+    SubViewOpOperandAdaptor adaptor(operands);
+    auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
+    auto sourceElementTy =
+        lowering.convertType(sourceMemRefType.getElementType())
+            .dyn_cast_or_null<LLVM::LLVMType>();
+
+    auto viewMemRefType = viewOp.getType();
+    auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
+                               .dyn_cast<LLVM::LLVMType>();
+    auto targetDescTy =
+        lowering.convertType(viewMemRefType).dyn_cast_or_null<LLVM::LLVMType>();
+    if (!sourceElementTy || !targetDescTy)
+      return matchFailure();
+
+    // Early exit for 0-D and operands lesser than `rank` corner cases.
+    unsigned rank = sourceMemRefType.getRank();
+    if (viewMemRefType.getRank() == 0 || rank != adaptor.offsets().size() ||
+        rank != adaptor.sizes().size() || rank != adaptor.strides().size())
+      return matchFailure();
+
+    int64_t offset;
+    SmallVector<int64_t, 4> strides;
+    auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
+    if (failed(successStrides))
+      return matchFailure();
+
+    // Create the descriptor.
+    Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
+
+    // Copy the buffer pointer from the old descriptor to the new one.
+    Value *sourceDescriptor = adaptor.source();
+    Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
+        loc, sourceElementTy.getPointerTo(), sourceDescriptor,
+        rewriter.getI64ArrayAttr(
+            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+    Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+        loc, targetElementTy.getPointerTo(), extracted);
+    desc = rewriter.create<LLVM::InsertValueOp>(
+        loc, desc, bitcastPtr,
+        rewriter.getI64ArrayAttr(
+            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+    extracted = rewriter.create<LLVM::ExtractValueOp>(
+        loc, sourceElementTy.getPointerTo(), sourceDescriptor,
+        rewriter.getI64ArrayAttr(
+            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+    bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+        loc, targetElementTy.getPointerTo(), extracted);
+    desc = rewriter.create<LLVM::InsertValueOp>(
+        loc, desc, bitcastPtr,
+        rewriter.getI64ArrayAttr(
+            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+
+    // Extract strides needed to compute offset.
+    SmallVector<Value *, 4> strideValues;
+    strideValues.reserve(viewMemRefType.getRank());
+    for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
+      strideValues.push_back(rewriter.create<LLVM::ExtractValueOp>(
+          loc, getIndexType(), sourceDescriptor,
+          rewriter.getI64ArrayAttr(
+              {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})));
+    }
+
+    // Offset.
+    Value *baseOffset = rewriter.create<LLVM::ExtractValueOp>(
+        loc, getIndexType(), sourceDescriptor,
+        rewriter.getI64ArrayAttr(
+            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+    for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
+      Value *min = adaptor.offsets()[i];
+      baseOffset = rewriter.create<LLVM::AddOp>(
+          loc, baseOffset,
+          rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
+    }
+    desc = rewriter.create<LLVM::InsertValueOp>(
+        loc, desc, baseOffset,
+        rewriter.getI64ArrayAttr(
+            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+
+    // Update sizes and strides.
+    for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
+      // Update size.
+      desc = rewriter.create<LLVM::InsertValueOp>(
+          loc, desc, adaptor.sizes()[i],
+          rewriter.getI64ArrayAttr(
+              {LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
+      // Update stride.
+      desc = rewriter.create<LLVM::InsertValueOp>(
+          loc, desc,
+          rewriter.create<LLVM::MulOp>(loc, adaptor.strides()[i],
+                                       strideValues[i]),
+          rewriter.getI64ArrayAttr(
+              {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
+    }
+
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
 /// Conversion pattern that transforms a op into:
 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
@@ -1647,6 +1759,7 @@ void mlir::populateStdToLLVMConversionPatterns(
       StoreOpLowering,
       SubFOpLowering,
       SubIOpLowering,
+      SubViewOpLowering,
       TruncateIOpLowering,
       ViewOpLowering,
       XOrOpLowering,
index e0669dc..b34b345 100644 (file)
@@ -708,3 +708,29 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
 
   return
 }
+
+// CHECK-LABEL: func @subview(
+// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
+func @subview(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, %arg1 : index, %arg2 : index) {
+  // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+  // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64
+  // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i64
+  // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64
+  // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i64
+  // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64
+  // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64
+  // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  %1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] :
+    memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
+  return
+}