From 98eead81868c1ba017cc5d8dbea11285d2eadc4c Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Sat, 9 May 2020 17:52:35 -0700 Subject: [PATCH] [mlir][Value] Add v.getDefiningOp() Summary: This makes a common pattern of `dyn_cast_or_null(v.getDefiningOp())` more concise. Differential Revision: https://reviews.llvm.org/D79681 --- mlir/docs/Tutorials/Toy/Ch-3.md | 3 +-- mlir/examples/toy/Ch3/mlir/ToyCombine.cpp | 3 +-- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 3 +-- mlir/examples/toy/Ch5/mlir/ToyCombine.cpp | 3 +-- mlir/examples/toy/Ch6/mlir/ToyCombine.cpp | 3 +-- mlir/examples/toy/Ch7/mlir/ToyCombine.cpp | 3 +-- mlir/include/mlir/IR/Value.h | 7 ++++++ mlir/lib/Analysis/AffineAnalysis.cpp | 2 +- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 2 +- mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp | 10 ++++---- .../LegalizeStandardForSPIRV.cpp | 5 ++-- mlir/lib/Dialect/Affine/EDSC/Builders.cpp | 2 +- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 4 ++-- .../Affine/Transforms/SuperVectorize.cpp | 2 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 9 ++++---- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 4 ++-- .../Dialect/Linalg/Transforms/Promotion.cpp | 4 ++-- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 2 +- mlir/lib/Dialect/SCF/SCF.cpp | 4 ++-- .../Transforms/ParallelLoopSpecialization.cpp | 2 +- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 12 ++++------ mlir/lib/Dialect/Vector/VectorOps.cpp | 4 ++-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 23 +++++++------------ 26 files changed, 56 insertions(+), 66 deletions(-) diff --git a/mlir/docs/Tutorials/Toy/Ch-3.md b/mlir/docs/Tutorials/Toy/Ch-3.md index cc31454fe533..d6a72b071647 100644 --- a/mlir/docs/Tutorials/Toy/Ch-3.md +++ b/mlir/docs/Tutorials/Toy/Ch-3.md @@ -91,8 +91,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); - TransposeOp transposeInputOp = - llvm::dyn_cast_or_null(transposeInput.getDefiningOp()); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. if (!transposeInputOp) diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp index 8529ea0f24ee..6b789c8d27d1 100644 --- a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp @@ -40,8 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); - TransposeOp transposeInputOp = - llvm::dyn_cast_or_null(transposeInput.getDefiningOp()); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. if (!transposeInputOp) diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 0dd38b2c31a4..c979f2d5fae3 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); - TransposeOp transposeInputOp = - llvm::dyn_cast_or_null(transposeInput.getDefiningOp()); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. if (!transposeInputOp) diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 0dd38b2c31a4..c979f2d5fae3 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); - TransposeOp transposeInputOp = - llvm::dyn_cast_or_null(transposeInput.getDefiningOp()); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. if (!transposeInputOp) diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp index 0dd38b2c31a4..c979f2d5fae3 100644 --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp @@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); - TransposeOp transposeInputOp = - llvm::dyn_cast_or_null(transposeInput.getDefiningOp()); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. if (!transposeInputOp) diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp index fafc3876db27..d48b989578cf 100644 --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -63,8 +63,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); - TransposeOp transposeInputOp = - llvm::dyn_cast_or_null(transposeInput.getDefiningOp()); + TransposeOp transposeInputOp = transposeInput.getDefiningOp(); // Input defined by another transpose? If not, no match. if (!transposeInputOp) diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 78517309468d..74f504c25156 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -116,6 +116,13 @@ public: /// defines it. Operation *getDefiningOp() const; + /// If this value is the result of an operation of type OpTy, return the + /// operation that defines it. + template + OpTy getDefiningOp() const { + return llvm::dyn_cast_or_null(getDefiningOp()); + } + /// If this value is the result of an operation, use it as a location, /// otherwise return an unknown location. Location getLoc() const; diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 185be49930b7..5a395937101f 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -453,7 +453,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, auto symbol = operands[i]; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. - if (auto cOp = dyn_cast_or_null(symbol.getDefiningOp())) + if (auto cOp = symbol.getDefiningOp()) dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), cOp.getValue()); } diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index b43cd6bd7be6..5c3f33d0a693 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -665,7 +665,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) { // Add top level symbol. addSymbolId(getNumSymbolIds(), id); // Check if the symbol is a constant. - if (auto constOp = dyn_cast_or_null(id.getDefiningOp())) + if (auto constOp = id.getDefiningOp()) setIdToConstant(id, constOp.getValue()); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 5151569d8067..e6d7127762d5 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -64,7 +64,7 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { assert(cst->containsId(value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. - if (auto cOp = dyn_cast_or_null(value.getDefiningOp())) + if (auto cOp = value.getDefiningOp()) cst->setIdToConstant(value, cOp.getValue()); } else if (auto loop = getForInductionVarOwner(value)) { if (failed(cst->addAffineForOpDomain(loop))) diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp index b52c264d8bab..3821b4a2cf34 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -219,7 +219,7 @@ struct LoopToGpuConverter { // Return true if the value is obviously a constant "one". static bool isConstantOne(Value value) { - if (auto def = dyn_cast_or_null(value.getDefiningOp())) + if (auto def = value.getDefiningOp()) return def.getValue() == 1; return false; } @@ -505,11 +505,11 @@ struct ParallelToGpuLaunchLowering : public OpRewritePattern { /// `upperBound`. static Value deriveStaticUpperBound(Value upperBound, PatternRewriter &rewriter) { - if (auto op = dyn_cast_or_null(upperBound.getDefiningOp())) { + if (auto op = upperBound.getDefiningOp()) { return op; } - if (auto minOp = dyn_cast_or_null(upperBound.getDefiningOp())) { + if (auto minOp = upperBound.getDefiningOp()) { for (const AffineExpr &result : minOp.map().getResults()) { if (auto constExpr = result.dyn_cast()) { return rewriter.create(minOp.getLoc(), @@ -518,7 +518,7 @@ static Value deriveStaticUpperBound(Value upperBound, } } - if (auto multiplyOp = dyn_cast_or_null(upperBound.getDefiningOp())) { + if (auto multiplyOp = upperBound.getDefiningOp()) { if (auto lhs = dyn_cast_or_null( deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter) .getDefiningOp())) @@ -607,7 +607,7 @@ static LogicalResult processParallelLoop( launchIndependent](Value val) -> Value { if (launchIndependent(val)) return val; - if (ConstantOp constOp = dyn_cast_or_null(val.getDefiningOp())) + if (ConstantOp constOp = val.getDefiningOp()) return rewriter.create(constOp.getLoc(), constOp.getValue()); return {}; }; diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index 6d9974233a9f..7ee82f9e18bf 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -110,7 +110,7 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter, LogicalResult LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, PatternRewriter &rewriter) const { - auto subViewOp = dyn_cast_or_null(loadOp.memref().getDefiningOp()); + auto subViewOp = loadOp.memref().getDefiningOp(); if (!subViewOp) { return failure(); } @@ -131,8 +131,7 @@ LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, LogicalResult StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, PatternRewriter &rewriter) const { - auto subViewOp = - dyn_cast_or_null(storeOp.memref().getDefiningOp()); + auto subViewOp = storeOp.memref().getDefiningOp(); if (!subViewOp) { return failure(); } diff --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp index 50e26574b7d5..98e6be955cba 100644 --- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp @@ -93,7 +93,7 @@ categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims, unsigned &numSymbols) { AffineExpr d; Value resultVal = nullptr; - if (auto constant = dyn_cast_or_null(val.getDefiningOp())) { + if (auto constant = val.getDefiningOp()) { d = getAffineConstantExpr(constant.getValue(), context); } else if (isValidSymbol(val) && !isValidDim(val)) { d = getAffineSymbolExpr(numSymbols++, context); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index c6d67723ecd1..16f4a3c6068e 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -591,7 +591,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, // 2. Compose AffineApplyOps and dispatch dims or symbols. for (unsigned i = 0, e = operands.size(); i < e; ++i) { auto t = operands[i]; - auto affineApply = dyn_cast_or_null(t.getDefiningOp()); + auto affineApply = t.getDefiningOp(); if (affineApply) { // a. Compose affine.apply operations. LLVM_DEBUG(affineApply.getOperation()->print( @@ -912,7 +912,7 @@ void AffineApplyOp::getCanonicalizationPatterns( static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { - auto cast = dyn_cast_or_null(operand.get().getDefiningOp()); + auto cast = operand.get().getDefiningOp(); if (cast && !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index 7b58d3a5ca0d..fe669624f6cb 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -965,7 +965,7 @@ static Value vectorizeOperand(Value operand, Operation *op, return nullptr; } // 3. vectorize constant. - if (auto constant = dyn_cast_or_null(operand.getDefiningOp())) { + if (auto constant = operand.getDefiningOp()) { return vectorizeConstant( op, constant, VectorType::get(state->strategy->vectorSizes, operand.getType())); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 83b65798ed9e..3a055d04b962 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -425,9 +425,8 @@ static LogicalResult verify(LandingpadOp op) { } else { // catch - global addresses only. // Bitcast ops should have global addresses as their args. - if (auto bcOp = dyn_cast_or_null(value.getDefiningOp())) { - if (auto addrOp = - dyn_cast_or_null(bcOp.arg().getDefiningOp())) + if (auto bcOp = value.getDefiningOp()) { + if (auto addrOp = bcOp.arg().getDefiningOp()) continue; return op.emitError("constant clauses expected") .attachNote(bcOp.getLoc()) @@ -435,9 +434,9 @@ static LogicalResult verify(LandingpadOp op) { "bitcast used in clauses for landingpad"; } // NullOp and AddressOfOp allowed - if (dyn_cast_or_null(value.getDefiningOp())) + if (value.getDefiningOp()) continue; - if (dyn_cast_or_null(value.getDefiningOp())) + if (value.getDefiningOp()) continue; return op.emitError("clause #") << idx << " is not a known constant - null, addressof, bitcast"; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 5803824a3162..fc2353e4087e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -52,7 +52,7 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op); static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { - auto castOp = dyn_cast_or_null(operand.get().getDefiningOp()); + auto castOp = operand.get().getDefiningOp(); if (castOp && canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index b85c586633cb..d541ed2a4f2d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -319,8 +319,8 @@ fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, // Must be a subview or a slice to guarantee there are loops we can fuse // into. - auto subView = dyn_cast_or_null(consumedView.getDefiningOp()); - auto slice = dyn_cast_or_null(consumedView.getDefiningOp()); + auto subView = consumedView.getDefiningOp(); + auto slice = consumedView.getDefiningOp(); if (!subView && !slice) { LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); continue; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 33479717a645..03f8d9e3fd18 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -88,7 +88,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( /// Otherwise return size. static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc, Value size) { - auto affineMinOp = dyn_cast_or_null(size.getDefiningOp()); + auto affineMinOp = size.getDefiningOp(); if (!affineMinOp) return size; int64_t minConst = std::numeric_limits::max(); @@ -112,7 +112,7 @@ static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers, alignment_attr = IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue()); if (!dynamicBuffers) - if (auto cst = dyn_cast_or_null(size.getDefiningOp())) + if (auto cst = size.getDefiningOp()) return std_alloc( MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)), ValueRange{}, alignment_attr); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 1fdbcdcb94fa..462c2ef0c9ba 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -287,7 +287,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, // accesses, unless we statically know the subview size divides the view // size evenly. int64_t viewSize = viewType.getDimSize(r); - auto sizeCst = dyn_cast_or_null(size.getDefiningOp()); + auto sizeCst = size.getDefiningOp(); if (ShapedType::isDynamic(viewSize) || !sizeCst || (viewSize % sizeCst.getValue()) != 0) { // Compute min(size, dim - offset) to avoid out-of-bounds accesses. diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index f67a9f7fbc22..b0dc1fa10679 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -36,7 +36,7 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context) OpFoldResult StorageCastOp::fold(ArrayRef operands) { // Matches x -> [scast -> scast] -> y, replacing the second scast with the // value of x if the casts invert each other. - auto srcScastOp = dyn_cast_or_null(arg().getDefiningOp()); + auto srcScastOp = arg().getDefiningOp(); if (!srcScastOp || srcScastOp.arg().getType() != getType()) return OpFoldResult(); return srcScastOp.arg(); diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index d93e1b835529..591179455c94 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -55,7 +55,7 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, } static LogicalResult verify(ForOp op) { - if (auto cst = dyn_cast_or_null(op.step().getDefiningOp())) + if (auto cst = op.step().getDefiningOp()) if (cst.getValue() <= 0) return op.emitOpError("constant step operand must be positive"); @@ -403,7 +403,7 @@ static LogicalResult verify(ParallelOp op) { // Check whether all constant step values are positive. for (Value stepValue : stepValues) - if (auto cst = dyn_cast_or_null(stepValue.getDefiningOp())) + if (auto cst = stepValue.getDefiningOp()) if (cst.getValue() <= 0) return op.emitOpError("constant step operand must be positive"); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp index 3c3140c052ee..94dba40a6436 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp @@ -29,7 +29,7 @@ static void specializeLoopForUnrolling(ParallelOp op) { SmallVector constantIndices; constantIndices.reserve(op.upperBound().size()); for (auto bound : op.upperBound()) { - auto minOp = dyn_cast_or_null(bound.getDefiningOp()); + auto minOp = bound.getDefiningOp(); if (!minOp) return; int64_t minConstant = std::numeric_limits::max(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index cb7fd0e6e2ea..553be944ab30 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -209,7 +209,7 @@ static detail::op_matcher m_ConstantIndex() { static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { - auto cast = dyn_cast_or_null(operand.get().getDefiningOp()); + auto cast = operand.get().getDefiningOp(); if (cast && !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; @@ -1696,7 +1696,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) { OpFoldResult IndexCastOp::fold(ArrayRef cstOperands) { // Fold IndexCast(IndexCast(x)) -> x - auto cast = dyn_cast_or_null(getOperand().getDefiningOp()); + auto cast = getOperand().getDefiningOp(); if (cast && cast.getOperand().getType() == getType()) return cast.getOperand(); @@ -2617,8 +2617,7 @@ OpFoldResult SubViewOp::fold(ArrayRef) { auto folds = [](Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { - auto castOp = - dyn_cast_or_null(operand.get().getDefiningOp()); + auto castOp = operand.get().getDefiningOp(); if (castOp && canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; @@ -2890,12 +2889,11 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern { LogicalResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { Value memrefOperand = viewOp.getOperand(0); - MemRefCastOp memrefCastOp = - dyn_cast_or_null(memrefOperand.getDefiningOp()); + MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp(); if (!memrefCastOp) return failure(); Value allocOperand = memrefCastOp.getOperand(); - AllocOp allocOp = dyn_cast_or_null(allocOperand.getDefiningOp()); + AllocOp allocOp = allocOperand.getDefiningOp(); if (!allocOp) return failure(); rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand, diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index f00f2843bd18..96f8597baa34 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1611,7 +1611,7 @@ public: // Return if the input of 'transposeOp' is not defined by another transpose. TransposeOp parentTransposeOp = - dyn_cast_or_null(transposeOp.vector().getDefiningOp()); + transposeOp.vector().getDefiningOp(); if (!parentTransposeOp) return failure(); @@ -1684,7 +1684,7 @@ OpFoldResult TupleGetOp::fold(ArrayRef operands) { // into: // %t = vector.tuple .., %e_i, .. // one less use // %x = %e_i - if (auto tupleOp = dyn_cast_or_null(getOperand().getDefiningOp())) + if (auto tupleOp = getOperand().getDefiningOp()) return tupleOp.getOperand(getIndex()); return {}; } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 209ed696c45a..0d1966fcaea5 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -193,12 +193,9 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { /// Promotes the loop body of a forOp to its containing block if the forOp /// it can be determined that the loop has a single iteration. LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) { - auto lbCstOp = - dyn_cast_or_null(forOp.lowerBound().getDefiningOp()); - auto ubCstOp = - dyn_cast_or_null(forOp.upperBound().getDefiningOp()); - auto stepCstOp = - dyn_cast_or_null(forOp.step().getDefiningOp()); + auto lbCstOp = forOp.lowerBound().getDefiningOp(); + auto ubCstOp = forOp.upperBound().getDefiningOp(); + auto stepCstOp = forOp.step().getDefiningOp(); if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.getValue() < 0 || ubCstOp.getValue() < 0 || stepCstOp.getValue() < 0) return failure(); @@ -590,12 +587,9 @@ LogicalResult mlir::loopUnrollByFactor(scf::ForOp forOp, Value stepUnrolled; bool generateEpilogueLoop = true; - auto lbCstOp = - dyn_cast_or_null(forOp.lowerBound().getDefiningOp()); - auto ubCstOp = - dyn_cast_or_null(forOp.upperBound().getDefiningOp()); - auto stepCstOp = - dyn_cast_or_null(forOp.step().getDefiningOp()); + auto lbCstOp = forOp.lowerBound().getDefiningOp(); + auto ubCstOp = forOp.upperBound().getDefiningOp(); + auto stepCstOp = forOp.step().getDefiningOp(); if (lbCstOp && ubCstOp && stepCstOp) { // Constant loop bounds computation. int64_t lbCst = lbCstOp.getValue(); @@ -1313,12 +1307,11 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder, // Check if the loop is already known to have a constant zero lower bound or // a constant one step. bool isZeroBased = false; - if (auto ubCst = - dyn_cast_or_null(lowerBound.getDefiningOp())) + if (auto ubCst = lowerBound.getDefiningOp()) isZeroBased = ubCst.getValue() == 0; bool isStepOne = false; - if (auto stepCst = dyn_cast_or_null(step.getDefiningOp())) + if (auto stepCst = step.getDefiningOp()) isStepOne = stepCst.getValue() == 1; // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)