Extend conversion of SubViewOp to llvm to also support cases where size and stride
authorStephan Herhut <herhut@google.com>
Tue, 3 Dec 2019 13:11:20 +0000 (05:11 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Dec 2019 13:11:49 +0000 (05:11 -0800)
are constant (i.e., there are no size and stride operands).

We recently added canonicalization that rewrites constant size and stride operands to
SubViewOp into static information in the type, so these patterns now occur during code
generation.

PiperOrigin-RevId: 283524688

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

index 2db02db..0d93220 100644 (file)
@@ -1506,10 +1506,12 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
     if (!sourceElementTy || !targetDescTy)
       return matchFailure();
 
-    // Early exit for 0-D and operands lesser than `rank` corner cases.
+    // Currently, only rank > 0 and full or no operands are supported. Fail to
+    // convert otherwise.
     unsigned rank = sourceMemRefType.getRank();
-    if (viewMemRefType.getRank() == 0 || rank != dynamicOffsets.size() ||
-        rank != dynamicSizes.size() || rank != dynamicStrides.size())
+    if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) ||
+        (!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
+        (!dynamicStrides.empty() && rank != dynamicStrides.size()))
       return matchFailure();
 
     int64_t offset;
@@ -1539,6 +1541,17 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
     for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
 
+    // Fill in missing dynamic sizes.
+    auto llvmIndexType = lowering.convertType(rewriter.getIndexType());
+    if (dynamicSizes.empty()) {
+      dynamicSizes.reserve(viewMemRefType.getRank());
+      auto shape = viewMemRefType.getShape();
+      for (auto extent : shape) {
+        dynamicSizes.push_back(rewriter.create<LLVM::ConstantOp>(
+            loc, llvmIndexType, rewriter.getI64IntegerAttr(extent)));
+      }
+    }
+
     // Offset.
     Value *baseOffset = sourceMemRef.offset(rewriter, loc);
     for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
@@ -1552,9 +1565,14 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
     // Update sizes and strides.
     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
       targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]);
-      targetMemRef.setStride(rewriter, loc, i,
-                             rewriter.create<LLVM::MulOp>(
-                                 loc, dynamicStrides[i], strideValues[i]));
+      Value *newStride;
+      if (dynamicStrides.empty())
+        newStride = rewriter.create<LLVM::ConstantOp>(
+            loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
+      else
+        newStride = rewriter.create<LLVM::MulOp>(loc, dynamicStrides[i],
+                                                 strideValues[i]);
+      targetMemRef.setStride(rewriter, loc, i, newStride);
     }
 
     rewriter.replaceOp(op, {targetMemRef});
index 2960471..5c50ed8 100644 (file)
@@ -734,3 +734,57 @@ func @subview(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, %
     memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
   return
 }
+
+// CHECK-LABEL: func @subview_const_size(
+// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
+func @subview_const_size(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, %arg1 : index, %arg2 : index) {
+  // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+  // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
+  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64)
+  // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64
+  // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i64
+  // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64
+  // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i64
+  // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64
+  // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64
+  // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  %1 = subview %0[%arg0, %arg1][][%arg0, %arg1] :
+    memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref<4x2xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
+  return
+}
+
+// CHECK-LABEL: func @subview_const_stride(
+// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
+func @subview_const_stride(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, %arg1 : index, %arg2 : index) {
+  // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">
+  // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64
+  // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i64
+  // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64
+  // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i64
+  // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64)
+  // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
+  // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
+  %1 = subview %0[%arg0, %arg1][%arg0, %arg1][] :
+    memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 * 2 + s0)>
+  return
+}