diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 1c2208062668..2ea1c6cad9a0 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -136,7 +136,7 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { // Get or create the declaration of the printf function in the module. - Function printfFunc = getPrintf(op->getFunction().getModule()); + Function printfFunc = getPrintf(op->getParentOfType()); auto print = cast(op); auto loc = print.getLoc(); diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 4f5ca5e495a5..927726696393 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -99,10 +99,6 @@ public: /// nullptr if this is a top-level block. Operation *getContainingOp(); - /// Returns the function that this block is part of, even if the block is - /// nested under an operation region. - Function getFunction(); - /// Insert this block (which must not already be in a function) right before /// the specified block. void insertBefore(Block *block); diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 70cd9b41ebb8..6913b7638d74 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -71,6 +71,11 @@ public: /// Return the operation that this refers to. Operation *getOperation() { return state; } + /// Return the closes surrounding parent operation that is of type 'OpTy'. + template OpTy getParentOfType() { + return getOperation()->getParentOfType(); + } + /// Return the context this operation belongs to. MLIRContext *getContext() { return getOperation()->getContext(); } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index b31dbda34f13..6e17ef063f86 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -125,10 +125,14 @@ public: /// or nullptr if this is a top-level operation. Operation *getParentOp(); - /// Returns the function that this operation is part of. - /// The function is determined by traversing the chain of parent operations. - /// Returns nullptr if the operation is unlinked. - Function getFunction(); + /// Return the closest surrounding parent operation that is of type 'OpTy'. + template OpTy getParentOfType() { + auto *op = this; + while ((op = op->getParentOp())) + if (auto parentOp = llvm::dyn_cast(op)) + return parentOp; + return OpTy(); + } /// Replace any uses of 'from' with 'to' within this operation. void replaceUsesOfWith(Value *from, Value *to); diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 37aebbdc496b..b5dbd539eb06 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -72,9 +72,6 @@ public: IRObjectWithUseList::replaceAllUsesWith(newValue); } - /// Return the function that this Value is defined in. - Function getFunction(); - /// If this value is the result of an operation, return the operation that /// defines it. Operation *getDefiningOp(); @@ -128,17 +125,11 @@ public: return const_cast(value)->getKind() == Kind::BlockArgument; } - /// Return the function that this argument is defined in. - Function getFunction(); - Block *getOwner() { return owner; } /// Returns the number of this argument. unsigned getArgNumber(); - /// Returns if the current argument is a function argument. - bool isFunctionArgument(); - private: friend class Block; // For access to private constructor. BlockArgument(Type type, Block *owner) diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index d98904346e56..d11d525dce60 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -307,7 +307,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { if (inserted) { reorderedDims.push_back(v); } - return getAffineDimExpr(iterPos->second, v->getFunction().getContext()) + return getAffineDimExpr(iterPos->second, v->getContext()) .cast(); } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index f915c4878269..2a52706c2770 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -442,13 +442,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Insert the `malloc` declaration if it is not already present. Function mallocFunc = - op->getFunction().getModule().getNamedFunction("malloc"); + op->getParentOfType().getModule().getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); mallocFunc = Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); - op->getFunction().getModule().push_back(mallocFunc); + op->getParentOfType().getModule().push_back(mallocFunc); } // Allocate the underlying buffer and store a pointer to it in the MemRef @@ -503,11 +503,12 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. - Function freeFunc = op->getFunction().getModule().getNamedFunction("free"); + Function freeFunc = + op->getParentOfType().getModule().getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); - op->getFunction().getModule().push_back(freeFunc); + op->getParentOfType().getModule().push_back(freeFunc); } auto type = transformed.memref()->getType().cast(); diff --git a/mlir/lib/EDSC/CoreAPIs.cpp b/mlir/lib/EDSC/CoreAPIs.cpp index 8a94dad8ae65..578b86736582 100644 --- a/mlir/lib/EDSC/CoreAPIs.cpp +++ b/mlir/lib/EDSC/CoreAPIs.cpp @@ -98,6 +98,6 @@ mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) { } unsigned getFunctionArity(mlir_func_t function) { - auto *f = reinterpret_cast(function); - return f->getNumArguments(); + auto f = mlir::Function::getFromOpaquePointer(function); + return f.getNumArguments(); } diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index 6cf57b42f457..92034c5d2886 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -426,7 +426,7 @@ LogicalResult LaunchFuncOp::verify() { return emitOpError("attribute 'kernel' must be a function"); } - auto module = getOperation()->getFunction().getModule(); + auto module = getParentOfType(); Function kernelFunc = module.getNamedFunction(kernel()); if (!kernelFunc) return emitError() << "kernel function '" << kernelAttr << "' is undefined"; diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp index 6cb920a10b3c..4f110ac286ae 100644 --- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp @@ -64,7 +64,7 @@ static Function outlineKernelFunc(gpu::LaunchOp launchOp) { FunctionType type = FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); std::string kernelFuncName = - Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str(); + Twine(launchOp.getParentOfType().getName(), "_kernel").str(); Function outlinedFunc = Function::create(loc, kernelFuncName, type); outlinedFunc.getBody().takeBody(launchOp.getBody()); Builder builder(launchOp.getContext()); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 065462273b43..bd6137b41b06 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1421,8 +1421,8 @@ void OperationPrinter::print(Block *block, bool printBlockArgs, os << ':'; // Print out some context information about the predecessors of this block. - if (!block->getFunction()) { - os << "\t// block is not in a function!"; + if (!block->getParent()) { + os << "\t// block is not in a region!"; } else if (block->hasNoPredecessors()) { os << "\t// no predecessors"; } else if (auto *pred = block->getSinglePredecessor()) { diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index e17b13d98fe4..93f5fe6e9767 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -50,13 +50,8 @@ Operation *Block::getContainingOp() { return getParent() ? getParent()->getContainingOp() : nullptr; } -Function Block::getFunction() { - auto *parent = getParent(); - return parent ? parent->getParentOfType() : nullptr; -} - -/// Insert this block (which must not already be in a function) right before -/// the specified block. +/// Insert this block (which must not already be in a region) right before the +/// specified block. void Block::insertBefore(Block *block) { assert(!getParent() && "already inserted into a block!"); assert(block->getParent() && "cannot insert before a block without a parent"); @@ -254,11 +249,11 @@ void Block::walk(Block::iterator begin, Block::iterator end, /// invalidated. Block *Block::splitBlock(iterator splitBefore) { // Start by creating a new basic block, and insert it immediate after this - // one in the containing function. + // one in the containing region. auto newBB = new Block(); getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB); - // Move all of the operations from the split point to the end of the function + // Move all of the operations from the split point to the end of the region // into the new block. newBB->getOperations().splice(newBB->end(), getOperations(), splitBefore, end()); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 40b759fcfd50..ba9e3cf17b9e 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,10 +281,6 @@ Operation *Operation::getParentOp() { return block ? block->getContainingOp() : nullptr; } -Function Operation::getFunction() { - return block ? block->getFunction() : nullptr; -} - /// Replace any uses of 'from' with 'to' within this operation. void Operation::replaceUsesOfWith(Value *from, Value *to) { if (from == to) diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 65a98f7ee59e..669f641b7341 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -29,21 +29,9 @@ Operation *Value::getDefiningOp() { return nullptr; } -/// Return the function that this Value is defined in. -Function Value::getFunction() { - switch (getKind()) { - case Value::Kind::BlockArgument: - return cast(this)->getFunction(); - case Value::Kind::OpResult: - return getDefiningOp()->getFunction(); - } - llvm_unreachable("Unknown Value Kind"); -} - Location Value::getLoc() { - if (auto *op = getDefiningOp()) { + if (auto *op = getDefiningOp()) return op->getLoc(); - } return UnknownLoc::get(getContext()); } @@ -78,20 +66,3 @@ void IRObjectWithUseList::dropAllUses() { use_begin()->drop(); } } - -//===----------------------------------------------------------------------===// -// BlockArgument implementation. -//===----------------------------------------------------------------------===// - -/// Return the function that this argument is defined in. -Function BlockArgument::getFunction() { - if (auto *owner = getOwner()) - return owner->getFunction(); - return nullptr; -} - -/// Returns if the current argument is a function argument. -bool BlockArgument::isFunctionArgument() { - auto containingFn = getFunction(); - return containingFn && &containingFn.front() == getOwner(); -} diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index f42f2860d5d7..a3d89c3c42bb 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -170,7 +170,7 @@ public: LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. - auto module = op->getFunction().getModule(); + auto module = op->getParentOfType(); Function mallocFunc = module.getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); @@ -231,7 +231,7 @@ public: auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. - auto module = op->getFunction().getModule(); + auto module = op->getParentOfType(); Function freeFunc = module.getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(voidPtrTy, {}); @@ -602,7 +602,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op, PatternRewriter &rewriter) { assert(isa(op)); auto fnName = LinalgOp::getLibraryCallName(); - auto module = op->getFunction().getModule(); + auto module = op->getParentOfType(); if (auto f = module.getNamedFunction(fnName)) { return f; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 3be51e67186d..63a01e254cdf 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -431,8 +431,7 @@ static LogicalResult verify(CallOp op) { auto fnAttr = op.getAttrOfType("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' function attribute"); - auto fn = op.getOperation()->getFunction().getModule().getNamedFunction( - fnAttr.getValue()); + auto fn = op.getParentOfType().getNamedFunction(fnAttr.getValue()); if (!fn) return op.emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; @@ -1098,8 +1097,8 @@ static LogicalResult verify(ConstantOp &op) { return op.emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. - auto fn = op.getOperation()->getFunction().getModule().getNamedFunction( - fnAttr.getValue()); + auto fn = + op.getParentOfType().getNamedFunction(fnAttr.getValue()); if (!fn) return op.emitOpError("reference to undefined function 'bar'"); @@ -2029,7 +2028,9 @@ static void print(OpAsmPrinter *p, ReturnOp op) { } static LogicalResult verify(ReturnOp op) { - auto function = op.getOperation()->getFunction(); + // TODO(b/137008268): Return op should verify that it is nested directly + // within a function operation. + auto function = op.getParentOfType(); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index e867dc70ed34..830546db497d 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -217,8 +217,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED emitRemarkForBlock(Block &block) { - auto *op = block.getContainingOp(); - return op ? op->emitRemark() : block.getFunction().emitRemark(); + return block.getContainingOp()->emitRemark(); } /// Creates a buffer in the faster memory space for the specified region; @@ -250,7 +249,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, OpBuilder &b = region.isWrite() ? epilogue : prologue; // Builder to create constants at the top level. - auto func = block->getFunction(); + auto func = block->getParent()->getParentOfType(); OpBuilder top(func.getBody()); auto loc = region.loc; @@ -765,10 +764,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { if (totalDmaBuffersSizeInBytes > fastMemCapacityBytes) { StringRef str = "Total size of all DMA buffers' for this block " "exceeds fast memory capacity\n"; - if (auto *op = block->getContainingOp()) - op->emitError(str); - else - block->getFunction().emitError(str); + block->getContainingOp()->emitError(str); } return totalDmaBuffersSizeInBytes; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1eee40b88da7..b2557a6c6fd3 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Create builder to insert alloc op just before 'forOp'. OpBuilder b(forInst); // Builder to create constants at the top level. - OpBuilder top(forInst->getFunction().getBody()); + OpBuilder top(forInst->getParentOfType().getBody()); // Create new memref type based on slice bounds. auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); @@ -1750,7 +1750,7 @@ public: }; // Search for siblings which load the same memref function argument. - auto fn = dstNode->op->getFunction(); + auto fn = dstNode->op->getParentOfType(); for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { for (auto *user : fn.getArgument(i)->getUsers()) { if (auto loadOp = dyn_cast(user)) { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index f59f1006ec51..fcac60c6a921 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -635,8 +635,8 @@ static bool emitSlice(MaterializationState *state, } } - LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs())); + LLVM_DEBUG(dbgs() << "\nFunction is now\n"); + LLVM_DEBUG((*slice)[0]->getParentOfType().print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 4ddf93c2232e..65847fc8bee8 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { Operation *op = forOp.getOperation(); if (!iv->use_empty()) { if (forOp.hasConstantLowerBound()) { - OpBuilder topBuilder(op->getFunction().getBody()); + OpBuilder topBuilder(op->getParentOfType().getBody()); auto constOp = topBuilder.create( forOp.getLoc(), forOp.getConstantLowerBound()); iv->replaceAllUsesWith(constOp); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 16f4effca15d..c1a4dcb7ebbb 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -81,11 +81,12 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, std::unique_ptr domInfo; std::unique_ptr postDomInfo; if (domInstFilter) - domInfo = llvm::make_unique(domInstFilter->getFunction()); + domInfo = llvm::make_unique( + domInstFilter->getParentOfType()); if (postDomInstFilter) - postDomInfo = - llvm::make_unique(postDomInstFilter->getFunction()); + postDomInfo = llvm::make_unique( + postDomInstFilter->getParentOfType()); // The ops where memref replacement succeeds are replaced with new ones. SmallVector opsToErase;