From 2125c0e3a8ec1e0abaadfac8cd9f0f1d8574d952 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Tue, 3 Dec 2019 05:11:20 -0800 Subject: [PATCH] Extend conversion of SubViewOp to llvm to also support cases where size and stride are constant (i.e., there are no size and stride operands). We recently added canonicalization that rewrites constant size and stride operands to SubViewOp into static information in the type, so these patterns now occur during code generation. PiperOrigin-RevId: 283524688 --- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 30 +++++++++--- .../StandardToLLVM/convert-to-llvmir.mlir | 54 ++++++++++++++++++++++ 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 2db02db..0d93220 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1506,10 +1506,12 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { if (!sourceElementTy || !targetDescTy) return matchFailure(); - // Early exit for 0-D and operands lesser than `rank` corner cases. + // Currently, only rank > 0 and full or no operands are supported. Fail to + // convert otherwise. unsigned rank = sourceMemRefType.getRank(); - if (viewMemRefType.getRank() == 0 || rank != dynamicOffsets.size() || - rank != dynamicSizes.size() || rank != dynamicStrides.size()) + if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) || + (!dynamicSizes.empty() && rank != dynamicSizes.size()) || + (!dynamicStrides.empty() && rank != dynamicStrides.size())) return matchFailure(); int64_t offset; @@ -1539,6 +1541,17 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); + // Fill in missing dynamic sizes. + auto llvmIndexType = lowering.convertType(rewriter.getIndexType()); + if (dynamicSizes.empty()) { + dynamicSizes.reserve(viewMemRefType.getRank()); + auto shape = viewMemRefType.getShape(); + for (auto extent : shape) { + dynamicSizes.push_back(rewriter.create( + loc, llvmIndexType, rewriter.getI64IntegerAttr(extent))); + } + } + // Offset. Value *baseOffset = sourceMemRef.offset(rewriter, loc); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { @@ -1552,9 +1565,14 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]); - targetMemRef.setStride(rewriter, loc, i, - rewriter.create( - loc, dynamicStrides[i], strideValues[i])); + Value *newStride; + if (dynamicStrides.empty()) + newStride = rewriter.create( + loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); + else + newStride = rewriter.create(loc, dynamicStrides[i], + strideValues[i]); + targetMemRef.setStride(rewriter, loc, i, newStride); } rewriter.replaceOp(op, {targetMemRef}); diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 2960471..5c50ed8 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -734,3 +734,57 @@ func @subview(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, % memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref (d0 * s1 + d1 * s2 + s0)> return } + +// CHECK-LABEL: func @subview_const_size( +// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64 +func @subview_const_size(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, %arg1 : index, %arg2 : index) { + // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> + // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) + // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) + // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64 + // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i64 + // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64 + // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i64 + // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64 + // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64 + // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %1 = subview %0[%arg0, %arg1][][%arg0, %arg1] : + memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref<4x2xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> + return +} + +// CHECK-LABEL: func @subview_const_stride( +// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64 +func @subview_const_stride(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, %arg1 : index, %arg2 : index) { + // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> + // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64 + // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i64 + // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64 + // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i64 + // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) + // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) + // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %1 = subview %0[%arg0, %arg1][%arg0, %arg1][] : + memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref (d0 * 4 + d1 * 2 + s0)> + return +} -- 2.7.4