[mlir][Value] Add v.getDefiningOp<OpTy>()

Summary:
This makes a common pattern of
`dyn_cast_or_null<OpTy>(v.getDefiningOp())` more concise.

Differential Revision: https://reviews.llvm.org/D79681
This commit is contained in:
Sean Silva
2020-05-09 17:52:35 -07:00
parent 51e6fc44d0
commit 98eead8186
26 changed files with 56 additions and 66 deletions

View File

@@ -91,8 +91,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
mlir::Value transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)

View File

@@ -40,8 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
mlir::Value transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)

View File

@@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
mlir::Value transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)

View File

@@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
mlir::Value transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)

View File

@@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
mlir::Value transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)

View File

@@ -63,8 +63,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
mlir::Value transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)

View File

@@ -116,6 +116,13 @@ public:
/// defines it.
Operation *getDefiningOp() const;
/// If this value is the result of an operation of type OpTy, return the
/// operation that defines it.
template <typename OpTy>
OpTy getDefiningOp() const {
return llvm::dyn_cast_or_null<OpTy>(getDefiningOp());
}
/// If this value is the result of an operation, use it as a location,
/// otherwise return an unknown location.
Location getLoc() const;

View File

@@ -453,7 +453,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
auto symbol = operands[i];
assert(isValidSymbol(symbol));
// Check if the symbol is a constant.
if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol.getDefiningOp()))
if (auto cOp = symbol.getDefiningOp<ConstantIndexOp>())
dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
cOp.getValue());
}

View File

@@ -665,7 +665,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
// Add top level symbol.
addSymbolId(getNumSymbolIds(), id);
// Check if the symbol is a constant.
if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id.getDefiningOp()))
if (auto constOp = id.getDefiningOp<ConstantIndexOp>())
setIdToConstant(id, constOp.getValue());
}

View File

@@ -64,7 +64,7 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
assert(cst->containsId(value) && "value expected to be present");
if (isValidSymbol(value)) {
// Check if the symbol is a constant.
if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(value.getDefiningOp()))
if (auto cOp = value.getDefiningOp<ConstantIndexOp>())
cst->setIdToConstant(value, cOp.getValue());
} else if (auto loop = getForInductionVarOwner(value)) {
if (failed(cst->addAffineForOpDomain(loop)))

View File

@@ -219,7 +219,7 @@ struct LoopToGpuConverter {
// Return true if the value is obviously a constant "one".
static bool isConstantOne(Value value) {
if (auto def = dyn_cast_or_null<ConstantIndexOp>(value.getDefiningOp()))
if (auto def = value.getDefiningOp<ConstantIndexOp>())
return def.getValue() == 1;
return false;
}
@@ -505,11 +505,11 @@ struct ParallelToGpuLaunchLowering : public OpRewritePattern<ParallelOp> {
/// `upperBound`.
static Value deriveStaticUpperBound(Value upperBound,
PatternRewriter &rewriter) {
if (auto op = dyn_cast_or_null<ConstantIndexOp>(upperBound.getDefiningOp())) {
if (auto op = upperBound.getDefiningOp<ConstantIndexOp>()) {
return op;
}
if (auto minOp = dyn_cast_or_null<AffineMinOp>(upperBound.getDefiningOp())) {
if (auto minOp = upperBound.getDefiningOp<AffineMinOp>()) {
for (const AffineExpr &result : minOp.map().getResults()) {
if (auto constExpr = result.dyn_cast<AffineConstantExpr>()) {
return rewriter.create<ConstantIndexOp>(minOp.getLoc(),
@@ -518,7 +518,7 @@ static Value deriveStaticUpperBound(Value upperBound,
}
}
if (auto multiplyOp = dyn_cast_or_null<MulIOp>(upperBound.getDefiningOp())) {
if (auto multiplyOp = upperBound.getDefiningOp<MulIOp>()) {
if (auto lhs = dyn_cast_or_null<ConstantIndexOp>(
deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter)
.getDefiningOp()))
@@ -607,7 +607,7 @@ static LogicalResult processParallelLoop(
launchIndependent](Value val) -> Value {
if (launchIndependent(val))
return val;
if (ConstantOp constOp = dyn_cast_or_null<ConstantOp>(val.getDefiningOp()))
if (ConstantOp constOp = val.getDefiningOp<ConstantOp>())
return rewriter.create<ConstantOp>(constOp.getLoc(), constOp.getValue());
return {};
};

View File

@@ -110,7 +110,7 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
LogicalResult
LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
PatternRewriter &rewriter) const {
auto subViewOp = dyn_cast_or_null<SubViewOp>(loadOp.memref().getDefiningOp());
auto subViewOp = loadOp.memref().getDefiningOp<SubViewOp>();
if (!subViewOp) {
return failure();
}
@@ -131,8 +131,7 @@ LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
LogicalResult
StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
PatternRewriter &rewriter) const {
auto subViewOp =
dyn_cast_or_null<SubViewOp>(storeOp.memref().getDefiningOp());
auto subViewOp = storeOp.memref().getDefiningOp<SubViewOp>();
if (!subViewOp) {
return failure();
}

View File

@@ -93,7 +93,7 @@ categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
unsigned &numSymbols) {
AffineExpr d;
Value resultVal = nullptr;
if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val.getDefiningOp())) {
if (auto constant = val.getDefiningOp<ConstantIndexOp>()) {
d = getAffineConstantExpr(constant.getValue(), context);
} else if (isValidSymbol(val) && !isValidDim(val)) {
d = getAffineSymbolExpr(numSymbols++, context);

View File

@@ -591,7 +591,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
// 2. Compose AffineApplyOps and dispatch dims or symbols.
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
auto t = operands[i];
auto affineApply = dyn_cast_or_null<AffineApplyOp>(t.getDefiningOp());
auto affineApply = t.getDefiningOp<AffineApplyOp>();
if (affineApply) {
// a. Compose affine.apply operations.
LLVM_DEBUG(affineApply.getOperation()->print(
@@ -912,7 +912,7 @@ void AffineApplyOp::getCanonicalizationPatterns(
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
auto cast = operand.get().getDefiningOp<MemRefCastOp>();
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;

View File

@@ -965,7 +965,7 @@ static Value vectorizeOperand(Value operand, Operation *op,
return nullptr;
}
// 3. vectorize constant.
if (auto constant = dyn_cast_or_null<ConstantOp>(operand.getDefiningOp())) {
if (auto constant = operand.getDefiningOp<ConstantOp>()) {
return vectorizeConstant(
op, constant,
VectorType::get(state->strategy->vectorSizes, operand.getType()));

View File

@@ -425,9 +425,8 @@ static LogicalResult verify(LandingpadOp op) {
} else {
// catch - global addresses only.
// Bitcast ops should have global addresses as their args.
if (auto bcOp = dyn_cast_or_null<BitcastOp>(value.getDefiningOp())) {
if (auto addrOp =
dyn_cast_or_null<AddressOfOp>(bcOp.arg().getDefiningOp()))
if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>())
continue;
return op.emitError("constant clauses expected")
.attachNote(bcOp.getLoc())
@@ -435,9 +434,9 @@ static LogicalResult verify(LandingpadOp op) {
"bitcast used in clauses for landingpad";
}
// NullOp and AddressOfOp allowed
if (dyn_cast_or_null<NullOp>(value.getDefiningOp()))
if (value.getDefiningOp<NullOp>())
continue;
if (dyn_cast_or_null<AddressOfOp>(value.getDefiningOp()))
if (value.getDefiningOp<AddressOfOp>())
continue;
return op.emitError("clause #")
<< idx << " is not a known constant - null, addressof, bitcast";

View File

@@ -52,7 +52,7 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
if (castOp && canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;

View File

@@ -319,8 +319,8 @@ fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
// Must be a subview or a slice to guarantee there are loops we can fuse
// into.
auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
auto subView = consumedView.getDefiningOp<SubViewOp>();
auto slice = consumedView.getDefiningOp<SliceOp>();
if (!subView && !slice) {
LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
continue;

View File

@@ -88,7 +88,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
/// Otherwise return size.
static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
Value size) {
auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
auto affineMinOp = size.getDefiningOp<AffineMinOp>();
if (!affineMinOp)
return size;
int64_t minConst = std::numeric_limits<int64_t>::max();
@@ -112,7 +112,7 @@ static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
alignment_attr =
IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue());
if (!dynamicBuffers)
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
if (auto cst = size.getDefiningOp<ConstantIndexOp>())
return std_alloc(
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)),
ValueRange{}, alignment_attr);

View File

@@ -287,7 +287,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
// accesses, unless we statically know the subview size divides the view
// size evenly.
int64_t viewSize = viewType.getDimSize(r);
auto sizeCst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp());
auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
if (ShapedType::isDynamic(viewSize) || !sizeCst ||
(viewSize % sizeCst.getValue()) != 0) {
// Compute min(size, dim - offset) to avoid out-of-bounds accesses.

View File

@@ -36,7 +36,7 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
// Matches x -> [scast -> scast] -> y, replacing the second scast with the
// value of x if the casts invert each other.
auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg().getDefiningOp());
auto srcScastOp = arg().getDefiningOp<StorageCastOp>();
if (!srcScastOp || srcScastOp.arg().getType() != getType())
return OpFoldResult();
return srcScastOp.arg();

View File

@@ -55,7 +55,7 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
}
static LogicalResult verify(ForOp op) {
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp()))
if (auto cst = op.step().getDefiningOp<ConstantIndexOp>())
if (cst.getValue() <= 0)
return op.emitOpError("constant step operand must be positive");
@@ -403,7 +403,7 @@ static LogicalResult verify(ParallelOp op) {
// Check whether all constant step values are positive.
for (Value stepValue : stepValues)
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(stepValue.getDefiningOp()))
if (auto cst = stepValue.getDefiningOp<ConstantIndexOp>())
if (cst.getValue() <= 0)
return op.emitOpError("constant step operand must be positive");

View File

@@ -29,7 +29,7 @@ static void specializeLoopForUnrolling(ParallelOp op) {
SmallVector<int64_t, 2> constantIndices;
constantIndices.reserve(op.upperBound().size());
for (auto bound : op.upperBound()) {
auto minOp = dyn_cast_or_null<AffineMinOp>(bound.getDefiningOp());
auto minOp = bound.getDefiningOp<AffineMinOp>();
if (!minOp)
return;
int64_t minConstant = std::numeric_limits<int64_t>::max();

View File

@@ -209,7 +209,7 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
auto cast = operand.get().getDefiningOp<MemRefCastOp>();
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
@@ -1696,7 +1696,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
// Fold IndexCast(IndexCast(x)) -> x
auto cast = dyn_cast_or_null<IndexCastOp>(getOperand().getDefiningOp());
auto cast = getOperand().getDefiningOp<IndexCastOp>();
if (cast && cast.getOperand().getType() == getType())
return cast.getOperand();
@@ -2617,8 +2617,7 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
auto folds = [](Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp =
dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
if (castOp && canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
@@ -2890,12 +2889,11 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
LogicalResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
Value memrefOperand = viewOp.getOperand(0);
MemRefCastOp memrefCastOp =
dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp<MemRefCastOp>();
if (!memrefCastOp)
return failure();
Value allocOperand = memrefCastOp.getOperand();
AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
if (!allocOp)
return failure();
rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,

View File

@@ -1611,7 +1611,7 @@ public:
// Return if the input of 'transposeOp' is not defined by another transpose.
TransposeOp parentTransposeOp =
dyn_cast_or_null<TransposeOp>(transposeOp.vector().getDefiningOp());
transposeOp.vector().getDefiningOp<TransposeOp>();
if (!parentTransposeOp)
return failure();
@@ -1684,7 +1684,7 @@ OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) {
// into:
// %t = vector.tuple .., %e_i, .. // one less use
// %x = %e_i
if (auto tupleOp = dyn_cast_or_null<TupleOp>(getOperand().getDefiningOp()))
if (auto tupleOp = getOperand().getDefiningOp<TupleOp>())
return tupleOp.getOperand(getIndex());
return {};
}

View File

@@ -193,12 +193,9 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
auto lbCstOp =
dyn_cast_or_null<ConstantIndexOp>(forOp.lowerBound().getDefiningOp());
auto ubCstOp =
dyn_cast_or_null<ConstantIndexOp>(forOp.upperBound().getDefiningOp());
auto stepCstOp =
dyn_cast_or_null<ConstantIndexOp>(forOp.step().getDefiningOp());
auto lbCstOp = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
auto ubCstOp = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
auto stepCstOp = forOp.step().getDefiningOp<ConstantIndexOp>();
if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.getValue() < 0 ||
ubCstOp.getValue() < 0 || stepCstOp.getValue() < 0)
return failure();
@@ -590,12 +587,9 @@ LogicalResult mlir::loopUnrollByFactor(scf::ForOp forOp,
Value stepUnrolled;
bool generateEpilogueLoop = true;
auto lbCstOp =
dyn_cast_or_null<ConstantIndexOp>(forOp.lowerBound().getDefiningOp());
auto ubCstOp =
dyn_cast_or_null<ConstantIndexOp>(forOp.upperBound().getDefiningOp());
auto stepCstOp =
dyn_cast_or_null<ConstantIndexOp>(forOp.step().getDefiningOp());
auto lbCstOp = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
auto ubCstOp = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
auto stepCstOp = forOp.step().getDefiningOp<ConstantIndexOp>();
if (lbCstOp && ubCstOp && stepCstOp) {
// Constant loop bounds computation.
int64_t lbCst = lbCstOp.getValue();
@@ -1313,12 +1307,11 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
bool isZeroBased = false;
if (auto ubCst =
dyn_cast_or_null<ConstantIndexOp>(lowerBound.getDefiningOp()))
if (auto ubCst = lowerBound.getDefiningOp<ConstantIndexOp>())
isZeroBased = ubCst.getValue() == 0;
bool isStepOne = false;
if (auto stepCst = dyn_cast_or_null<ConstantIndexOp>(step.getDefiningOp()))
if (auto stepCst = step.getDefiningOp<ConstantIndexOp>())
isStepOne = stepCst.getValue() == 1;
// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)