# setAliasAnalysisMetadataCode;
}
-def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], [],
- /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
+def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
+ [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+ DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+ DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
+ /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile);
// Append the alias attributes defined by LLVM_IntrOpBase.
(ins "const ::mlir::MemorySlot &":$slot)
>,
InterfaceMethod<[{
+ Gets whether this operation stores to the specified slot.
+
+ No IR mutation is allowed in this method.
+ }],
+ "bool", "storesTo",
+ (ins "const ::mlir::MemorySlot &":$slot)
+ >,
+ InterfaceMethod<[{
Gets the value stored to the provided memory slot, or returns a null
value if this operation does not store to this slot. An operation
storing a value to a slot must always be able to provide the value it
- stores. This method is only called on operations that use the slot.
-
- No IR mutation is allowed in this method.
+ stores. This method is only called once per slot promotion, and only
+ on operations that store to the slot according to the `storesTo` method.
+ The returned value must dominate all operations dominated by the storing
+ operation.
+
+ If IR must be mutated to extract a concrete value being stored, mutation
+ must happen through the provided rewriter. The rewriter is located
+ immediately after the memory operation on call. No IR deletion is
+ allowed in this method. IR mutations must not introduce new uses of the
+ memory slot. Existing control flow must not be modified.
}],
"::mlir::Value", "getStored",
- (ins "const ::mlir::MemorySlot &":$slot)
+ (ins "const ::mlir::MemorySlot &":$slot,
+ "::mlir::RewriterBase &":$rewriter)
>,
InterfaceMethod<[{
Checks that this operation can be promoted to no longer use the provided
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
return getAddr() == slot.ptr;
}
-Value LLVM::LoadOp::getStored(const MemorySlot &slot) { return {}; }
+bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
+
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+ llvm_unreachable("getStored should not be called on LoadOp");
+}
bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
-Value LLVM::StoreOp::getStored(const MemorySlot &slot) {
- return getAddr() == slot.ptr ? getValue() : Value();
+bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
+ return getAddr() == slot.ptr;
+}
+
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+ return getValue();
}
bool LLVM::LoadOp::canUsesBeRemoved(
for (IntegerAttr index : llvm::drop_begin(indices)) {
// Ensure the structure of the type being indexed can be reasoned about.
// This includes rejecting any potential typed pointer.
- auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
+ auto destructurable =
+ llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
if (!destructurable)
return {};
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
- IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
+ IntegerAttr firstLevelIndex =
+ llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
}
//===----------------------------------------------------------------------===//
+// Interfaces for memset
+//===----------------------------------------------------------------------===//
+
+/// Returns the length of the given memset in bytes if it can be known at
+/// compile-time on a best-effort basis, nothing otherwise.
+static std::optional<uint64_t> getStaticMemsetLen(LLVM::MemsetOp op) {
+ IntegerAttr::ValueType memsetLen;
+ if (!matchPattern(op.getLen(), m_ConstantInt(&memsetLen)))
+ return {};
+
+ if (memsetLen.getBitWidth() > 64)
+ return {};
+
+ return memsetLen.getZExtValue();
+}
+
+/// Returns whether one can be sure the memset does not write outside of the
+/// bounds of the given slot, on a best-effort basis.
+static bool definitelyWritesOnlyWithinSlot(LLVM::MemsetOp op,
+ const MemorySlot &slot,
+ DataLayout &dataLayout) {
+ if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
+ op.getDst() != slot.ptr)
+ return false;
+
+ std::optional<uint64_t> memsetLen = getStaticMemsetLen(op);
+ return memsetLen && *memsetLen <= dataLayout.getTypeSize(slot.elemType);
+}
+
+bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
+ return getDst() == slot.ptr;
+}
+
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
+ RewriterBase &rewriter) {
+ // TODO: Support non-integer types.
+ return TypeSwitch<Type, Value>(slot.elemType)
+ .Case([&](IntegerType intType) -> Value {
+ if (intType.getWidth() == 8)
+ return getVal();
+
+ assert(intType.getWidth() % 8 == 0);
+
+ // Build the memset integer by repeatedly shifting the value and or-ing
+ // it with the previous value.
+ uint64_t coveredBits = 8;
+ Value currentValue =
+ rewriter.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
+ while (coveredBits < intType.getWidth()) {
+ Value shiftBy =
+ rewriter.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
+ Value shifted =
+ rewriter.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
+ currentValue =
+ rewriter.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
+ coveredBits *= 2;
+ }
+
+ return currentValue;
+ })
+ .Default([](Type) -> Value {
+ llvm_unreachable(
+ "getStored should not be called on memset to unsupported type");
+ });
+}
+
+bool LLVM::MemsetOp::canUsesBeRemoved(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ // TODO: Support non-integer types.
+ bool canConvertType =
+ TypeSwitch<Type, bool>(slot.elemType)
+ .Case([](IntegerType intType) {
+ return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
+ })
+ .Default([](Type) { return false; });
+ if (!canConvertType)
+ return false;
+
+ if (getIsVolatile())
+ return false;
+
+ DataLayout layout = DataLayout::closest(*this);
+ return getStaticMemsetLen(*this) == layout.getTypeSize(slot.elemType);
+}
+
+DeletionKind LLVM::MemsetOp::removeBlockingUses(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ RewriterBase &rewriter, Value reachingDefinition) {
+ return DeletionKind::Delete;
+}
+
+LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ DataLayout dataLayout = DataLayout::closest(*this);
+ return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
+}
+
+bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ if (&slot.elemType.getDialect() != getOperation()->getDialect())
+ return false;
+
+ if (getIsVolatile())
+ return false;
+
+ if (!slot.elemType.cast<DestructurableTypeInterface>()
+ .getSubelementIndexMap())
+ return false;
+
+ DataLayout dataLayout = DataLayout::closest(*this);
+ return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
+}
+
+DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter) {
+ std::optional<DenseMap<Attribute, Type>> types =
+ slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
+
+ IntegerAttr memsetLenAttr;
+ bool successfulMatch =
+ matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
+ (void)successfulMatch;
+ assert(successfulMatch);
+
+ bool packed = false;
+ if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
+ packed = structType.isPacked();
+
+ Type i32 = IntegerType::get(getContext(), 32);
+ DataLayout dataLayout = DataLayout::closest(*this);
+ uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
+ uint64_t covered = 0;
+ for (size_t i = 0; i < types->size(); i++) {
+ // Create indices on the fly to get elements in the right order.
+ Attribute index = IntegerAttr::get(i32, i);
+ Type elemType = types->at(index);
+ uint64_t typeSize = dataLayout.getTypeSize(elemType);
+
+ if (!packed)
+ covered =
+ llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
+
+ if (covered >= memsetLen)
+ break;
+
+ // If this subslot is used, apply a new memset to it.
+ // Otherwise, only compute its offset within the original memset.
+ if (subslots.contains(index)) {
+ uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
+
+ Value newMemsetSizeValue =
+ rewriter
+ .create<LLVM::ConstantOp>(
+ getLen().getLoc(),
+ IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
+ .getResult();
+
+ rewriter.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr,
+ getVal(), newMemsetSizeValue,
+ getIsVolatile());
+ }
+
+ covered += typeSize;
+ }
+
+ return DeletionKind::Delete;
+}
+
+//===----------------------------------------------------------------------===//
// Interfaces for destructurable types
//===----------------------------------------------------------------------===//
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
return getMemRef() == slot.ptr;
}
-Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
+bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
+
+Value memref::LoadOp::getStored(const MemorySlot &slot,
+ RewriterBase &rewriter) {
+ llvm_unreachable("getStored should not be called on LoadOp");
+}
bool memref::LoadOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
-Value memref::StoreOp::getStored(const MemorySlot &slot) {
- if (getMemRef() != slot.ptr)
- return {};
+bool memref::StoreOp::storesTo(const MemorySlot &slot) {
+ return getMemRef() == slot.ptr;
+}
+
+Value memref::StoreOp::getStored(const MemorySlot &slot,
+ RewriterBase &rewriter) {
return getValue();
}
/// Computes the reaching definition for all the operations that require
/// promotion. `reachingDef` is the value the slot should contain at the
/// beginning of the block. This method returns the reached definition at the
- /// end of the block.
+ /// end of the block. This method must only be called at most once per block.
Value computeReachingDefInBlock(Block *block, Value reachingDef);
/// Computes the reaching definition for all the operations that require
/// promotion. `reachingDef` corresponds to the initial value the
/// slot will contain before any write, typically a poison value.
+ /// This method must only be called at most once per region.
void computeReachingDefInRegion(Region *region, Value reachingDef);
/// Removes the blocking uses of the slot, in topological order.
// If we store to the slot, further loads will see that value.
// Because we did not meet any load before, the value is not live-in.
- if (memOp.getStored(slot))
+ if (memOp.storesTo(slot))
break;
}
}
SmallPtrSet<Block *, 16> definingBlocks;
for (Operation *user : slot.ptr.getUsers())
if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
- if (storeOp.getStored(slot))
+ if (storeOp.storesTo(slot))
definingBlocks.insert(user->getBlock());
idfCalculator.setDefiningBlocks(definingBlocks);
Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
Value reachingDef) {
- for (Operation &op : block->getOperations()) {
+ SmallVector<Operation *> blockOps;
+ for (Operation &op : block->getOperations())
+ blockOps.push_back(&op);
+ for (Operation *op : blockOps) {
if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
if (info.userToBlockingUses.contains(memOp))
reachingDefs.insert({memOp, reachingDef});
- if (Value stored = memOp.getStored(slot))
+ if (memOp.storesTo(slot)) {
+ rewriter.setInsertionPointAfter(memOp);
+ Value stored = memOp.getStored(slot, rewriter);
+ assert(stored && "a memory operation storing to a slot must provide a "
+ "new definition of the slot");
reachingDef = stored;
+ }
}
}
--- /dev/null
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @basic_memset
+llvm.func @basic_memset() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ %memset_len = llvm.mlir.constant(4 : i32) : i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK-NOT: "llvm.intr.memset"
+ // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+ // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+ // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+ // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+ // CHECK-NOT: "llvm.intr.memset"
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[VALUE_32]] : i32
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @allow_dynamic_value_memset
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @allow_dynamic_value_memset(%memset_value: i8) -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_len = llvm.mlir.constant(4 : i32) : i32
+ // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK-NOT: "llvm.intr.memset"
+ // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+ // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+ // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+ // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+ // CHECK-NOT: "llvm.intr.memset"
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[VALUE_32]] : i32
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @exotic_target_memset
+llvm.func @exotic_target_memset() -> i40 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ %memset_len = llvm.mlir.constant(5 : i32) : i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
+ // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
+ // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK-NOT: "llvm.intr.memset"
+ // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40
+ // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+ // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+ // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+ // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]]
+ // CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]]
+ // CHECK-NOT: "llvm.intr.memset"
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
+ // CHECK: llvm.return %[[VALUE_COMPL]] : i40
+ llvm.return %2 : i40
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_volatile_memset
+llvm.func @no_volatile_memset() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ %memset_len = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = true}>
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_partial_memset
+llvm.func @no_partial_memset() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ %memset_len = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_overflowing_memset
+llvm.func @no_overflowing_memset() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(6 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ %memset_len = llvm.mlir.constant(6 : i32) : i32
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @only_byte_aligned_integers_memset
+llvm.func @only_byte_aligned_integers_memset() -> i10 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i10
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i10 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ %memset_len = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i10
+ llvm.return %2 : i10
+}
--- /dev/null
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(sroa))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @memset
+llvm.func @memset() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ // 16 bytes means it will span over the first 4 i32 entries
+ %memset_len = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_partial
+llvm.func @memset_partial() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // After SROA, only the second i32 will be actually used. As the memset writes up
+ // to half of it, only 2 bytes will be set.
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ // 6 bytes means it will span over the first i32 and half of the second i32.
+ %memset_len = llvm.mlir.constant(6 : i32) : i32
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_full
+llvm.func @memset_full() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ // 40 bytes means it will span over the entire array
+ %memset_len = llvm.mlir.constant(40 : i32) : i32
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_too_much
+llvm.func @memset_too_much() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32>
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(41 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ // 41 bytes means it will span over the entire array, and then some
+ %memset_len = llvm.mlir.constant(41 : i32) : i32
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_no_volatile
+llvm.func @memset_no_volatile() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32>
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(16 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ %memset_len = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = true}>
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @indirect_memset
+llvm.func @indirect_memset() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ // This memset will only cover the selected element.
+ %memset_len = llvm.mlir.constant(4 : i32) : i32
+ %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ "llvm.intr.memset"(%2, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @invalid_indirect_memset
+llvm.func @invalid_indirect_memset() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.struct<"foo", (i32, i32)>
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(6 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ // This memset will go slightly beyond one of the elements.
+ %memset_len = llvm.mlir.constant(6 : i32) : i32
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0]
+ %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
+ // CHECK: "llvm.intr.memset"(%[[GEP]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ "llvm.intr.memset"(%2, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_double_use
+llvm.func @memset_double_use() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ // 8 bytes means it will span over the two i32 entries.
+ %memset_len = llvm.mlir.constant(8 : i32) : i32
+ // We expect two generated memset, one for each field.
+ // CHECK-NOT: "llvm.intr.memset"
+ // CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ // CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ // CHECK-NOT: "llvm.intr.memset"
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ %4 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)>
+ %5 = llvm.load %4 : !llvm.ptr -> f32
+ // We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
+ %6 = llvm.bitcast %5 : f32 to i32
+ %7 = llvm.add %3, %6 : i32
+ llvm.return %7 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_considers_alignment
+llvm.func @memset_considers_alignment() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
+ // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ // 8 bytes means it will span over the i8 and the i32 entry.
+ // Because of padding, the f32 entry will not be touched.
+ %memset_len = llvm.mlir.constant(8 : i32) : i32
+ // Even though the two i32 are used, only one memset should be generated,
+ // as the second i32 is not touched by the initial memset.
+ // CHECK-NOT: "llvm.intr.memset"
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+ // CHECK-NOT: "llvm.intr.memset"
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ %4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)>
+ %5 = llvm.load %4 : !llvm.ptr -> f32
+ // We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
+ %6 = llvm.bitcast %5 : f32 to i32
+ %7 = llvm.add %3, %6 : i32
+ llvm.return %7 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_considers_packing
+llvm.func @memset_considers_packing() -> i32 {
+ // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+ // CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
+ // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+ // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
+ // CHECK-DAG: %[[MEMSET_LEN_WHOLE:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK-DAG: %[[MEMSET_LEN_PARTIAL:.*]] = llvm.mlir.constant(3 : i32) : i32
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ // 8 bytes means it will span over all the fields, because there is no padding as the struct is packed.
+ %memset_len = llvm.mlir.constant(8 : i32) : i32
+ // Now all fields are touched by the memset.
+ // CHECK-NOT: "llvm.intr.memset"
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_WHOLE]]) <{isVolatile = false}>
+ // CHECK: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_PARTIAL]]) <{isVolatile = false}>
+ // CHECK-NOT: "llvm.intr.memset"
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)>
+ %3 = llvm.load %2 : !llvm.ptr -> i32
+ %4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)>
+ %5 = llvm.load %4 : !llvm.ptr -> f32
+ // We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
+ %6 = llvm.bitcast %5 : f32 to i32
+ %7 = llvm.add %3, %6 : i32
+ llvm.return %7 : i32
+}