[mlir][NFC] Update the Builtin dialect to use "Both" accessors

Differential Revision: https://reviews.llvm.org/D121189
This commit is contained in:
River Riddle
2022-03-07 19:13:02 -08:00
parent 171850c55a
commit f8d5c73c82
14 changed files with 48 additions and 51 deletions

View File

@@ -67,12 +67,9 @@ public:
matchAndRewrite(mlir::FuncOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
auto result = fir::NameUniquer::deconstruct(op.sym_name());
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
auto newName = mangleExternalName(result);
op.sym_nameAttr(rewriter.getStringAttr(newName));
SymbolTable::setSymbolName(op, newName);
}
auto result = fir::NameUniquer::deconstruct(op.getSymName());
if (fir::NameUniquer::isExternalFacingUniquedName(result))
op.setSymNameAttr(rewriter.getStringAttr(mangleExternalName(result)));
rewriter.finalizeRootUpdate(op);
return success();
}
@@ -165,7 +162,7 @@ void ExternalNameConversionPass::runOnOperation() {
});
target.addDynamicallyLegalOp<mlir::FuncOp>([](mlir::FuncOp op) {
return !fir::NameUniquer::needExternalNameMangling(op.sym_name());
return !fir::NameUniquer::needExternalNameMangling(op.getSymName());
});
target.addDynamicallyLegalOp<fir::GlobalOp>([](fir::GlobalOp op) {

View File

@@ -34,6 +34,7 @@ def Builtin_Dialect : Dialect {
public:
}];
let emitAccessorPrefix = kEmitAccessorPrefix_Both;
}
#endif // BUILTIN_BASE

View File

@@ -76,7 +76,7 @@ def FuncOp : Builtin_Op<"func", [
}];
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttr:$type,
TypeAttrOf<FunctionType>:$type,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
@@ -110,12 +110,6 @@ def FuncOp : Builtin_Op<"func", [
/// compatible.
void cloneInto(FuncOp dest, BlockAndValueMapping &mapper);
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// CallableOpInterface
//===------------------------------------------------------------------===//
@@ -144,7 +138,7 @@ def FuncOp : Builtin_Op<"func", [
LogicalResult verifyType() {
auto type = getTypeAttr().getValue();
if (!type.isa<FunctionType>())
return emitOpError("requires '" + getTypeAttrName() +
return emitOpError("requires '" + FunctionOpInterface::getTypeAttrName() +
"' attribute of function type");
return success();
}
@@ -188,16 +182,16 @@ def ModuleOp : Builtin_Op<"module", [
let arguments = (ins OptionalAttr<SymbolNameAttr>:$sym_name,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region SizedRegion<1>:$body);
let regions = (region SizedRegion<1>:$bodyRegion);
let assemblyFormat = "($sym_name^)? attr-dict-with-keyword $body";
let assemblyFormat = "($sym_name^)? attr-dict-with-keyword $bodyRegion";
let builders = [OpBuilder<(ins CArg<"Optional<StringRef>", "{}">:$name)>];
let extraClassDeclaration = [{
/// Construct a module from the given location with an optional name.
static ModuleOp create(Location loc, Optional<StringRef> name = llvm::None);
/// Return the name of this module if present.
Optional<StringRef> getName() { return sym_name(); }
Optional<StringRef> getName() { return getSymName(); }
//===------------------------------------------------------------------===//
// SymbolOpInterface Methods

View File

@@ -208,7 +208,7 @@ template <typename ConcreteOp>
LogicalResult verifyTrait(ConcreteOp op) {
if (!op.getTypeAttr())
return op.emitOpError("requires a type attribute '")
<< ConcreteOp::getTypeAttrName() << '\'';
<< function_interface_impl::getTypeAttrName() << '\'';
if (failed(op.verifyType()))
return failure();

View File

@@ -1268,6 +1268,11 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
let constBuilderCall = "::mlir::TypeAttr::get($0)";
}
class TypeAttrOf<Type ty>
: TypeAttrBase<ty.cppClassName, "type attribute of " # ty.description> {
let constBuilderCall = "::mlir::TypeAttr::get($0)";
}
// The mere presence of unit attributes has a meaning. Therefore, unit
// attributes are always treated as optional and accessors to them return
// "true" if the attribute is present and "false" otherwise.

View File

@@ -492,18 +492,18 @@ struct UnrealizedConversionCastOpLowering
matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> convertedTypes;
if (succeeded(typeConverter->convertTypes(op.outputs().getTypes(),
if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
convertedTypes)) &&
convertedTypes == adaptor.inputs().getTypes()) {
rewriter.replaceOp(op, adaptor.inputs());
convertedTypes == adaptor.getInputs().getTypes()) {
rewriter.replaceOp(op, adaptor.getInputs());
return success();
}
convertedTypes.clear();
if (succeeded(typeConverter->convertTypes(adaptor.inputs().getTypes(),
if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
convertedTypes)) &&
convertedTypes == op.outputs().getType()) {
rewriter.replaceOp(op, adaptor.inputs());
convertedTypes == op.getOutputs().getType()) {
rewriter.replaceOp(op, adaptor.getInputs());
return success();
}
return failure();

View File

@@ -37,15 +37,15 @@ struct UnrealizedConversionCastPassthrough
auto users = op->getUsers();
if (!llvm::all_of(users, [&](Operation *user) {
if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
return other.getResultTypes() == op.inputs().getTypes() &&
other.inputs() == op.outputs();
return other.getResultTypes() == op.getInputs().getTypes() &&
other.getInputs() == op.getOutputs();
return false;
})) {
return rewriter.notifyMatchFailure(op, "live unrealized conversion cast");
}
for (Operation *user : users)
rewriter.replaceOp(user, op.inputs());
rewriter.replaceOp(user, op.getInputs());
rewriter.eraseOp(op);
return success();

View File

@@ -463,8 +463,7 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
ArrayRef<Type> computeFuncInputTypes =
computeFunc.func.type().cast<FunctionType>().getInputs();
ArrayRef<Type> computeFuncInputTypes = computeFunc.func.getType().getInputs();
// Compared to the parallel compute function async dispatch function takes
// additional !async.group argument. Also instead of a single `blockIndex` it
@@ -541,7 +540,7 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
operands[1] = midIndex;
operands[2] = end;
executeBuilder.create<func::CallOp>(executeLoc, func.sym_name(),
executeBuilder.create<func::CallOp>(executeLoc, func.getSymName(),
func.getCallableResults(), operands);
executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
};
@@ -562,7 +561,7 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
SmallVector<Value> computeFuncOperands = {blockStart};
computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
b.create<func::CallOp>(computeFunc.func.sym_name(),
b.create<func::CallOp>(computeFunc.func.getSymName(),
computeFunc.func.getCallableResults(),
computeFuncOperands);
b.create<func::ReturnOp>(ValueRange());
@@ -609,7 +608,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
SmallVector<Value> operands = {c0, blockSize};
appendBlockComputeOperands(operands);
b.create<func::CallOp>(parallelComputeFunction.func.sym_name(),
b.create<func::CallOp>(parallelComputeFunction.func.getSymName(),
parallelComputeFunction.func.getCallableResults(),
operands);
b.create<scf::YieldOp>();
@@ -628,7 +627,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
SmallVector<Value> operands = {group, c0, blockCount, blockSize};
appendBlockComputeOperands(operands);
b.create<func::CallOp>(asyncDispatchFunction.sym_name(),
b.create<func::CallOp>(asyncDispatchFunction.getSymName(),
asyncDispatchFunction.getCallableResults(),
operands);
@@ -687,7 +686,7 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
// Call parallel compute function inside the async.execute region.
auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
Location executeLoc, ValueRange executeArgs) {
executeBuilder.create<func::CallOp>(executeLoc, compute.sym_name(),
executeBuilder.create<func::CallOp>(executeLoc, compute.getSymName(),
compute.getCallableResults(),
computeFuncOperands(iv));
executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
@@ -704,7 +703,7 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder);
// Call parallel compute function for the first block in the caller thread.
b.create<func::CallOp>(compute.sym_name(), compute.getCallableResults(),
b.create<func::CallOp>(compute.getSymName(), compute.getCallableResults(),
computeFuncOperands(c0));
// Wait for the completion of all async compute operations.

View File

@@ -180,7 +180,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
// `async.await` op lowering will create resume blocks for async
// continuations, and will conditionally branch to cleanup or suspend blocks.
for (Block &block : func.body().getBlocks()) {
for (Block &block : func.getBody().getBlocks()) {
if (&block == entryBlock || &block == cleanupBlock ||
&block == suspendBlock)
continue;
@@ -677,7 +677,7 @@ funcsToCoroutines(ModuleOp module,
// this dict between the passes is ugly.
if (isAllowedToBlock(func) ||
outlinedFunctions.find(func) == outlinedFunctions.end()) {
for (Operation &op : func.body().getOps()) {
for (Operation &op : func.getBody().getOps()) {
if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
funcWorklist.push_back(func);
break;

View File

@@ -149,7 +149,7 @@ getFuncOpAnalysisState(const BufferizationState &state, FuncOp funcOp) {
/// Return nullptr if there is no such unique ReturnOp.
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
func::ReturnOp returnOp;
for (Block &b : funcOp.body()) {
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
if (returnOp)
return nullptr;
@@ -460,7 +460,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
// 3. Rewrite the bbArgs.
// Iterate on the original `numArgs` and replace them in order.
// This guarantees the argument order still matches after the rewrite.
Block &frontBlock = funcOp.body().front();
Block &frontBlock = funcOp.getBody().front();
unsigned numArgs = frontBlock.getNumArguments();
for (unsigned idx = 0; idx < numArgs; ++idx) {
auto bbArg = frontBlock.getArgument(0);
@@ -527,7 +527,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
// For each FuncOp, the number of CallOpInterface it contains.
DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult {
if (!funcOp.body().empty()) {
if (!funcOp.getBody().empty()) {
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError()
@@ -624,7 +624,7 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
argumentTypes.push_back(desiredMemrefType);
// If funcOp's body is not empty, change the bbArg type and propagate.
if (!funcOp.body().empty()) {
if (!funcOp.getBody().empty()) {
BlockArgument bbArg = funcOp.getArgument(argNumber);
bbArg.setType(desiredMemrefType);
OpBuilder b(bbArg.getContext());
@@ -886,7 +886,7 @@ struct CallOpInterface
// 4. Create the new CallOp.
Operation *newCallOp = rewriter.create<func::CallOp>(
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
// Get replacement values for non-tensor / non-equivalent results.
for (unsigned i = 0; i < replacementValues.size(); ++i) {
@@ -1009,7 +1009,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
// Analyze ops.
for (FuncOp funcOp : moduleState.orderedFuncOps) {
// No body => no analysis.
if (funcOp.body().empty())
if (funcOp.getBody().empty())
continue;
// Now analyzing function.
@@ -1037,7 +1037,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
// Bufferize function bodies.
for (FuncOp funcOp : moduleState.orderedFuncOps) {
// No body => no analysis.
if (funcOp.body().empty())
if (funcOp.getBody().empty())
continue;
if (failed(bufferizeOp(funcOp, state)))

View File

@@ -283,7 +283,7 @@ public:
IRRewriter rewriter(func.getContext());
propagateShapesInRegion(func.body());
propagateShapesInRegion(func.getBody());
// Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
// the FuncOp type.

View File

@@ -101,7 +101,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
state.addAttribute(function_interface_impl::getTypeAttrName(),
TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
@@ -287,8 +288,8 @@ LogicalResult ModuleOp::verify() {
LogicalResult
UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
SmallVectorImpl<OpFoldResult> &foldResults) {
OperandRange operands = inputs();
ResultRange results = outputs();
OperandRange operands = getInputs();
ResultRange results = getOutputs();
if (operands.getType() == results.getType()) {
foldResults.append(operands.begin(), operands.end());

View File

@@ -30,7 +30,7 @@ struct TestPrintInvalidPass
void runOnOperation() override {
Location loc = getOperation().getLoc();
OpBuilder builder(getOperation().body());
OpBuilder builder(getOperation().getBodyRegion());
auto funcOp = builder.create<FuncOp>(
loc, "test", FunctionType::get(getOperation().getContext(), {}, {}));
funcOp.addEntryBlock();

View File

@@ -41,7 +41,7 @@ protected:
// Create ValueShapeRange on the arith.addi operation.
ValueShapeRange addiRange() {
auto &fnBody = mapFn.body();
auto &fnBody = mapFn.getBody();
return std::next(fnBody.front().begin())->getOperands();
}