[mlir] Add a new SymbolUserOpInterface class

The initial goal of this interface is to fix the current problems with verifying symbol user operations, but can extend beyond that in the future. The current problems with the verification of symbol uses are:
* Extremely inefficient:
Most current symbol users perform the symbol lookup using the slow O(N) string compare methods, which can lead to extremely long verification times in large modules.
* Invalid/break the constraints of verification pass
If the symbol reference is not-flat(and even if it is flat in some cases) a verifier for an operation is not permitted to touch the referenced operation because it may be in the process of being mutated by a different thread within the pass manager.

The new SymbolUserOpInterface exposes a method `verifySymbolUses` that will be invoked from the parent symbol table to allow for verifying the constraints of any referenced symbols. This method is passed a `SymbolTableCollection` to allow for O(1) lookups of any necessary symbol operation.

Differential Revision: https://reviews.llvm.org/D89512
This commit is contained in:
River Riddle
2020-10-16 11:57:00 -07:00
parent 7bc7d0ac7a
commit 71eeb5ec4d
7 changed files with 127 additions and 37 deletions

View File

@@ -231,4 +231,12 @@ format of the header for each interface section goes as follows:
##### SymbolInterfaces
* `SymbolOpInterface` - Used to represent [`Symbol`](SymbolsAndSymbolTables.md#symbol) operations which reside immediately within a region that defines a [`SymbolTable`](SymbolsAndSymbolTables.md#symbol-table).
* `SymbolOpInterface` - Used to represent
[`Symbol`](SymbolsAndSymbolTables.md#symbol) operations which reside
immediately within a region that defines a
[`SymbolTable`](SymbolsAndSymbolTables.md#symbol-table).
* `SymbolUserOpInterface` - Used to represent operations that reference
[`Symbol`](SymbolsAndSymbolTables.md#symbol) operations. This provides the
ability to perform safe and efficient verification of symbol uses, as well
as additional functionality.

View File

@@ -142,6 +142,10 @@ See the `LangRef` definition of the
[`SymbolRefAttr`](LangRef.md#symbol-reference-attribute) for more information
about the structure of this attribute.
Operations that reference a `Symbol` and want to perform verification and
general mutation of the symbol should implement the `SymbolUserOpInterface` to
ensure that symbol accesses are legal and efficient.
### Manipulating a Symbol
As described above, `SymbolRefs` act as an auxiliary way of defining uses of

View File

@@ -15,6 +15,7 @@
include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -733,7 +734,9 @@ def BranchOp : Std_Op<"br",
// CallOp
//===----------------------------------------------------------------------===//
def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
def CallOp : Std_Op<"call",
[CallOpInterface, MemRefsNormalizable,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "call operation";
let description = [{
The `call` operation represents a direct call to a function that is within
@@ -788,6 +791,7 @@ def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
let assemblyFormat = [{
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
}];
let verifier = ?;
}
//===----------------------------------------------------------------------===//

View File

@@ -158,6 +158,27 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
}];
}
//===----------------------------------------------------------------------===//
// SymbolUserOpInterface
//===----------------------------------------------------------------------===//
def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
let description = [{
This interface describes an operation that may use a `Symbol`. This
interface allows for users of symbols to hook into verification and other
symbol related utilities that are either costly or otherwise disallowed
within a traditional operation.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<"Verify the symbol uses held by this operation.",
"LogicalResult", "verifySymbolUses",
(ins "::mlir::SymbolTableCollection &":$symbolTable)
>,
];
}
//===----------------------------------------------------------------------===//
// Symbol Traits
//===----------------------------------------------------------------------===//

View File

@@ -236,6 +236,21 @@ public:
LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
SmallVectorImpl<Operation *> &symbols);
/// Returns the operation registered with the given symbol name within the
/// closest parent operation of, or including, 'from' with the
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
/// found.
Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
template <typename T>
T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
}
template <typename T>
T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
}
/// Lookup, or create, a symbol table for an operation.
SymbolTable &getSymbolTable(Operation *op);

View File

@@ -740,34 +740,33 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
// CallOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(CallOp op) {
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the callee attribute was specified.
auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
auto fnAttr = getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' symbol reference attribute");
auto fn =
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
return emitOpError("requires a 'callee' symbol reference attribute");
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
if (!fn)
return op.emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
// Verify that the operand and result types match the callee.
auto fnType = fn.getType();
if (fnType.getNumInputs() != op.getNumOperands())
return op.emitOpError("incorrect number of operands for callee");
if (fnType.getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
if (op.getOperand(i).getType() != fnType.getInput(i))
return op.emitOpError("operand type mismatch: expected operand type ")
if (getOperand(i).getType() != fnType.getInput(i))
return emitOpError("operand type mismatch: expected operand type ")
<< fnType.getInput(i) << ", but provided "
<< op.getOperand(i).getType() << " for operand number " << i;
<< getOperand(i).getType() << " for operand number " << i;
if (fnType.getNumResults() != op.getNumResults())
return op.emitOpError("incorrect number of results for callee");
if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
if (op.getResult(i).getType() != fnType.getResult(i))
return op.emitOpError("result type mismatch");
if (getResult(i).getType() != fnType.getResult(i))
return emitOpError("result type mismatch");
return success();
}

View File

@@ -68,6 +68,30 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
return success();
}
/// Walk all of the operations within the given set of regions, without
/// traversing into any nested symbol tables. Stops walking if the result of the
/// callback is anything other than `WalkResult::advance`.
static Optional<WalkResult>
walkSymbolTable(MutableArrayRef<Region> regions,
function_ref<Optional<WalkResult>(Operation *)> callback) {
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
while (!worklist.empty()) {
for (Operation &op : worklist.pop_back_val()->getOps()) {
Optional<WalkResult> result = callback(&op);
if (result != WalkResult::advance())
return result;
// If this op defines a new symbol table scope, we can't traverse. Any
// symbol references nested within 'op' are different semantically.
if (!op.hasTrait<OpTrait::SymbolTable>()) {
for (Region &region : op.getRegions())
worklist.push_back(&region);
}
}
}
return WalkResult::advance();
}
//===----------------------------------------------------------------------===//
// SymbolTable
//===----------------------------------------------------------------------===//
@@ -347,7 +371,18 @@ LogicalResult detail::verifySymbolTable(Operation *op) {
.append("see existing symbol definition here");
}
}
return success();
// Verify any nested symbol user operations.
SymbolTableCollection symbolTable;
auto verifySymbolUserFn = [&](Operation *op) -> Optional<WalkResult> {
if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
return WalkResult(user.verifySymbolUses(symbolTable));
return WalkResult::advance();
};
Optional<WalkResult> result =
walkSymbolTable(op->getRegions(), verifySymbolUserFn);
return success(result && !result->wasInterrupted());
}
LogicalResult detail::verifySymbol(Operation *op) {
@@ -452,25 +487,13 @@ static WalkResult walkSymbolRefs(
static Optional<WalkResult> walkSymbolUses(
MutableArrayRef<Region> regions,
function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
while (!worklist.empty()) {
for (Operation &op : worklist.pop_back_val()->getOps()) {
if (walkSymbolRefs(&op, callback).wasInterrupted())
return WalkResult::interrupt();
return walkSymbolTable(regions, [&](Operation *op) -> Optional<WalkResult> {
// Check that this isn't a potentially unknown symbol table.
if (isPotentiallyUnknownSymbolTable(op))
return llvm::None;
// Check that this isn't a potentially unknown symbol table.
if (isPotentiallyUnknownSymbolTable(&op))
return llvm::None;
// If this op defines a new symbol table scope, we can't traverse. Any
// symbol references nested within 'op' are different semantically.
if (!op.hasTrait<OpTrait::SymbolTable>()) {
for (Region &region : op.getRegions())
worklist.push_back(&region);
}
}
}
return WalkResult::advance();
return walkSymbolRefs(op, callback);
});
}
/// Walk all of the uses, for any symbol, that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
@@ -927,6 +950,22 @@ SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
}
/// Returns the operation registered with the given symbol name within the
/// closest parent operation of, or including, 'from' with the
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
/// found.
Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
StringRef symbol) {
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
}
Operation *
SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
SymbolRefAttr symbol) {
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
}
/// Lookup, or create, a symbol table for an operation.
SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
auto it = symbolTables.try_emplace(op, nullptr);