[MLIR] Add std.assume_alignment op.

Reviewers: ftynse, nicolasvasilache, andydavis1

Subscribers: bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74378
This commit is contained in:
Tim Shen
2020-02-10 19:44:42 -08:00
parent 28728bf06f
commit f581e655ec
7 changed files with 119 additions and 0 deletions

View File

@@ -870,4 +870,14 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg">,
let verifier = "return ::verify(*this);";
}
def LLVM_AssumeOp : LLVM_Op<"intr.assume", []>,
Arguments<(ins LLVM_Type:$cond)> {
let llvmBuilder = [{
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Function *fn =
llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::assume, {});
builder.CreateCall(fn, {$cond});
}];
}
#endif // LLVMIR_OPS

View File

@@ -1639,4 +1639,21 @@ def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]>
}];
}
def AssumeAlignmentOp : Std_Op<"assume_alignment"> {
let summary =
"assertion that gives alignment information to the input memref";
let description = [{
The assume alignment operation takes a memref and a integer of alignment
value, and internally annotates the buffer with the given alignment. If
the buffer isn't aligned to the given alignment, the behavior is undefined.
This operation doesn't affect the semantics of a correct program. It's for
optimization only, and the optimization is best-effort.
}];
let arguments = (ins AnyMemRef:$memref, PositiveI32Attr:$alignment);
let results = (outs);
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
}
#endif // STANDARD_OPS

View File

@@ -2501,6 +2501,45 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
}
};
struct AssumeAlignmentOpLowering
: public LLVMLegalizationPattern<AssumeAlignmentOp> {
using LLVMLegalizationPattern<AssumeAlignmentOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<AssumeAlignmentOp> transformed(operands);
Value memref = transformed.memref();
unsigned alignment = cast<AssumeAlignmentOp>(op).alignment().getZExtValue();
MemRefDescriptor memRefDescriptor(memref);
Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
// Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
// the asserted memref.alignedPtr isn't used anywhere else, as the real
// users like load/store/views always re-extract memref.alignedPtr as they
// get lowered.
//
// This relies on LLVM's CSE optimization (potentially after SROA), since
// after CSE all memref.alignedPtr instances get de-duplicated into the same
// pointer SSA value.
Value zero =
createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0);
Value mask = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(),
alignment - 1);
Value ptrValue =
rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), ptr);
rewriter.create<LLVM::AssumeOp>(
op->getLoc(),
rewriter.create<LLVM::ICmpOp>(
op->getLoc(), LLVM::ICmpPredicate::eq,
rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
rewriter.eraseOp(op);
return matchSuccess();
}
};
} // namespace
static void ensureDistinctSuccessors(Block &bb) {
@@ -2612,6 +2651,7 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
bool useAlloca) {
// clang-format off
patterns.insert<
AssumeAlignmentOpLowering,
DimOpLowering,
LoadOpLowering,
MemRefCastOpLowering,

View File

@@ -2764,6 +2764,17 @@ SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
return success();
}
//===----------------------------------------------------------------------===//
// AssumeAlignmentOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AssumeAlignmentOp op) {
unsigned alignment = op.alignment().getZExtValue();
if (!llvm::isPowerOf2_32(alignment))
return op.emitOpError("alignment must be power of 2");
return success();
}
namespace {
/// Pattern to rewrite a subview op with constant size arguments.

View File

@@ -855,3 +855,18 @@ module {
// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float
// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table
}
// -----
// CHECK-LABEL: func @assume_alignment
func @assume_alignment(%0 : memref<4x4xf16>) {
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : !llvm.i64
// CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm<"half*"> to !llvm.i64
// CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : !llvm.i64
// CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : !llvm.i64
// CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (!llvm.i1) -> ()
assume_alignment %0, 16 : memref<4x4xf16>
return
}

View File

@@ -740,3 +740,11 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
tensor_store %1, %0 : memref<4x4xi32>
return
}
// CHECK-LABEL: func @assume_alignment
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
func @assume_alignment(%0: memref<4x4xf16>) {
// CHECK: assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
assume_alignment %0, 16 : memref<4x4xf16>
return
}

View File

@@ -1036,3 +1036,21 @@ func @invalid_memref_cast() {
%2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
return
}
// -----
// alignment is not power of 2.
func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{alignment must be power of 2}}
std.assume_alignment %0, 12 : memref<4x4xf16>
return
}
// -----
// 0 alignment value.
func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{'std.assume_alignment' op attribute 'alignment' failed to satisfy constraint: positive 32-bit integer attribute}}
std.assume_alignment %0, 0 : memref<4x4xf16>
return
}