mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[MLIR] Add type checking capability to RegionBranchOpInterface
- Add function `verifyTypes` that Op's can call to do type checking verification along the control flow edges described the Op's RegionBranchOpInterface. - We cannot rely on the verify methods on the OpInterface because the interface functions assume valid Ops, so they may crash if invoked on unverified Ops. (For example, scf.for getSuccessorRegions() calls getRegionIterArgs(), which dereferences getBody() block. If the scf.for is invalid with no body, this can lead to a segfault). `verifyTypes` can be called post op-verification to avoid this. Differential Revision: https://reviews.llvm.org/D82829
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -24,8 +25,9 @@ using namespace mlir;
|
||||
/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
|
||||
/// successor if 'operandIndex' is within the range of 'operands', or None if
|
||||
/// `operandIndex` isn't a successor operand index.
|
||||
Optional<BlockArgument> mlir::detail::getBranchSuccessorArgument(
|
||||
Optional<OperandRange> operands, unsigned operandIndex, Block *successor) {
|
||||
Optional<BlockArgument>
|
||||
detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
|
||||
unsigned operandIndex, Block *successor) {
|
||||
// Check that the operands are valid.
|
||||
if (!operands || operands->empty())
|
||||
return llvm::None;
|
||||
@@ -43,8 +45,8 @@ Optional<BlockArgument> mlir::detail::getBranchSuccessorArgument(
|
||||
|
||||
/// Verify that the given operands match those of the given successor block.
|
||||
LogicalResult
|
||||
mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||
Optional<OperandRange> operands) {
|
||||
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||
Optional<OperandRange> operands) {
|
||||
if (!operands)
|
||||
return success();
|
||||
|
||||
@@ -66,3 +68,139 @@ mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RegionBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Verify that types match along all region control flow edges originating from
|
||||
/// `sourceNo` (region # if source is a region, llvm::None if source is parent
|
||||
/// op). `getInputsTypesForRegion` is a function that returns the types of the
|
||||
/// inputs that flow from `sourceIndex' to the given region.
|
||||
static LogicalResult verifyTypesAlongAllEdges(
|
||||
Operation *op, Optional<unsigned> sourceNo,
|
||||
function_ref<TypeRange(Optional<unsigned>)> getInputsTypesForRegion) {
|
||||
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
||||
|
||||
SmallVector<RegionSuccessor, 2> successors;
|
||||
unsigned numInputs;
|
||||
if (sourceNo) {
|
||||
Region &srcRegion = op->getRegion(sourceNo.getValue());
|
||||
numInputs = srcRegion.getNumArguments();
|
||||
} else {
|
||||
numInputs = op->getNumOperands();
|
||||
}
|
||||
SmallVector<Attribute, 2> operands(numInputs, nullptr);
|
||||
regionInterface.getSuccessorRegions(sourceNo, operands, successors);
|
||||
|
||||
for (RegionSuccessor &succ : successors) {
|
||||
Optional<unsigned> succRegionNo;
|
||||
if (!succ.isParent())
|
||||
succRegionNo = succ.getSuccessor()->getRegionNumber();
|
||||
|
||||
auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
|
||||
diag << "from ";
|
||||
if (sourceNo)
|
||||
diag << "Region #" << sourceNo.getValue();
|
||||
else
|
||||
diag << op->getName();
|
||||
|
||||
diag << " to ";
|
||||
if (succRegionNo)
|
||||
diag << "Region #" << succRegionNo.getValue();
|
||||
else
|
||||
diag << op->getName();
|
||||
return diag;
|
||||
};
|
||||
|
||||
TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo);
|
||||
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
|
||||
if (sourceTypes.size() != succInputsTypes.size()) {
|
||||
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
|
||||
return printEdgeName(diag)
|
||||
<< " has " << sourceTypes.size()
|
||||
<< " source operands, but target successor needs "
|
||||
<< succInputsTypes.size();
|
||||
}
|
||||
|
||||
for (auto typesIdx :
|
||||
llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) {
|
||||
Type sourceType = std::get<0>(typesIdx.value());
|
||||
Type inputType = std::get<1>(typesIdx.value());
|
||||
if (sourceType != inputType) {
|
||||
InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
|
||||
return printEdgeName(diag)
|
||||
<< " source #" << typesIdx.index() << " type " << sourceType
|
||||
<< " should match input #" << typesIdx.index() << " type "
|
||||
<< inputType;
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Verify that types match along control flow edges described the given op.
|
||||
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
||||
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
||||
|
||||
auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
|
||||
if (regionNo.hasValue()) {
|
||||
return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
|
||||
.getTypes();
|
||||
}
|
||||
|
||||
// If the successor of a parent op is the parent itself
|
||||
// RegionBranchOpInterface does not have an API to query what the entry
|
||||
// operands will be in that case. Vend out the result types of the op in
|
||||
// that case so that type checking succeeds for this case.
|
||||
return op->getResultTypes();
|
||||
};
|
||||
|
||||
// Verify types along control flow edges originating from the parent.
|
||||
if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
|
||||
return failure();
|
||||
|
||||
// RegionBranchOpInterface should not be implemented by Ops that do not have
|
||||
// attached regions.
|
||||
assert(op->getNumRegions() != 0);
|
||||
|
||||
// Verify types along control flow edges originating from each region.
|
||||
for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
|
||||
Region ®ion = op->getRegion(regionNo);
|
||||
|
||||
// Since the interface cannnot distinguish between different ReturnLike
|
||||
// ops within the region branching to different successors, all ReturnLike
|
||||
// ops in this region should have the same operand types. We will then use
|
||||
// one of them as the representative for type matching.
|
||||
|
||||
Operation *regionReturn = nullptr;
|
||||
for (Block &block : region) {
|
||||
Operation *terminator = block.getTerminator();
|
||||
if (!terminator->hasTrait<OpTrait::ReturnLike>())
|
||||
continue;
|
||||
|
||||
if (!regionReturn) {
|
||||
regionReturn = terminator;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Found more than one ReturnLike terminator. Make sure the operand types
|
||||
// match with the first one.
|
||||
if (regionReturn->getOperandTypes() != terminator->getOperandTypes())
|
||||
return op->emitOpError("Region #")
|
||||
<< regionNo
|
||||
<< " operands mismatch between return-like terminators";
|
||||
}
|
||||
|
||||
auto inputTypesFromRegion = [&](Optional<unsigned> regionNo) -> TypeRange {
|
||||
// All successors get the same set of operands.
|
||||
return regionReturn ? TypeRange(regionReturn->getOperands().getTypes())
|
||||
: TypeRange();
|
||||
};
|
||||
|
||||
if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user