mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 01:15:50 +08:00
[mlir] Add a new SymbolUserOpInterface class
The initial goal of this interface is to fix the current problems with verifying symbol user operations, but can extend beyond that in the future. The current problems with the verification of symbol uses are: * Extremely inefficient: Most current symbol users perform the symbol lookup using the slow O(N) string compare methods, which can lead to extremely long verification times in large modules. * Invalid/break the constraints of verification pass If the symbol reference is not-flat(and even if it is flat in some cases) a verifier for an operation is not permitted to touch the referenced operation because it may be in the process of being mutated by a different thread within the pass manager. The new SymbolUserOpInterface exposes a method `verifySymbolUses` that will be invoked from the parent symbol table to allow for verifying the constraints of any referenced symbols. This method is passed a `SymbolTableCollection` to allow for O(1) lookups of any necessary symbol operation. Differential Revision: https://reviews.llvm.org/D89512
This commit is contained in:
@@ -231,4 +231,12 @@ format of the header for each interface section goes as follows:
|
||||
|
||||
##### SymbolInterfaces
|
||||
|
||||
* `SymbolOpInterface` - Used to represent [`Symbol`](SymbolsAndSymbolTables.md#symbol) operations which reside immediately within a region that defines a [`SymbolTable`](SymbolsAndSymbolTables.md#symbol-table).
|
||||
* `SymbolOpInterface` - Used to represent
|
||||
[`Symbol`](SymbolsAndSymbolTables.md#symbol) operations which reside
|
||||
immediately within a region that defines a
|
||||
[`SymbolTable`](SymbolsAndSymbolTables.md#symbol-table).
|
||||
|
||||
* `SymbolUserOpInterface` - Used to represent operations that reference
|
||||
[`Symbol`](SymbolsAndSymbolTables.md#symbol) operations. This provides the
|
||||
ability to perform safe and efficient verification of symbol uses, as well
|
||||
as additional functionality.
|
||||
|
||||
@@ -142,6 +142,10 @@ See the `LangRef` definition of the
|
||||
[`SymbolRefAttr`](LangRef.md#symbol-reference-attribute) for more information
|
||||
about the structure of this attribute.
|
||||
|
||||
Operations that reference a `Symbol` and want to perform verification and
|
||||
general mutation of the symbol should implement the `SymbolUserOpInterface` to
|
||||
ensure that symbol accesses are legal and efficient.
|
||||
|
||||
### Manipulating a Symbol
|
||||
|
||||
As described above, `SymbolRefs` act as an auxiliary way of defining uses of
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
@@ -733,7 +734,9 @@ def BranchOp : Std_Op<"br",
|
||||
// CallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
|
||||
def CallOp : Std_Op<"call",
|
||||
[CallOpInterface, MemRefsNormalizable,
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
|
||||
let summary = "call operation";
|
||||
let description = [{
|
||||
The `call` operation represents a direct call to a function that is within
|
||||
@@ -788,6 +791,7 @@ def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
|
||||
let assemblyFormat = [{
|
||||
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
|
||||
}];
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -158,6 +158,27 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolUserOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
|
||||
let description = [{
|
||||
This interface describes an operation that may use a `Symbol`. This
|
||||
interface allows for users of symbols to hook into verification and other
|
||||
symbol related utilities that are either costly or otherwise disallowed
|
||||
within a traditional operation.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Verify the symbol uses held by this operation.",
|
||||
"LogicalResult", "verifySymbolUses",
|
||||
(ins "::mlir::SymbolTableCollection &":$symbolTable)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Symbol Traits
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -236,6 +236,21 @@ public:
|
||||
LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
|
||||
SmallVectorImpl<Operation *> &symbols);
|
||||
|
||||
/// Returns the operation registered with the given symbol name within the
|
||||
/// closest parent operation of, or including, 'from' with the
|
||||
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
|
||||
/// found.
|
||||
Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
|
||||
Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
|
||||
template <typename T>
|
||||
T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
|
||||
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
|
||||
}
|
||||
template <typename T>
|
||||
T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
|
||||
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
|
||||
}
|
||||
|
||||
/// Lookup, or create, a symbol table for an operation.
|
||||
SymbolTable &getSymbolTable(Operation *op);
|
||||
|
||||
|
||||
@@ -740,34 +740,33 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
|
||||
// CallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CallOp op) {
|
||||
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
// Check that the callee attribute was specified.
|
||||
auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
|
||||
auto fnAttr = getAttrOfType<FlatSymbolRefAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return op.emitOpError("requires a 'callee' symbol reference attribute");
|
||||
auto fn =
|
||||
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
|
||||
return emitOpError("requires a 'callee' symbol reference attribute");
|
||||
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
|
||||
if (!fn)
|
||||
return op.emitOpError() << "'" << fnAttr.getValue()
|
||||
<< "' does not reference a valid function";
|
||||
return emitOpError() << "'" << fnAttr.getValue()
|
||||
<< "' does not reference a valid function";
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
auto fnType = fn.getType();
|
||||
if (fnType.getNumInputs() != op.getNumOperands())
|
||||
return op.emitOpError("incorrect number of operands for callee");
|
||||
if (fnType.getNumInputs() != getNumOperands())
|
||||
return emitOpError("incorrect number of operands for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
|
||||
if (op.getOperand(i).getType() != fnType.getInput(i))
|
||||
return op.emitOpError("operand type mismatch: expected operand type ")
|
||||
if (getOperand(i).getType() != fnType.getInput(i))
|
||||
return emitOpError("operand type mismatch: expected operand type ")
|
||||
<< fnType.getInput(i) << ", but provided "
|
||||
<< op.getOperand(i).getType() << " for operand number " << i;
|
||||
<< getOperand(i).getType() << " for operand number " << i;
|
||||
|
||||
if (fnType.getNumResults() != op.getNumResults())
|
||||
return op.emitOpError("incorrect number of results for callee");
|
||||
if (fnType.getNumResults() != getNumResults())
|
||||
return emitOpError("incorrect number of results for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
|
||||
if (op.getResult(i).getType() != fnType.getResult(i))
|
||||
return op.emitOpError("result type mismatch");
|
||||
if (getResult(i).getType() != fnType.getResult(i))
|
||||
return emitOpError("result type mismatch");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -68,6 +68,30 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Walk all of the operations within the given set of regions, without
|
||||
/// traversing into any nested symbol tables. Stops walking if the result of the
|
||||
/// callback is anything other than `WalkResult::advance`.
|
||||
static Optional<WalkResult>
|
||||
walkSymbolTable(MutableArrayRef<Region> regions,
|
||||
function_ref<Optional<WalkResult>(Operation *)> callback) {
|
||||
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
|
||||
while (!worklist.empty()) {
|
||||
for (Operation &op : worklist.pop_back_val()->getOps()) {
|
||||
Optional<WalkResult> result = callback(&op);
|
||||
if (result != WalkResult::advance())
|
||||
return result;
|
||||
|
||||
// If this op defines a new symbol table scope, we can't traverse. Any
|
||||
// symbol references nested within 'op' are different semantically.
|
||||
if (!op.hasTrait<OpTrait::SymbolTable>()) {
|
||||
for (Region ®ion : op.getRegions())
|
||||
worklist.push_back(®ion);
|
||||
}
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolTable
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -347,7 +371,18 @@ LogicalResult detail::verifySymbolTable(Operation *op) {
|
||||
.append("see existing symbol definition here");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
|
||||
// Verify any nested symbol user operations.
|
||||
SymbolTableCollection symbolTable;
|
||||
auto verifySymbolUserFn = [&](Operation *op) -> Optional<WalkResult> {
|
||||
if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
|
||||
return WalkResult(user.verifySymbolUses(symbolTable));
|
||||
return WalkResult::advance();
|
||||
};
|
||||
|
||||
Optional<WalkResult> result =
|
||||
walkSymbolTable(op->getRegions(), verifySymbolUserFn);
|
||||
return success(result && !result->wasInterrupted());
|
||||
}
|
||||
|
||||
LogicalResult detail::verifySymbol(Operation *op) {
|
||||
@@ -452,25 +487,13 @@ static WalkResult walkSymbolRefs(
|
||||
static Optional<WalkResult> walkSymbolUses(
|
||||
MutableArrayRef<Region> regions,
|
||||
function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
|
||||
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
|
||||
while (!worklist.empty()) {
|
||||
for (Operation &op : worklist.pop_back_val()->getOps()) {
|
||||
if (walkSymbolRefs(&op, callback).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
return walkSymbolTable(regions, [&](Operation *op) -> Optional<WalkResult> {
|
||||
// Check that this isn't a potentially unknown symbol table.
|
||||
if (isPotentiallyUnknownSymbolTable(op))
|
||||
return llvm::None;
|
||||
|
||||
// Check that this isn't a potentially unknown symbol table.
|
||||
if (isPotentiallyUnknownSymbolTable(&op))
|
||||
return llvm::None;
|
||||
|
||||
// If this op defines a new symbol table scope, we can't traverse. Any
|
||||
// symbol references nested within 'op' are different semantically.
|
||||
if (!op.hasTrait<OpTrait::SymbolTable>()) {
|
||||
for (Region ®ion : op.getRegions())
|
||||
worklist.push_back(®ion);
|
||||
}
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
return walkSymbolRefs(op, callback);
|
||||
});
|
||||
}
|
||||
/// Walk all of the uses, for any symbol, that are nested within the given
|
||||
/// operation 'from', invoking the provided callback for each. This does not
|
||||
@@ -927,6 +950,22 @@ SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
|
||||
return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
|
||||
}
|
||||
|
||||
/// Returns the operation registered with the given symbol name within the
|
||||
/// closest parent operation of, or including, 'from' with the
|
||||
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
|
||||
/// found.
|
||||
Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
|
||||
StringRef symbol) {
|
||||
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
|
||||
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
|
||||
}
|
||||
Operation *
|
||||
SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
|
||||
SymbolRefAttr symbol) {
|
||||
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
|
||||
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
|
||||
}
|
||||
|
||||
/// Lookup, or create, a symbol table for an operation.
|
||||
SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
|
||||
auto it = symbolTables.try_emplace(op, nullptr);
|
||||
|
||||
Reference in New Issue
Block a user