mirror of
https://github.com/intel/llvm.git
synced 2026-02-01 08:56:15 +08:00
Decouple running a conversion from the DialectConversion class. The DialectConversion class is only necessary for type signature changes(block arguments or function arguments). This isn't always desired when performing a dialect conversion. This allows for those conversions without this need to run per function instead of per module.
-- PiperOrigin-RevId: 249657549
This commit is contained in:
committed by
Mehdi Amini
parent
c0f41e5bb3
commit
14d1cfbccb
@@ -436,7 +436,8 @@ void linalg::convertToLLVM(mlir::Module &module) {
|
||||
|
||||
// Convert Linalg ops to the LLVM IR dialect using the converter defined
|
||||
// above.
|
||||
auto r = Lowering(getDescriptorConverters).convert(&module);
|
||||
Lowering lowering(getDescriptorConverters);
|
||||
auto r = applyConverter(module, lowering);
|
||||
(void)r;
|
||||
assert(succeeded(r) && "conversion failed");
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ void linalg::convertLinalg3ToLLVM(Module &module) {
|
||||
assert(succeeded(rr) && "affine loop lowering failed");
|
||||
|
||||
auto lowering = makeLinalgToLLVMLowering(getConversions);
|
||||
auto r = lowering->convert(&module);
|
||||
auto r = applyConverter(module, *lowering);
|
||||
(void)r;
|
||||
assert(succeeded(r) && "conversion failed");
|
||||
}
|
||||
|
||||
@@ -134,7 +134,8 @@ protected:
|
||||
/// dialect.
|
||||
struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {
|
||||
void runOnModule() override {
|
||||
if (failed(EarlyLowering().convert(&getModule()))) {
|
||||
EarlyLowering lowering;
|
||||
if (failed(applyConverter(getModule(), lowering))) {
|
||||
getModule().getContext()->emitError(
|
||||
mlir::UnknownLoc::get(getModule().getContext()),
|
||||
"Error lowering Toy\n");
|
||||
|
||||
@@ -343,8 +343,9 @@ protected:
|
||||
/// and is targeting LLVM otherwise.
|
||||
struct LateLoweringPass : public ModulePass<LateLoweringPass> {
|
||||
void runOnModule() override {
|
||||
// Perform Toy specific lowering
|
||||
if (failed(LateLowering().convert(&getModule()))) {
|
||||
// Perform Toy specific lowering.
|
||||
LateLowering lowering;
|
||||
if (failed(applyConverter(getModule(), lowering))) {
|
||||
getModule().getContext()->emitError(
|
||||
UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n");
|
||||
signalPassFailure();
|
||||
|
||||
@@ -39,8 +39,8 @@ class Value;
|
||||
/// Base class for the dialect conversion patterns that require type changes.
|
||||
/// Specific conversions must derive this class and implement least one
|
||||
/// `rewrite` method.
|
||||
/// NOTE: These conversion patterns can only be used with the DialectConversion
|
||||
/// class.
|
||||
/// NOTE: These conversion patterns can only be used with the 'apply*' methods
|
||||
/// below.
|
||||
class DialectConversionPattern : public RewritePattern {
|
||||
public:
|
||||
/// Construct an DialectConversionPattern. `rootName` must correspond to the
|
||||
@@ -112,22 +112,10 @@ private:
|
||||
// match against the list of conversions. On the first match, call
|
||||
// `rewrite` for the operations, and advance to the next iteration. If no
|
||||
// match is found, replicate the operation as is.
|
||||
/// 3. Update all attributes of function type to point to the new functions.
|
||||
/// 4. Replace old functions with new functions in the module.
|
||||
/// If any error happened during the conversion, the pass fails as soon as
|
||||
/// possible.
|
||||
///
|
||||
/// If conversion fails for a specific function, that functions remains
|
||||
/// unmodified. Otherwise, successfully converted functions will remain
|
||||
/// converted.
|
||||
class DialectConversion {
|
||||
public:
|
||||
virtual ~DialectConversion() = default;
|
||||
|
||||
/// Run the converter on the provided module.
|
||||
LLVM_NODISCARD
|
||||
LogicalResult convert(Module *m);
|
||||
|
||||
/// Derived classes must implement this hook to produce a set of conversion
|
||||
/// patterns to apply. They may use `mlirContext` to obtain registered
|
||||
/// dialects or operations. This will be called in the beginning of the
|
||||
@@ -170,6 +158,19 @@ public:
|
||||
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs);
|
||||
};
|
||||
|
||||
/// Convert the given module with the provided dialect conversion object.
|
||||
/// If conversion fails for a specific function, those functions remains
|
||||
/// unmodified.
|
||||
LLVM_NODISCARD
|
||||
LogicalResult applyConverter(Module &module, DialectConversion &converter);
|
||||
|
||||
/// Convert the given function with the provided conversion patterns. This will
|
||||
/// convert as many of the operations within 'fn' as possible given the set of
|
||||
/// patterns.
|
||||
LLVM_NODISCARD
|
||||
LogicalResult applyConversionPatterns(Function &fn,
|
||||
OwningRewritePatternList &&patterns);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
|
||||
|
||||
@@ -1006,9 +1006,9 @@ class LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
||||
public:
|
||||
// Run the dialect converter on the module.
|
||||
void runOnModule() override {
|
||||
Module *m = &getModule();
|
||||
LLVM::ensureDistinctSuccessors(m);
|
||||
if (failed(impl.convert(m)))
|
||||
Module &m = getModule();
|
||||
LLVM::ensureDistinctSuccessors(&m);
|
||||
if (failed(applyConverter(m, impl)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
||||
@@ -608,7 +608,8 @@ void LowerLinalgToLLVMPass::runOnModule() {
|
||||
signalPassFailure();
|
||||
|
||||
// Convert to the LLVM IR dialect using the converter defined above.
|
||||
if (failed(Lowering().convert(&module)))
|
||||
Lowering lowering;
|
||||
if (failed(applyConverter(module, lowering)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
||||
@@ -226,13 +226,13 @@ void DialectConversionPattern::rewrite(Operation *op,
|
||||
// FunctionConverter
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace {
|
||||
// This class converts a single function using a given DialectConversion
|
||||
// structure.
|
||||
// This class converts a single function using the given pattern matcher. If a
|
||||
// DialectConversion object is also provided, then the types of block arguments
|
||||
// will be converted using the appropriate 'convertType' calls.
|
||||
class FunctionConverter {
|
||||
public:
|
||||
// Constructs a FunctionConverter.
|
||||
explicit FunctionConverter(MLIRContext *ctx, DialectConversion *conversion,
|
||||
RewritePatternMatcher &matcher)
|
||||
explicit FunctionConverter(MLIRContext *ctx, RewritePatternMatcher &matcher,
|
||||
DialectConversion *conversion = nullptr)
|
||||
: dialectConversion(conversion), matcher(matcher) {}
|
||||
|
||||
/// Converts the given function to the dialect using hooks defined in
|
||||
@@ -319,11 +319,15 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
|
||||
Region ®ion, RegionParent *parent) {
|
||||
assert(!region.empty() && "expected non-empty region");
|
||||
|
||||
// Create the arguments of each of the blocks in the region.
|
||||
for (Block &block : region)
|
||||
for (auto *arg : block.getArguments())
|
||||
if (failed(convertArgument(rewriter, arg, parent->getLoc())))
|
||||
return failure();
|
||||
// Create the arguments of each of the blocks in the region. If a type
|
||||
// converter was not provided, then we don't need to change any of the block
|
||||
// types.
|
||||
if (dialectConversion) {
|
||||
for (Block &block : region)
|
||||
for (auto *arg : block.getArguments())
|
||||
if (failed(convertArgument(rewriter, arg, parent->getLoc())))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Start a DFS-order traversal of the CFG to make sure defs are converted
|
||||
// before uses in dominated blocks.
|
||||
@@ -346,8 +350,8 @@ LogicalResult FunctionConverter::convertFunction(Function *f) {
|
||||
// Rewrite the function body.
|
||||
DialectConversionRewriter rewriter(f);
|
||||
if (failed(convertRegion(rewriter, f->getBody(), f))) {
|
||||
// Reset any of the converted arguments.
|
||||
rewriter.argConverter.discardRewrites();
|
||||
// Reset any of the generated rewrites.
|
||||
rewriter.discardRewrites();
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -360,24 +364,6 @@ LogicalResult FunctionConverter::convertFunction(Function *f) {
|
||||
// DialectConversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class represents a function to be converted. It allows for converting
|
||||
/// the body of functions and the signature in two phases.
|
||||
struct ConvertedFunction {
|
||||
ConvertedFunction(Function *fn, FunctionType newType,
|
||||
ArrayRef<NamedAttributeList> newFunctionArgAttrs)
|
||||
: fn(fn), newType(newType),
|
||||
newFunctionArgAttrs(newFunctionArgAttrs.begin(),
|
||||
newFunctionArgAttrs.end()) {}
|
||||
|
||||
/// The function to convert.
|
||||
Function *fn;
|
||||
/// The new type and argument attributes for the function.
|
||||
FunctionType newType;
|
||||
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// Create a function type with arguments and results converted, and argument
|
||||
// attributes passed through.
|
||||
FunctionType DialectConversion::convertFunctionSignatureType(
|
||||
@@ -403,21 +389,38 @@ FunctionType DialectConversion::convertFunctionSignatureType(
|
||||
return FunctionType::get(arguments, results, type.getContext());
|
||||
}
|
||||
|
||||
// Converts the module as follows.
|
||||
// 1. Call `convertFunction` on each function of the module and collect the
|
||||
// mapping between old and new functions.
|
||||
// 2. Remap all function attributes in the new functions to point to the new
|
||||
// functions instead of the old ones.
|
||||
// 3. Replace old functions with the new in the module.
|
||||
LogicalResult DialectConversion::convert(Module *module) {
|
||||
if (!module)
|
||||
return failure();
|
||||
//===----------------------------------------------------------------------===//
|
||||
// applyConversionPatterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class represents a function to be converted. It allows for converting
|
||||
/// the body of functions and the signature in two phases.
|
||||
struct ConvertedFunction {
|
||||
ConvertedFunction(Function *fn, FunctionType newType,
|
||||
ArrayRef<NamedAttributeList> newFunctionArgAttrs)
|
||||
: fn(fn), newType(newType),
|
||||
newFunctionArgAttrs(newFunctionArgAttrs.begin(),
|
||||
newFunctionArgAttrs.end()) {}
|
||||
|
||||
/// The function to convert.
|
||||
Function *fn;
|
||||
/// The new type and argument attributes for the function.
|
||||
FunctionType newType;
|
||||
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Convert the given module with the provided dialect conversion object.
|
||||
/// If conversion fails for a specific function, those functions remains
|
||||
/// unmodified.
|
||||
LogicalResult mlir::applyConverter(Module &module,
|
||||
DialectConversion &converter) {
|
||||
// Grab the conversion patterns from the converter and create the pattern
|
||||
// matcher.
|
||||
MLIRContext *context = module->getContext();
|
||||
MLIRContext *context = module.getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
initConverters(patterns, context);
|
||||
converter.initConverters(patterns, context);
|
||||
RewritePatternMatcher matcher(std::move(patterns));
|
||||
|
||||
// Try to convert each of the functions within the module. Defer updating the
|
||||
@@ -426,18 +429,18 @@ LogicalResult DialectConversion::convert(Module *module) {
|
||||
// public signatures of the functions within the module before they are
|
||||
// updated.
|
||||
std::vector<ConvertedFunction> toConvert;
|
||||
toConvert.reserve(module->getFunctions().size());
|
||||
for (auto &func : *module) {
|
||||
toConvert.reserve(module.getFunctions().size());
|
||||
for (auto &func : module) {
|
||||
// Convert the function type using the dialect converter.
|
||||
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
|
||||
FunctionType newType = convertFunctionSignatureType(
|
||||
FunctionType newType = converter.convertFunctionSignatureType(
|
||||
func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs);
|
||||
if (!newType || !newType.isa<FunctionType>())
|
||||
return func.emitError("could not convert function type");
|
||||
|
||||
// Convert the body of this function.
|
||||
FunctionConverter converter(context, this, matcher);
|
||||
if (failed(converter.convertFunction(&func)))
|
||||
FunctionConverter funcConverter(context, matcher, &converter);
|
||||
if (failed(funcConverter.convertFunction(&func)))
|
||||
return failure();
|
||||
|
||||
// Add function signature to be updated.
|
||||
@@ -453,3 +456,15 @@ LogicalResult DialectConversion::convert(Module *module) {
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Convert the given function with the provided conversion patterns. This will
|
||||
/// convert as many of the operations within 'fn' as possible given the set of
|
||||
/// patterns.
|
||||
LogicalResult
|
||||
mlir::applyConversionPatterns(Function &fn,
|
||||
OwningRewritePatternList &&patterns) {
|
||||
// Convert the body of this function.
|
||||
RewritePatternMatcher matcher(std::move(patterns));
|
||||
FunctionConverter converter(fn.getContext(), matcher);
|
||||
return converter.convertFunction(&fn);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user