Decouple running a conversion from the DialectConversion class. The DialectConversion class is only necessary for type signature changes(block arguments or function arguments). This isn't always desired when performing a dialect conversion. This allows for those conversions without this need to run per function instead of per module.

--

PiperOrigin-RevId: 249657549
This commit is contained in:
River Riddle
2019-05-23 09:23:33 -07:00
committed by Mehdi Amini
parent c0f41e5bb3
commit 14d1cfbccb
8 changed files with 89 additions and 69 deletions

View File

@@ -436,7 +436,8 @@ void linalg::convertToLLVM(mlir::Module &module) {
// Convert Linalg ops to the LLVM IR dialect using the converter defined
// above.
auto r = Lowering(getDescriptorConverters).convert(&module);
Lowering lowering(getDescriptorConverters);
auto r = applyConverter(module, lowering);
(void)r;
assert(succeeded(r) && "conversion failed");
}

View File

@@ -143,7 +143,7 @@ void linalg::convertLinalg3ToLLVM(Module &module) {
assert(succeeded(rr) && "affine loop lowering failed");
auto lowering = makeLinalgToLLVMLowering(getConversions);
auto r = lowering->convert(&module);
auto r = applyConverter(module, *lowering);
(void)r;
assert(succeeded(r) && "conversion failed");
}

View File

@@ -134,7 +134,8 @@ protected:
/// dialect.
struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {
void runOnModule() override {
if (failed(EarlyLowering().convert(&getModule()))) {
EarlyLowering lowering;
if (failed(applyConverter(getModule(), lowering))) {
getModule().getContext()->emitError(
mlir::UnknownLoc::get(getModule().getContext()),
"Error lowering Toy\n");

View File

@@ -343,8 +343,9 @@ protected:
/// and is targeting LLVM otherwise.
struct LateLoweringPass : public ModulePass<LateLoweringPass> {
void runOnModule() override {
// Perform Toy specific lowering
if (failed(LateLowering().convert(&getModule()))) {
// Perform Toy specific lowering.
LateLowering lowering;
if (failed(applyConverter(getModule(), lowering))) {
getModule().getContext()->emitError(
UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n");
signalPassFailure();

View File

@@ -39,8 +39,8 @@ class Value;
/// Base class for the dialect conversion patterns that require type changes.
/// Specific conversions must derive this class and implement least one
/// `rewrite` method.
/// NOTE: These conversion patterns can only be used with the DialectConversion
/// class.
/// NOTE: These conversion patterns can only be used with the 'apply*' methods
/// below.
class DialectConversionPattern : public RewritePattern {
public:
/// Construct an DialectConversionPattern. `rootName` must correspond to the
@@ -112,22 +112,10 @@ private:
// match against the list of conversions. On the first match, call
// `rewrite` for the operations, and advance to the next iteration. If no
// match is found, replicate the operation as is.
/// 3. Update all attributes of function type to point to the new functions.
/// 4. Replace old functions with new functions in the module.
/// If any error happened during the conversion, the pass fails as soon as
/// possible.
///
/// If conversion fails for a specific function, that functions remains
/// unmodified. Otherwise, successfully converted functions will remain
/// converted.
class DialectConversion {
public:
virtual ~DialectConversion() = default;
/// Run the converter on the provided module.
LLVM_NODISCARD
LogicalResult convert(Module *m);
/// Derived classes must implement this hook to produce a set of conversion
/// patterns to apply. They may use `mlirContext` to obtain registered
/// dialects or operations. This will be called in the beginning of the
@@ -170,6 +158,19 @@ public:
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs);
};
/// Convert the given module with the provided dialect conversion object.
/// If conversion fails for a specific function, those functions remains
/// unmodified.
LLVM_NODISCARD
LogicalResult applyConverter(Module &module, DialectConversion &converter);
/// Convert the given function with the provided conversion patterns. This will
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LLVM_NODISCARD
LogicalResult applyConversionPatterns(Function &fn,
OwningRewritePatternList &&patterns);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_

View File

@@ -1006,9 +1006,9 @@ class LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
public:
// Run the dialect converter on the module.
void runOnModule() override {
Module *m = &getModule();
LLVM::ensureDistinctSuccessors(m);
if (failed(impl.convert(m)))
Module &m = getModule();
LLVM::ensureDistinctSuccessors(&m);
if (failed(applyConverter(m, impl)))
signalPassFailure();
}

View File

@@ -608,7 +608,8 @@ void LowerLinalgToLLVMPass::runOnModule() {
signalPassFailure();
// Convert to the LLVM IR dialect using the converter defined above.
if (failed(Lowering().convert(&module)))
Lowering lowering;
if (failed(applyConverter(module, lowering)))
signalPassFailure();
}

View File

@@ -226,13 +226,13 @@ void DialectConversionPattern::rewrite(Operation *op,
// FunctionConverter
//===----------------------------------------------------------------------===//
namespace {
// This class converts a single function using a given DialectConversion
// structure.
// This class converts a single function using the given pattern matcher. If a
// DialectConversion object is also provided, then the types of block arguments
// will be converted using the appropriate 'convertType' calls.
class FunctionConverter {
public:
// Constructs a FunctionConverter.
explicit FunctionConverter(MLIRContext *ctx, DialectConversion *conversion,
RewritePatternMatcher &matcher)
explicit FunctionConverter(MLIRContext *ctx, RewritePatternMatcher &matcher,
DialectConversion *conversion = nullptr)
: dialectConversion(conversion), matcher(matcher) {}
/// Converts the given function to the dialect using hooks defined in
@@ -319,11 +319,15 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
Region &region, RegionParent *parent) {
assert(!region.empty() && "expected non-empty region");
// Create the arguments of each of the blocks in the region.
for (Block &block : region)
for (auto *arg : block.getArguments())
if (failed(convertArgument(rewriter, arg, parent->getLoc())))
return failure();
// Create the arguments of each of the blocks in the region. If a type
// converter was not provided, then we don't need to change any of the block
// types.
if (dialectConversion) {
for (Block &block : region)
for (auto *arg : block.getArguments())
if (failed(convertArgument(rewriter, arg, parent->getLoc())))
return failure();
}
// Start a DFS-order traversal of the CFG to make sure defs are converted
// before uses in dominated blocks.
@@ -346,8 +350,8 @@ LogicalResult FunctionConverter::convertFunction(Function *f) {
// Rewrite the function body.
DialectConversionRewriter rewriter(f);
if (failed(convertRegion(rewriter, f->getBody(), f))) {
// Reset any of the converted arguments.
rewriter.argConverter.discardRewrites();
// Reset any of the generated rewrites.
rewriter.discardRewrites();
return failure();
}
@@ -360,24 +364,6 @@ LogicalResult FunctionConverter::convertFunction(Function *f) {
// DialectConversion
//===----------------------------------------------------------------------===//
namespace {
/// This class represents a function to be converted. It allows for converting
/// the body of functions and the signature in two phases.
struct ConvertedFunction {
ConvertedFunction(Function *fn, FunctionType newType,
ArrayRef<NamedAttributeList> newFunctionArgAttrs)
: fn(fn), newType(newType),
newFunctionArgAttrs(newFunctionArgAttrs.begin(),
newFunctionArgAttrs.end()) {}
/// The function to convert.
Function *fn;
/// The new type and argument attributes for the function.
FunctionType newType;
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
};
} // end anonymous namespace
// Create a function type with arguments and results converted, and argument
// attributes passed through.
FunctionType DialectConversion::convertFunctionSignatureType(
@@ -403,21 +389,38 @@ FunctionType DialectConversion::convertFunctionSignatureType(
return FunctionType::get(arguments, results, type.getContext());
}
// Converts the module as follows.
// 1. Call `convertFunction` on each function of the module and collect the
// mapping between old and new functions.
// 2. Remap all function attributes in the new functions to point to the new
// functions instead of the old ones.
// 3. Replace old functions with the new in the module.
LogicalResult DialectConversion::convert(Module *module) {
if (!module)
return failure();
//===----------------------------------------------------------------------===//
// applyConversionPatterns
//===----------------------------------------------------------------------===//
namespace {
/// This class represents a function to be converted. It allows for converting
/// the body of functions and the signature in two phases.
struct ConvertedFunction {
ConvertedFunction(Function *fn, FunctionType newType,
ArrayRef<NamedAttributeList> newFunctionArgAttrs)
: fn(fn), newType(newType),
newFunctionArgAttrs(newFunctionArgAttrs.begin(),
newFunctionArgAttrs.end()) {}
/// The function to convert.
Function *fn;
/// The new type and argument attributes for the function.
FunctionType newType;
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
};
} // end anonymous namespace
/// Convert the given module with the provided dialect conversion object.
/// If conversion fails for a specific function, those functions remains
/// unmodified.
LogicalResult mlir::applyConverter(Module &module,
DialectConversion &converter) {
// Grab the conversion patterns from the converter and create the pattern
// matcher.
MLIRContext *context = module->getContext();
MLIRContext *context = module.getContext();
OwningRewritePatternList patterns;
initConverters(patterns, context);
converter.initConverters(patterns, context);
RewritePatternMatcher matcher(std::move(patterns));
// Try to convert each of the functions within the module. Defer updating the
@@ -426,18 +429,18 @@ LogicalResult DialectConversion::convert(Module *module) {
// public signatures of the functions within the module before they are
// updated.
std::vector<ConvertedFunction> toConvert;
toConvert.reserve(module->getFunctions().size());
for (auto &func : *module) {
toConvert.reserve(module.getFunctions().size());
for (auto &func : module) {
// Convert the function type using the dialect converter.
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
FunctionType newType = convertFunctionSignatureType(
FunctionType newType = converter.convertFunctionSignatureType(
func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs);
if (!newType || !newType.isa<FunctionType>())
return func.emitError("could not convert function type");
// Convert the body of this function.
FunctionConverter converter(context, this, matcher);
if (failed(converter.convertFunction(&func)))
FunctionConverter funcConverter(context, matcher, &converter);
if (failed(funcConverter.convertFunction(&func)))
return failure();
// Add function signature to be updated.
@@ -453,3 +456,15 @@ LogicalResult DialectConversion::convert(Module *module) {
return success();
}
/// Convert the given function with the provided conversion patterns. This will
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LogicalResult
mlir::applyConversionPatterns(Function &fn,
OwningRewritePatternList &&patterns) {
// Convert the body of this function.
RewritePatternMatcher matcher(std::move(patterns));
FunctionConverter converter(fn.getContext(), matcher);
return converter.convertFunction(&fn);
}