From f6df11568e83960ef698fc979965428c6b431344 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Markus=20B=C3=B6ck?= Date: Tue, 4 Jul 2023 16:48:17 +0200 Subject: [PATCH] [mlir][llvm] add basic type consistency pattern destructuring stores This is a common pattern produced by clang and similar. Essentially, it coalesces stores into adjacent integer fields into a single integer store. This violates our definition of type-consistency that the pass is supposed to enforce and also prevents SROA and mem2reg from eliminating `alloca`s. This patch fixes that by splitting these stores into multiple stores. It does so by simply using logical shift rights and truncating the produced value to the size of the field, optionally bitcasting before storing into the field. The implementation is currently very simple, only working on struct types of a single depth and adjacent fields in that struct, with no padding inbetween. Future work could improve on these once required. Differential Revision: https://reviews.llvm.org/D154449 --- .../Dialect/LLVMIR/Transforms/TypeConsistency.h | 12 ++ .../Dialect/LLVMIR/Transforms/TypeConsistency.cpp | 197 +++++++++++++++++++-- mlir/test/Dialect/LLVMIR/type-consistency.mlir | 152 ++++++++++++++++ 3 files changed, 343 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h index 469feed..7da8b7f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -53,6 +53,18 @@ public: PatternRewriter &rewriter) const override; }; +/// Splits stores of integers which write into multiple adjacent stores +/// of a pointer. The integer is then split and stores are generated for +/// every field being stored in a type-consistent manner. +/// This is currently done on a best-effort basis. +class SplitIntegerStores : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(StoreOp store, + PatternRewriter &rewrite) const override; +}; + } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp index f696625..02eddb4 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -168,11 +168,20 @@ LogicalResult AddFieldGetterToStructDirectUse::matchAndRewrite( /// Returns the amount of bytes the provided GEP elements will offset the /// pointer by. Returns nullopt if the offset could not be computed. -static std::optional gepToByteOffset(DataLayout &layout, Type base, - ArrayRef indices) { - uint64_t offset = indices[0] * layout.getTypeSize(base); +static std::optional gepToByteOffset(DataLayout &layout, GEPOp gep) { - Type currentType = base; + SmallVector indices; + // Ensures all indices are static and fetches them. + for (auto index : gep.getIndices()) { + IntegerAttr indexInt = llvm::dyn_cast_if_present(index); + if (!indexInt) + return std::nullopt; + indices.push_back(indexInt.getInt()); + } + + uint64_t offset = indices[0] * layout.getTypeSize(gep.getSourceElementType()); + + Type currentType = gep.getSourceElementType(); for (uint32_t index : llvm::drop_begin(indices)) { bool shouldCancel = TypeSwitch(currentType) @@ -302,9 +311,9 @@ findIndicesForOffset(DataLayout &layout, Type base, uint64_t offset, return success(); } -LogicalResult -CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep, - PatternRewriter &rewriter) const { +/// Returns the consistent type for the GEP if the GEP is not type-consistent. +/// Returns failure if the GEP is already consistent. +static FailureOr getRequiredConsistentGEPType(GEPOp gep) { // GEP of typed pointers are not supported. if (!gep.getElemType()) return failure(); @@ -317,34 +326,185 @@ CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep, Type typeHint = isElementTypeInconsistent(gep.getBase(), baseType); if (!typeHint) return failure(); + return typeHint; +} - SmallVector indices; - // Ensures all indices are static and fetches them. - for (auto index : gep.getIndices()) { - IntegerAttr indexInt = llvm::dyn_cast_if_present(index); - if (!indexInt) - return failure(); - indices.push_back(indexInt.getInt()); +LogicalResult +CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep, + PatternRewriter &rewriter) const { + FailureOr typeHint = getRequiredConsistentGEPType(gep); + if (failed(typeHint)) { + // GEP is already canonical, nothing to do here. + return failure(); } DataLayout layout = DataLayout::closest(gep); - std::optional desiredOffset = - gepToByteOffset(layout, gep.getSourceElementType(), indices); + std::optional desiredOffset = gepToByteOffset(layout, gep); if (!desiredOffset) return failure(); SmallVector newIndices; if (failed( - findIndicesForOffset(layout, typeHint, *desiredOffset, newIndices))) + findIndicesForOffset(layout, *typeHint, *desiredOffset, newIndices))) return failure(); rewriter.replaceOpWithNewOp( - gep, LLVM::LLVMPointerType::get(getContext()), typeHint, gep.getBase(), + gep, LLVM::LLVMPointerType::get(getContext()), *typeHint, gep.getBase(), newIndices, gep.getInbounds()); return success(); } +/// Returns the list of fields of `structType` that are written to by a store +/// operation writing `storeSize` bytes at `storeOffset` within the struct. +/// `storeOffset` is required to cleanly point to an immediate field within +/// the struct. +/// If the write operation were to write to any padding, write beyond the +/// struct, partially write to a field, or contains currently unsupported +/// types, failure is returned. +static FailureOr> +getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType, + int storeSize, unsigned storeOffset) { + ArrayRef body = structType.getBody(); + unsigned currentOffset = 0; + body = body.drop_until([&](Type type) { + if (!structType.isPacked()) { + unsigned alignment = dataLayout.getTypeABIAlignment(type); + currentOffset = llvm::alignTo(currentOffset, alignment); + } + + // currentOffset is guaranteed to be equal to offset since offset is either + // 0 or stems from a type-consistent GEP indexing into just a single + // aggregate. + if (currentOffset == storeOffset) + return true; + + assert(currentOffset < storeOffset && + "storeOffset should cleanly point into an immediate field"); + + currentOffset += dataLayout.getTypeSize(type); + return false; + }); + + size_t exclusiveEnd = 0; + for (; exclusiveEnd < body.size() && storeSize > 0; exclusiveEnd++) { + // Not yet recursively handling aggregates, only primitives. + if (!isa(body[exclusiveEnd])) + return failure(); + + if (!structType.isPacked()) { + unsigned alignment = dataLayout.getTypeABIAlignment(body[exclusiveEnd]); + // No padding allowed inbetween fields at this point in time. + if (!llvm::isAligned(llvm::Align(alignment), currentOffset)) + return failure(); + } + + unsigned fieldSize = dataLayout.getTypeSize(body[exclusiveEnd]); + currentOffset += fieldSize; + storeSize -= fieldSize; + } + + // If the storeSize is not 0 at this point we are either partially writing + // into a field or writing past the aggregate as a whole. Abort. + if (storeSize != 0) + return failure(); + return body.take_front(exclusiveEnd); +} + +LogicalResult +SplitIntegerStores::matchAndRewrite(StoreOp store, + PatternRewriter &rewriter) const { + IntegerType sourceType = dyn_cast(store.getValue().getType()); + if (!sourceType) { + // We currently only support integer sources. + return failure(); + } + + Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType); + if (!typeHint) { + // Nothing to do, since it is already consistent. + return failure(); + } + + auto dataLayout = DataLayout::closest(store); + + unsigned offset = 0; + Value address = store.getAddr(); + if (auto gepOp = address.getDefiningOp()) { + // Currently only handle canonical GEPs with exactly two indices, + // indexing a single aggregate deep. + // Recursing into sub-structs is left as a future exercise. + // If the GEP is not canonical we have to fail, otherwise we would not + // create type-consistent IR. + if (gepOp.getIndices().size() != 2 || + succeeded(getRequiredConsistentGEPType(gepOp))) + return failure(); + + // A GEP might point somewhere into the middle of an aggregate with the + // store storing into multiple adjacent elements. Destructure into + // the base address with an offset. + std::optional byteOffset = gepToByteOffset(dataLayout, gepOp); + if (!byteOffset) + return failure(); + + offset = *byteOffset; + typeHint = gepOp.getSourceElementType(); + address = gepOp.getBase(); + } + + auto structType = typeHint.dyn_cast(); + if (!structType) { + // TODO: Handle array types in the future. + return failure(); + } + + FailureOr> writtenToFields = + getWrittenToFields(dataLayout, structType, + /*storeSize=*/dataLayout.getTypeSize(sourceType), + /*storeOffset=*/offset); + if (failed(writtenToFields)) + return failure(); + + unsigned currentOffset = offset; + for (Type type : *writtenToFields) { + unsigned fieldSize = dataLayout.getTypeSize(type); + + // Extract the data out of the integer by first shifting right and then + // truncating it. + auto pos = rewriter.create( + store.getLoc(), + rewriter.getIntegerAttr(sourceType, (currentOffset - offset) * 8)); + + auto shrOp = rewriter.create(store.getLoc(), store.getValue(), pos); + + IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8); + Value valueToStore = + rewriter.create(store.getLoc(), fieldIntType, shrOp); + if (fieldIntType != type) { + // Bitcast to the right type. `fieldIntType` was explicitly created + // to be of the same size as `type` and must currently be a primitive as + // well. + valueToStore = + rewriter.create(store.getLoc(), type, valueToStore); + } + + // We create an `i8` indexed GEP here as that is the easiest (offset is + // already known). Other patterns turn this into a type-consistent GEP. + auto gepOp = rewriter.create(store.getLoc(), address.getType(), + rewriter.getI8Type(), address, + ArrayRef{currentOffset}); + rewriter.create(store.getLoc(), valueToStore, gepOp); + + // No need to care about padding here since we already checked previously + // that no padding exists in this range. + currentOffset += fieldSize; + } + + rewriter.eraseOp(store); + + return success(); +} + //===----------------------------------------------------------------------===// // Type consistency pass //===----------------------------------------------------------------------===// @@ -358,6 +518,7 @@ struct LLVMTypeConsistencyPass rewritePatterns.add>( &getContext()); rewritePatterns.add(&getContext()); + rewritePatterns.add(&getContext()); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen))) diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir index f8cfca9..5cef22a 100644 --- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir +++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -148,3 +148,155 @@ llvm.func @index_to_struct(%arg: i32) { llvm.store %arg, %7 : i32, !llvm.ptr llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_ints +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_ints(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64 + + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_ints_offset +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_ints_offset(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64 + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32)> : (i32) -> !llvm.ptr + %3 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32)> + + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %3 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_floats +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_floats(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64 + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (f32, f32)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32 + // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)> + // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// Padding test purposefully not modified. + +// CHECK-LABEL: llvm.func @coalesced_store_padding_inbetween +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_padding_inbetween(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, i32)> : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %1 : i64, !llvm.ptr + llvm.return +} + +// ----- + +// Padding test purposefully not modified. + +// CHECK-LABEL: llvm.func @coalesced_store_padding_end +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_padding_end(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i16)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i16)> : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %1 : i64, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_past_end +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_past_end(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32)> : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %1 : i64, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_packed_struct +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_packed_struct(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16 + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST48]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} -- 2.7.4