[mlir][llvm] Add memset support for mem2reg/sroa

This revision introduces support for memset intrinsics in SROA and
mem2reg for the LLVM dialect. This is achieved for SROA by breaking
memsets of aggregates into multiple memsets of scalars, and for mem2reg
by promoting memsets of single integer slots into the value the memset
operation would yield.

The SROA logic supports breaking memsets of static size operating at the
start of a memory slot. The intended most common case is for memsets
covering the entirety of a struct, most often as a way to initialize it
to 0.

The mem2reg logic supports dynamic values and static sizes as input to
promotable memsets. This is achieved by lowering memsets into
`ceil(log_2(n))` LeftShift operations, `ceil(log_2(n))` Or operations
and up to one ZExt operation (for n the byte width of the integer),
computing in registers the integer value the memset would create. Only
byte-aligned integers are supported, more types could easily be added
afterwards.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D152367
This commit is contained in:
Théo Degioanni
2023-06-14 08:43:10 +00:00
committed by Christian Ulmann
parent 1e41a29d73
commit 8404b23acd
7 changed files with 628 additions and 22 deletions

View File

@@ -282,8 +282,11 @@ def LLVM_MemcpyInlineOp :
# setAliasAnalysisMetadataCode;
}
def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], [],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
[DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile);
// Append the alias attributes defined by LLVM_IntrOpBase.

View File

@@ -104,16 +104,32 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
(ins "const ::mlir::MemorySlot &":$slot)
>,
InterfaceMethod<[{
Gets the value stored to the provided memory slot, or returns a null
value if this operation does not store to this slot. An operation
storing a value to a slot must always be able to provide the value it
stores. This method is only called on operations that use the slot.
Gets whether this operation stores to the specified slot.
No IR mutation is allowed in this method.
}],
"::mlir::Value", "getStored",
"bool", "storesTo",
(ins "const ::mlir::MemorySlot &":$slot)
>,
InterfaceMethod<[{
Gets the value stored to the provided memory slot, or returns a null
value if this operation does not store to this slot. An operation
storing a value to a slot must always be able to provide the value it
stores. This method is only called once per slot promotion, and only
on operations that store to the slot according to the `storesTo` method.
The returned value must dominate all operations dominated by the storing
operation.
If IR must be mutated to extract a concrete value being stored, mutation
must happen through the provided rewriter. The rewriter is located
immediately after the memory operation on call. No IR deletion is
allowed in this method. IR mutations must not introduce new uses of the
memory slot. Existing control flow must not be modified.
}],
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
"::mlir::RewriterBase &":$rewriter)
>,
InterfaceMethod<[{
Checks that this operation can be promoted to no longer use the provided
blocking uses, in the context of promoting `slot`.

View File

@@ -13,10 +13,13 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -119,12 +122,20 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}
Value LLVM::LoadOp::getStored(const MemorySlot &slot) { return {}; }
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
llvm_unreachable("getStored should not be called on LoadOp");
}
bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
Value LLVM::StoreOp::getStored(const MemorySlot &slot) {
return getAddr() == slot.ptr ? getValue() : Value();
bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
return getValue();
}
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -310,7 +321,8 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
for (IntegerAttr index : llvm::drop_begin(indices)) {
// Ensure the structure of the type being indexed can be reasoned about.
// This includes rejecting any potential typed pointer.
auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
auto destructurable =
llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
if (!destructurable)
return {};
@@ -369,7 +381,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
IntegerAttr firstLevelIndex =
llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
@@ -400,6 +413,180 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
return DeletionKind::Keep;
}
//===----------------------------------------------------------------------===//
// Interfaces for memset
//===----------------------------------------------------------------------===//
/// Returns the length of the given memset in bytes if it can be known at
/// compile-time on a best-effort basis, nothing otherwise.
static std::optional<uint64_t> getStaticMemsetLen(LLVM::MemsetOp op) {
IntegerAttr::ValueType memsetLen;
if (!matchPattern(op.getLen(), m_ConstantInt(&memsetLen)))
return {};
if (memsetLen.getBitWidth() > 64)
return {};
return memsetLen.getZExtValue();
}
/// Returns whether one can be sure the memset does not write outside of the
/// bounds of the given slot, on a best-effort basis.
static bool definitelyWritesOnlyWithinSlot(LLVM::MemsetOp op,
const MemorySlot &slot,
DataLayout &dataLayout) {
if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
op.getDst() != slot.ptr)
return false;
std::optional<uint64_t> memsetLen = getStaticMemsetLen(op);
return memsetLen && *memsetLen <= dataLayout.getTypeSize(slot.elemType);
}
bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
}
Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](IntegerType intType) -> Value {
if (intType.getWidth() == 8)
return getVal();
assert(intType.getWidth() % 8 == 0);
// Build the memset integer by repeatedly shifting the value and or-ing
// it with the previous value.
uint64_t coveredBits = 8;
Value currentValue =
rewriter.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
while (coveredBits < intType.getWidth()) {
Value shiftBy =
rewriter.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
Value shifted =
rewriter.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
currentValue =
rewriter.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
coveredBits *= 2;
}
return currentValue;
})
.Default([](Type) -> Value {
llvm_unreachable(
"getStored should not be called on memset to unsupported type");
});
}
bool LLVM::MemsetOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses) {
// TODO: Support non-integer types.
bool canConvertType =
TypeSwitch<Type, bool>(slot.elemType)
.Case([](IntegerType intType) {
return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
})
.Default([](Type) { return false; });
if (!canConvertType)
return false;
if (getIsVolatile())
return false;
DataLayout layout = DataLayout::closest(*this);
return getStaticMemsetLen(*this) == layout.getTypeSize(slot.elemType);
}
DeletionKind LLVM::MemsetOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
return DeletionKind::Delete;
}
LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
DataLayout dataLayout = DataLayout::closest(*this);
return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
}
bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
if (&slot.elemType.getDialect() != getOperation()->getDialect())
return false;
if (getIsVolatile())
return false;
if (!slot.elemType.cast<DestructurableTypeInterface>()
.getSubelementIndexMap())
return false;
DataLayout dataLayout = DataLayout::closest(*this);
return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
}
DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
std::optional<DenseMap<Attribute, Type>> types =
slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
IntegerAttr memsetLenAttr;
bool successfulMatch =
matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
(void)successfulMatch;
assert(successfulMatch);
bool packed = false;
if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
packed = structType.isPacked();
Type i32 = IntegerType::get(getContext(), 32);
DataLayout dataLayout = DataLayout::closest(*this);
uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
uint64_t covered = 0;
for (size_t i = 0; i < types->size(); i++) {
// Create indices on the fly to get elements in the right order.
Attribute index = IntegerAttr::get(i32, i);
Type elemType = types->at(index);
uint64_t typeSize = dataLayout.getTypeSize(elemType);
if (!packed)
covered =
llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
if (covered >= memsetLen)
break;
// If this subslot is used, apply a new memset to it.
// Otherwise, only compute its offset within the original memset.
if (subslots.contains(index)) {
uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
Value newMemsetSizeValue =
rewriter
.create<LLVM::ConstantOp>(
getLen().getLoc(),
IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
.getResult();
rewriter.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr,
getVal(), newMemsetSizeValue,
getIsVolatile());
}
covered += typeSize;
}
return DeletionKind::Delete;
}
//===----------------------------------------------------------------------===//
// Interfaces for destructurable types
//===----------------------------------------------------------------------===//

View File

@@ -23,6 +23,7 @@
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
@@ -160,7 +161,12 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
return getMemRef() == slot.ptr;
}
Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
Value memref::LoadOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter) {
llvm_unreachable("getStored should not be called on LoadOp");
}
bool memref::LoadOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
@@ -222,9 +228,12 @@ DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
Value memref::StoreOp::getStored(const MemorySlot &slot) {
if (getMemRef() != slot.ptr)
return {};
bool memref::StoreOp::storesTo(const MemorySlot &slot) {
return getMemRef() == slot.ptr;
}
Value memref::StoreOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter) {
return getValue();
}

View File

@@ -172,12 +172,13 @@ private:
/// Computes the reaching definition for all the operations that require
/// promotion. `reachingDef` is the value the slot should contain at the
/// beginning of the block. This method returns the reached definition at the
/// end of the block.
/// end of the block. This method must only be called at most once per block.
Value computeReachingDefInBlock(Block *block, Value reachingDef);
/// Computes the reaching definition for all the operations that require
/// promotion. `reachingDef` corresponds to the initial value the
/// slot will contain before any write, typically a poison value.
/// This method must only be called at most once per region.
void computeReachingDefInRegion(Region *region, Value reachingDef);
/// Removes the blocking uses of the slot, in topological order.
@@ -326,7 +327,7 @@ SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
// If we store to the slot, further loads will see that value.
// Because we did not meet any load before, the value is not live-in.
if (memOp.getStored(slot))
if (memOp.storesTo(slot))
break;
}
}
@@ -365,7 +366,7 @@ void MemorySlotPromotionAnalyzer::computeMergePoints(
SmallPtrSet<Block *, 16> definingBlocks;
for (Operation *user : slot.ptr.getUsers())
if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
if (storeOp.getStored(slot))
if (storeOp.storesTo(slot))
definingBlocks.insert(user->getBlock());
idfCalculator.setDefiningBlocks(definingBlocks);
@@ -416,13 +417,21 @@ MemorySlotPromotionAnalyzer::computeInfo() {
Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
Value reachingDef) {
for (Operation &op : block->getOperations()) {
SmallVector<Operation *> blockOps;
for (Operation &op : block->getOperations())
blockOps.push_back(&op);
for (Operation *op : blockOps) {
if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
if (info.userToBlockingUses.contains(memOp))
reachingDefs.insert({memOp, reachingDef});
if (Value stored = memOp.getStored(slot))
if (memOp.storesTo(slot)) {
rewriter.setInsertionPointAfter(memOp);
Value stored = memOp.getStored(slot, rewriter);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
}
}
}

View File

@@ -0,0 +1,145 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @basic_memset
llvm.func @basic_memset() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
// CHECK-NOT: "llvm.intr.memset"
// CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
// CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
// CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
// CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
// CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
// CHECK-NOT: "llvm.intr.memset"
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: llvm.return %[[VALUE_32]] : i32
llvm.return %2 : i32
}
// -----
// CHECK-LABEL: llvm.func @allow_dynamic_value_memset
// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
llvm.func @allow_dynamic_value_memset(%memset_value: i8) -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%memset_len = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
// CHECK-NOT: "llvm.intr.memset"
// CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
// CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
// CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
// CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
// CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
// CHECK-NOT: "llvm.intr.memset"
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: llvm.return %[[VALUE_32]] : i32
llvm.return %2 : i32
}
// -----
// CHECK-LABEL: llvm.func @exotic_target_memset
llvm.func @exotic_target_memset() -> i40 {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(5 : i32) : i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
// CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
// CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
// CHECK-NOT: "llvm.intr.memset"
// CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40
// CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
// CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
// CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
// CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
// CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]]
// CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]]
// CHECK-NOT: "llvm.intr.memset"
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
// CHECK: llvm.return %[[VALUE_COMPL]] : i40
llvm.return %2 : i40
}
// -----
// CHECK-LABEL: llvm.func @no_volatile_memset
llvm.func @no_volatile_memset() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(4 : i32) : i32
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = true}>
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
llvm.return %2 : i32
}
// -----
// CHECK-LABEL: llvm.func @no_partial_memset
llvm.func @no_partial_memset() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(2 : i32) : i32
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
llvm.return %2 : i32
}
// -----
// CHECK-LABEL: llvm.func @no_overflowing_memset
llvm.func @no_overflowing_memset() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(6 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(6 : i32) : i32
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
llvm.return %2 : i32
}
// -----
// CHECK-LABEL: llvm.func @only_byte_aligned_integers_memset
llvm.func @only_byte_aligned_integers_memset() -> i10 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i10
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i10 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(2 : i32) : i32
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i10
llvm.return %2 : i10
}

View File

@@ -0,0 +1,237 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(sroa))" --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @memset
llvm.func @memset() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// After SROA, only one i32 will be actually used, so only 4 bytes will be set.
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
// 16 bytes means it will span over the first 4 i32 entries
%memset_len = llvm.mlir.constant(16 : i32) : i32
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
%3 = llvm.load %2 : !llvm.ptr -> i32
llvm.return %3 : i32
}
// -----
// CHECK-LABEL: llvm.func @memset_partial
llvm.func @memset_partial() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// After SROA, only the second i32 will be actually used. As the memset writes up
// to half of it, only 2 bytes will be set.
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
// 6 bytes means it will span over the first i32 and half of the second i32.
%memset_len = llvm.mlir.constant(6 : i32) : i32
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
%3 = llvm.load %2 : !llvm.ptr -> i32
llvm.return %3 : i32
}
// -----
// CHECK-LABEL: llvm.func @memset_full
llvm.func @memset_full() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// After SROA, only one i32 will be actually used, so only 4 bytes will be set.
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
// 40 bytes means it will span over the entire array
%memset_len = llvm.mlir.constant(40 : i32) : i32
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
%3 = llvm.load %2 : !llvm.ptr -> i32
llvm.return %3 : i32
}
// -----
// CHECK-LABEL: llvm.func @memset_too_much
llvm.func @memset_too_much() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32>
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(41 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
// 41 bytes means it will span over the entire array, and then some
%memset_len = llvm.mlir.constant(41 : i32) : i32
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
%3 = llvm.load %2 : !llvm.ptr -> i32
llvm.return %3 : i32
}
// -----
// CHECK-LABEL: llvm.func @memset_no_volatile
llvm.func @memset_no_volatile() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32>
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(16 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(16 : i32) : i32
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = true}>
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
%3 = llvm.load %2 : !llvm.ptr -> i32
llvm.return %3 : i32
}
// -----
// CHECK-LABEL: llvm.func @indirect_memset
llvm.func @indirect_memset() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
// This memset will only cover the selected element.
%memset_len = llvm.mlir.constant(4 : i32) : i32
%2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
// CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
"llvm.intr.memset"(%2, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%3 = llvm.load %2 : !llvm.ptr -> i32
llvm.return %3 : i32
}
// -----
// CHECK-LABEL: llvm.func @invalid_indirect_memset
llvm.func @invalid_indirect_memset() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.struct<"foo", (i32, i32)>
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(6 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
// This memset will go slightly beyond one of the elements.
%memset_len = llvm.mlir.constant(6 : i32) : i32
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0]
%2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
// CHECK: "llvm.intr.memset"(%[[GEP]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
"llvm.intr.memset"(%2, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%3 = llvm.load %2 : !llvm.ptr -> i32
llvm.return %3 : i32
}
// -----
// CHECK-LABEL: llvm.func @memset_double_use
llvm.func @memset_double_use() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// After SROA, only one i32 will be actually used, so only 4 bytes will be set.
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
// 8 bytes means it will span over the two i32 entries.
%memset_len = llvm.mlir.constant(8 : i32) : i32
// We expect two generated memset, one for each field.
// CHECK-NOT: "llvm.intr.memset"
// CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
// CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
// CHECK-NOT: "llvm.intr.memset"
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)>
%3 = llvm.load %2 : !llvm.ptr -> i32
%4 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)>
%5 = llvm.load %4 : !llvm.ptr -> f32
// We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
%6 = llvm.bitcast %5 : f32 to i32
%7 = llvm.add %3, %6 : i32
llvm.return %7 : i32
}
// -----
// CHECK-LABEL: llvm.func @memset_considers_alignment
llvm.func @memset_considers_alignment() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
// CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
// 8 bytes means it will span over the i8 and the i32 entry.
// Because of padding, the f32 entry will not be touched.
%memset_len = llvm.mlir.constant(8 : i32) : i32
// Even though the two i32 are used, only one memset should be generated,
// as the second i32 is not touched by the initial memset.
// CHECK-NOT: "llvm.intr.memset"
// CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
// CHECK-NOT: "llvm.intr.memset"
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)>
%3 = llvm.load %2 : !llvm.ptr -> i32
%4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)>
%5 = llvm.load %4 : !llvm.ptr -> f32
// We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
%6 = llvm.bitcast %5 : f32 to i32
%7 = llvm.add %3, %6 : i32
llvm.return %7 : i32
}
// -----
// CHECK-LABEL: llvm.func @memset_considers_packing
llvm.func @memset_considers_packing() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
// CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
// CHECK-DAG: %[[MEMSET_LEN_WHOLE:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: %[[MEMSET_LEN_PARTIAL:.*]] = llvm.mlir.constant(3 : i32) : i32
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
%memset_value = llvm.mlir.constant(42 : i8) : i8
// 8 bytes means it will span over all the fields, because there is no padding as the struct is packed.
%memset_len = llvm.mlir.constant(8 : i32) : i32
// Now all fields are touched by the memset.
// CHECK-NOT: "llvm.intr.memset"
// CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_WHOLE]]) <{isVolatile = false}>
// CHECK: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_PARTIAL]]) <{isVolatile = false}>
// CHECK-NOT: "llvm.intr.memset"
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)>
%3 = llvm.load %2 : !llvm.ptr -> i32
%4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)>
%5 = llvm.load %4 : !llvm.ptr -> f32
// We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
%6 = llvm.bitcast %5 : f32 to i32
%7 = llvm.add %3, %6 : i32
llvm.return %7 : i32
}