mirror of
https://github.com/intel/llvm.git
synced 2026-02-03 02:26:27 +08:00
Remove lowerAffineConstructs and lowerControlFlow in favor of providing patterns.
These methods don't compose well with the rest of conversion framework, and create artificial breaks in conversion. Replace these methods with two(populateAffineToStdConversionPatterns and populateLoopToStdConversionPatterns respectively) that populate a list of patterns to perform the same behavior. PiperOrigin-RevId: 258219277
This commit is contained in:
committed by
Mehdi Amini
parent
e7a2ef21f9
commit
2087bf6386
@@ -410,27 +410,19 @@ struct LinalgTypeConverter : public LLVMTypeConverter {
|
||||
} // end anonymous namespace
|
||||
|
||||
LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) {
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (failed(mlir::lowerAffineConstructs(func)))
|
||||
return failure();
|
||||
if (failed(mlir::lowerControlFlow(func)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Convert Linalg ops to the LLVM IR dialect using the converter defined
|
||||
// above.
|
||||
LinalgTypeConverter converter(module.getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
populateAffineToStdConversionPatterns(patterns, module.getContext());
|
||||
populateLoopToStdConversionPatterns(patterns, module.getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
|
||||
|
||||
ConversionTarget target(*module.getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
if (failed(applyConversionPatterns(module, target, converter,
|
||||
std::move(patterns))))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
return applyConversionPatterns(module, target, converter,
|
||||
std::move(patterns));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -148,18 +148,12 @@ static void populateLinalg3ToLLVMConversionPatterns(
|
||||
}
|
||||
|
||||
LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
|
||||
// Remove affine constructs.
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (failed(lowerAffineConstructs(func)))
|
||||
return failure();
|
||||
if (failed(mlir::lowerControlFlow(func)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Convert Linalg ops to the LLVM IR dialect using the converter defined
|
||||
// above.
|
||||
LinalgTypeConverter converter(module.getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
populateAffineToStdConversionPatterns(patterns, module.getContext());
|
||||
populateLoopToStdConversionPatterns(patterns, module.getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
|
||||
populateLinalg3ToLLVMConversionPatterns(patterns, module.getContext());
|
||||
|
||||
@@ -18,16 +18,27 @@
|
||||
#ifndef MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
|
||||
#define MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
class FuncOp;
|
||||
class FunctionPassBase;
|
||||
struct LogicalResult;
|
||||
class ModulePassBase;
|
||||
class MLIRContext;
|
||||
class RewritePattern;
|
||||
|
||||
/// Lowers loop.for, loop.if and loop.terminator ops to CFG.
|
||||
LogicalResult lowerControlFlow(FuncOp func);
|
||||
// Owning list of rewriting patterns.
|
||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
|
||||
/// Collect a set of patterns to lower from loop.for, loop.if, and
|
||||
/// loop.terminator to CFG operations within the Standard dialect, in particular
|
||||
/// convert structured control flow into CFG branch-based control flow.
|
||||
void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
/// Creates a pass to convert loop.for, loop.if and loop.terminator ops to CFG.
|
||||
ModulePassBase *createConvertToCFGPass();
|
||||
FunctionPassBase *createConvertToCFGPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -19,25 +19,32 @@
|
||||
#define MLIR_TRANSFORMS_LOWERAFFINE_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
class AffineExpr;
|
||||
class AffineForOp;
|
||||
class FuncOp;
|
||||
class Location;
|
||||
struct LogicalResult;
|
||||
class MLIRContext;
|
||||
class OpBuilder;
|
||||
class RewritePattern;
|
||||
class Value;
|
||||
|
||||
// Owning list of rewriting patterns.
|
||||
using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
|
||||
/// Emit code that computes the given affine expression using standard
|
||||
/// arithmetic operations applied to the provided dimension and symbol values.
|
||||
Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
|
||||
ArrayRef<Value *> dimValues,
|
||||
ArrayRef<Value *> symbolValues);
|
||||
|
||||
/// Convert from the Affine dialect to the Standard dialect, in particular
|
||||
/// convert structured affine control flow into CFG branch-based control flow.
|
||||
LogicalResult lowerAffineConstructs(FuncOp function);
|
||||
/// Collect a set of patterns to convert from the Affine dialect to the Standard
|
||||
/// dialect, in particular convert structured affine control flow into CFG
|
||||
/// branch-based control flow.
|
||||
void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
/// Emit code that computes the lower bound of the given affine loop using
|
||||
/// standard arithmetic operations.
|
||||
|
||||
@@ -43,8 +43,8 @@ using namespace mlir::loop;
|
||||
|
||||
namespace {
|
||||
|
||||
struct ControlFlowToCFGPass : public ModulePass<ControlFlowToCFGPass> {
|
||||
void runOnModule() override;
|
||||
struct ControlFlowToCFGPass : public FunctionPass<ControlFlowToCFGPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Create a CFG subgraph for the loop around its body blocks (if the body
|
||||
@@ -270,22 +270,23 @@ IfLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
LogicalResult mlir::lowerControlFlow(FuncOp func) {
|
||||
OwningRewritePatternList patterns;
|
||||
void mlir::populateLoopToStdConversionPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build(
|
||||
patterns, func.getContext());
|
||||
ConversionTarget target(*func.getContext());
|
||||
patterns, ctx);
|
||||
}
|
||||
|
||||
void ControlFlowToCFGPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
return applyConversionPatterns(func, target, std::move(patterns));
|
||||
if (failed(
|
||||
applyConversionPatterns(getFunction(), target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
void ControlFlowToCFGPass::runOnModule() {
|
||||
for (auto func : getModule().getOps<FuncOp>())
|
||||
if (failed(mlir::lowerControlFlow(func)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
ModulePassBase *mlir::createConvertToCFGPass() {
|
||||
FunctionPassBase *mlir::createConvertToCFGPass() {
|
||||
return new ControlFlowToCFGPass();
|
||||
}
|
||||
|
||||
|
||||
@@ -1052,10 +1052,6 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
||||
return signalPassFailure();
|
||||
|
||||
ModuleOp m = getModule();
|
||||
for (auto func : m.getOps<FuncOp>())
|
||||
if (failed(mlir::lowerControlFlow(func)))
|
||||
signalPassFailure();
|
||||
|
||||
LLVM::ensureDistinctSuccessors(m);
|
||||
std::unique_ptr<LLVMTypeConverter> typeConverter =
|
||||
typeConverterMaker(&getContext());
|
||||
@@ -1063,6 +1059,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
||||
return signalPassFailure();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
populateLoopToStdConversionPatterns(patterns, m.getContext());
|
||||
patternListFiller(*typeConverter, patterns);
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
|
||||
@@ -806,13 +806,12 @@ void LowerLinalgToLLVMPass::runOnModule() {
|
||||
for (auto f : module.getOps<FuncOp>()) {
|
||||
lowerLinalgSubViewOps(f);
|
||||
lowerLinalgForToCFG(f);
|
||||
if (failed(lowerAffineConstructs(f)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
// Convert to the LLVM IR dialect using the converter defined above.
|
||||
OwningRewritePatternList patterns;
|
||||
LinalgTypeConverter converter(&getContext());
|
||||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
|
||||
|
||||
|
||||
@@ -506,21 +506,25 @@ public:
|
||||
|
||||
} // end namespace
|
||||
|
||||
LogicalResult mlir::lowerAffineConstructs(FuncOp function) {
|
||||
OwningRewritePatternList patterns;
|
||||
void mlir::populateAffineToStdConversionPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
|
||||
AffineDmaWaitLowering, AffineLoadLowering,
|
||||
AffineStoreLowering, AffineForLowering, AffineIfLowering,
|
||||
AffineTerminatorLowering>::build(patterns,
|
||||
function.getContext());
|
||||
ConversionTarget target(*function.getContext());
|
||||
target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
|
||||
return applyConversionPatterns(function, target, std::move(patterns));
|
||||
AffineTerminatorLowering>::build(patterns, ctx);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class LowerAffinePass : public FunctionPass<LowerAffinePass> {
|
||||
void runOnFunction() override { lowerAffineConstructs(getFunction()); }
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
|
||||
if (failed(applyConversionPatterns(getFunction(), target,
|
||||
std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
Reference in New Issue
Block a user