Remove lowerAffineConstructs and lowerControlFlow in favor of providing patterns.

These methods don't compose well with the rest of conversion framework, and create artificial breaks in conversion. Replace these methods with two(populateAffineToStdConversionPatterns and populateLoopToStdConversionPatterns respectively) that populate a list of patterns to perform the same behavior.

PiperOrigin-RevId: 258219277
This commit is contained in:
River Riddle
2019-07-15 12:52:44 -07:00
committed by Mehdi Amini
parent e7a2ef21f9
commit 2087bf6386
8 changed files with 61 additions and 56 deletions

View File

@@ -410,27 +410,19 @@ struct LinalgTypeConverter : public LLVMTypeConverter {
} // end anonymous namespace
LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) {
for (auto func : module.getOps<FuncOp>()) {
if (failed(mlir::lowerAffineConstructs(func)))
return failure();
if (failed(mlir::lowerControlFlow(func)))
return failure();
}
// Convert Linalg ops to the LLVM IR dialect using the converter defined
// above.
LinalgTypeConverter converter(module.getContext());
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, module.getContext());
populateLoopToStdConversionPatterns(patterns, module.getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
ConversionTarget target(*module.getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyConversionPatterns(module, target, converter,
std::move(patterns))))
return failure();
return success();
return applyConversionPatterns(module, target, converter,
std::move(patterns));
}
namespace {

View File

@@ -148,18 +148,12 @@ static void populateLinalg3ToLLVMConversionPatterns(
}
LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
// Remove affine constructs.
for (auto func : module.getOps<FuncOp>()) {
if (failed(lowerAffineConstructs(func)))
return failure();
if (failed(mlir::lowerControlFlow(func)))
return failure();
}
// Convert Linalg ops to the LLVM IR dialect using the converter defined
// above.
LinalgTypeConverter converter(module.getContext());
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, module.getContext());
populateLoopToStdConversionPatterns(patterns, module.getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
populateLinalg3ToLLVMConversionPatterns(patterns, module.getContext());

View File

@@ -18,16 +18,27 @@
#ifndef MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
#define MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
#include <memory>
#include <vector>
namespace mlir {
class FuncOp;
class FunctionPassBase;
struct LogicalResult;
class ModulePassBase;
class MLIRContext;
class RewritePattern;
/// Lowers loop.for, loop.if and loop.terminator ops to CFG.
LogicalResult lowerControlFlow(FuncOp func);
// Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
/// Collect a set of patterns to lower from loop.for, loop.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular
/// convert structured control flow into CFG branch-based control flow.
void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx);
/// Creates a pass to convert loop.for, loop.if and loop.terminator ops to CFG.
ModulePassBase *createConvertToCFGPass();
FunctionPassBase *createConvertToCFGPass();
} // namespace mlir

View File

@@ -19,25 +19,32 @@
#define MLIR_TRANSFORMS_LOWERAFFINE_H
#include "mlir/Support/LLVM.h"
#include <vector>
namespace mlir {
class AffineExpr;
class AffineForOp;
class FuncOp;
class Location;
struct LogicalResult;
class MLIRContext;
class OpBuilder;
class RewritePattern;
class Value;
// Owning list of rewriting patterns.
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
/// Emit code that computes the given affine expression using standard
/// arithmetic operations applied to the provided dimension and symbol values.
Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
ArrayRef<Value *> dimValues,
ArrayRef<Value *> symbolValues);
/// Convert from the Affine dialect to the Standard dialect, in particular
/// convert structured affine control flow into CFG branch-based control flow.
LogicalResult lowerAffineConstructs(FuncOp function);
/// Collect a set of patterns to convert from the Affine dialect to the Standard
/// dialect, in particular convert structured affine control flow into CFG
/// branch-based control flow.
void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx);
/// Emit code that computes the lower bound of the given affine loop using
/// standard arithmetic operations.

View File

@@ -43,8 +43,8 @@ using namespace mlir::loop;
namespace {
struct ControlFlowToCFGPass : public ModulePass<ControlFlowToCFGPass> {
void runOnModule() override;
struct ControlFlowToCFGPass : public FunctionPass<ControlFlowToCFGPass> {
void runOnFunction() override;
};
// Create a CFG subgraph for the loop around its body blocks (if the body
@@ -270,22 +270,23 @@ IfLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
return matchSuccess();
}
LogicalResult mlir::lowerControlFlow(FuncOp func) {
OwningRewritePatternList patterns;
void mlir::populateLoopToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build(
patterns, func.getContext());
ConversionTarget target(*func.getContext());
patterns, ctx);
}
void ControlFlowToCFGPass::runOnFunction() {
OwningRewritePatternList patterns;
populateLoopToStdConversionPatterns(patterns, &getContext());
ConversionTarget target(getContext());
target.addLegalDialect<StandardOpsDialect>();
return applyConversionPatterns(func, target, std::move(patterns));
if (failed(
applyConversionPatterns(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
void ControlFlowToCFGPass::runOnModule() {
for (auto func : getModule().getOps<FuncOp>())
if (failed(mlir::lowerControlFlow(func)))
return signalPassFailure();
}
ModulePassBase *mlir::createConvertToCFGPass() {
FunctionPassBase *mlir::createConvertToCFGPass() {
return new ControlFlowToCFGPass();
}

View File

@@ -1052,10 +1052,6 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
return signalPassFailure();
ModuleOp m = getModule();
for (auto func : m.getOps<FuncOp>())
if (failed(mlir::lowerControlFlow(func)))
signalPassFailure();
LLVM::ensureDistinctSuccessors(m);
std::unique_ptr<LLVMTypeConverter> typeConverter =
typeConverterMaker(&getContext());
@@ -1063,6 +1059,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
return signalPassFailure();
OwningRewritePatternList patterns;
populateLoopToStdConversionPatterns(patterns, m.getContext());
patternListFiller(*typeConverter, patterns);
ConversionTarget target(getContext());

View File

@@ -806,13 +806,12 @@ void LowerLinalgToLLVMPass::runOnModule() {
for (auto f : module.getOps<FuncOp>()) {
lowerLinalgSubViewOps(f);
lowerLinalgForToCFG(f);
if (failed(lowerAffineConstructs(f)))
signalPassFailure();
}
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LinalgTypeConverter converter(&getContext());
populateAffineToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());

View File

@@ -506,21 +506,25 @@ public:
} // end namespace
LogicalResult mlir::lowerAffineConstructs(FuncOp function) {
OwningRewritePatternList patterns;
void mlir::populateAffineToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering,
AffineStoreLowering, AffineForLowering, AffineIfLowering,
AffineTerminatorLowering>::build(patterns,
function.getContext());
ConversionTarget target(*function.getContext());
target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
return applyConversionPatterns(function, target, std::move(patterns));
AffineTerminatorLowering>::build(patterns, ctx);
}
namespace {
class LowerAffinePass : public FunctionPass<LowerAffinePass> {
void runOnFunction() override { lowerAffineConstructs(getFunction()); }
void runOnFunction() override {
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext());
ConversionTarget target(getContext());
target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
if (failed(applyConversionPatterns(getFunction(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace