mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
[MLIR] Add std.assume_alignment op.
Reviewers: ftynse, nicolasvasilache, andydavis1 Subscribers: bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D74378
This commit is contained in:
@@ -870,4 +870,14 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg">,
|
||||
let verifier = "return ::verify(*this);";
|
||||
}
|
||||
|
||||
def LLVM_AssumeOp : LLVM_Op<"intr.assume", []>,
|
||||
Arguments<(ins LLVM_Type:$cond)> {
|
||||
let llvmBuilder = [{
|
||||
llvm::Module *module = builder.GetInsertBlock()->getModule();
|
||||
llvm::Function *fn =
|
||||
llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::assume, {});
|
||||
builder.CreateCall(fn, {$cond});
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // LLVMIR_OPS
|
||||
|
||||
@@ -1639,4 +1639,21 @@ def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]>
|
||||
}];
|
||||
}
|
||||
|
||||
def AssumeAlignmentOp : Std_Op<"assume_alignment"> {
|
||||
let summary =
|
||||
"assertion that gives alignment information to the input memref";
|
||||
let description = [{
|
||||
The assume alignment operation takes a memref and a integer of alignment
|
||||
value, and internally annotates the buffer with the given alignment. If
|
||||
the buffer isn't aligned to the given alignment, the behavior is undefined.
|
||||
|
||||
This operation doesn't affect the semantics of a correct program. It's for
|
||||
optimization only, and the optimization is best-effort.
|
||||
}];
|
||||
let arguments = (ins AnyMemRef:$memref, PositiveI32Attr:$alignment);
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
#endif // STANDARD_OPS
|
||||
|
||||
@@ -2501,6 +2501,45 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct AssumeAlignmentOpLowering
|
||||
: public LLVMLegalizationPattern<AssumeAlignmentOp> {
|
||||
using LLVMLegalizationPattern<AssumeAlignmentOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OperandAdaptor<AssumeAlignmentOp> transformed(operands);
|
||||
Value memref = transformed.memref();
|
||||
unsigned alignment = cast<AssumeAlignmentOp>(op).alignment().getZExtValue();
|
||||
|
||||
MemRefDescriptor memRefDescriptor(memref);
|
||||
Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
|
||||
|
||||
// Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
|
||||
// the asserted memref.alignedPtr isn't used anywhere else, as the real
|
||||
// users like load/store/views always re-extract memref.alignedPtr as they
|
||||
// get lowered.
|
||||
//
|
||||
// This relies on LLVM's CSE optimization (potentially after SROA), since
|
||||
// after CSE all memref.alignedPtr instances get de-duplicated into the same
|
||||
// pointer SSA value.
|
||||
Value zero =
|
||||
createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0);
|
||||
Value mask = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(),
|
||||
alignment - 1);
|
||||
Value ptrValue =
|
||||
rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), ptr);
|
||||
rewriter.create<LLVM::AssumeOp>(
|
||||
op->getLoc(),
|
||||
rewriter.create<LLVM::ICmpOp>(
|
||||
op->getLoc(), LLVM::ICmpPredicate::eq,
|
||||
rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static void ensureDistinctSuccessors(Block &bb) {
|
||||
@@ -2612,6 +2651,7 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
|
||||
bool useAlloca) {
|
||||
// clang-format off
|
||||
patterns.insert<
|
||||
AssumeAlignmentOpLowering,
|
||||
DimOpLowering,
|
||||
LoadOpLowering,
|
||||
MemRefCastOpLowering,
|
||||
|
||||
@@ -2764,6 +2764,17 @@ SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AssumeAlignmentOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(AssumeAlignmentOp op) {
|
||||
unsigned alignment = op.alignment().getZExtValue();
|
||||
if (!llvm::isPowerOf2_32(alignment))
|
||||
return op.emitOpError("alignment must be power of 2");
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Pattern to rewrite a subview op with constant size arguments.
|
||||
|
||||
@@ -855,3 +855,18 @@ module {
|
||||
// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float
|
||||
// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @assume_alignment
|
||||
func @assume_alignment(%0 : memref<4x4xf16>) {
|
||||
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm<"half*"> to !llvm.i64
|
||||
// CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : !llvm.i64
|
||||
// CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (!llvm.i1) -> ()
|
||||
assume_alignment %0, 16 : memref<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -740,3 +740,11 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
|
||||
tensor_store %1, %0 : memref<4x4xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @assume_alignment
|
||||
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
|
||||
func @assume_alignment(%0: memref<4x4xf16>) {
|
||||
// CHECK: assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
|
||||
assume_alignment %0, 16 : memref<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1036,3 +1036,21 @@ func @invalid_memref_cast() {
|
||||
%2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// alignment is not power of 2.
|
||||
func @assume_alignment(%0: memref<4x4xf16>) {
|
||||
// expected-error@+1 {{alignment must be power of 2}}
|
||||
std.assume_alignment %0, 12 : memref<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// 0 alignment value.
|
||||
func @assume_alignment(%0: memref<4x4xf16>) {
|
||||
// expected-error@+1 {{'std.assume_alignment' op attribute 'alignment' failed to satisfy constraint: positive 32-bit integer attribute}}
|
||||
std.assume_alignment %0, 0 : memref<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user