[mlir] subview op lowering for target memrefs with const offset
authorTobias Gysi <tobias.gysi@inf.ethz.ch>
Mon, 10 Feb 2020 16:29:50 +0000 (17:29 +0100)
committerAlex Zinenko <zinenko@google.com>
Mon, 10 Feb 2020 16:35:17 +0000 (17:35 +0100)
The current standard to llvm conversion pass lowers subview ops only if
dynamic offsets are provided. This commit extends the lowering with a
code path that uses the constant offset of the target memref for the
subview op lowering (see Example 3 of the subview op definition for an
example) if no dynamic offsets are provided.

Differential Revision: https://reviews.llvm.org/D74280

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

index 63c8b75f197f63c6944113759803b838cb33b1b3..36c98d0e85b156c39f366631a249f1aa6f81d425 100644 (file)
@@ -2304,7 +2304,8 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
     // Currently, only rank > 0 and full or no operands are supported. Fail to
     // convert otherwise.
     unsigned rank = sourceMemRefType.getRank();
-    if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) ||
+    if (viewMemRefType.getRank() == 0 ||
+        (!dynamicOffsets.empty() && rank != dynamicOffsets.size()) ||
         (!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
         (!dynamicStrides.empty() && rank != dynamicStrides.size()))
       return matchFailure();
@@ -2315,6 +2316,11 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
     if (failed(successStrides))
       return matchFailure();
 
+    // Fail to convert if neither a dynamic nor static offset is available.
+    if (dynamicOffsets.empty() &&
+        offset == MemRefType::getDynamicStrideOrOffset())
+      return matchFailure();
+
     // Create the descriptor.
     MemRefDescriptor sourceMemRef(operands.front());
     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
@@ -2348,14 +2354,18 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
     }
 
     // Offset.
-    Value baseOffset = sourceMemRef.offset(rewriter, loc);
-    for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
-      Value min = dynamicOffsets[i];
-      baseOffset = rewriter.create<LLVM::AddOp>(
-          loc, baseOffset,
-          rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
+    if (dynamicOffsets.empty()) {
+      targetMemRef.setConstantOffset(rewriter, loc, offset);
+    } else {
+      Value baseOffset = sourceMemRef.offset(rewriter, loc);
+      for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
+        Value min = dynamicOffsets[i];
+        baseOffset = rewriter.create<LLVM::AddOp>(
+            loc, baseOffset,
+            rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
+      }
+      targetMemRef.setOffset(rewriter, loc, baseOffset);
     }
-    targetMemRef.setOffset(rewriter, loc, baseOffset);
 
     // Update sizes and strides.
     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
index 302aa31e48e0a95409bd7c91a5d6eb8893f00aa7..4c5a4a078a11e7f54e8015a3e9eb32a1d041b22d 100644 (file)
@@ -815,6 +815,31 @@ func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4
   return
 }
 
+// CHECK-LABEL: func @subview_const_stride_and_offset(
+func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>) {
+  // The last "insertvalue" that populates the memref descriptor from the function arguments.
+  // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
+
+  // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[CST62:.*]] = llvm.mlir.constant(62 : i64)
+  // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i64)
+  // CHECK: %[[CST8:.*]] = llvm.mlir.constant(8 : index)
+  // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[CST8]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST3]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i64)
+  // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
+  // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  %1 = subview %0[][][] :
+    memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<62x3xf32, affine_map<(d0, d1) -> (d0 * 4 + d1 + 8)>>
+  return
+}
+
 // -----
 
 module {