mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[mlir][scf] Implement conversion from scf.forall to scf.parallel (#94109)
There is currently no path to lower scf.forall to scf.parallel with the goal of targeting the OpenMP dialect. In the SCF->ControlFlow conversion, scf.forall is briefly converted to scf.parallel, but the scf.parallel is lowered directly to a sequential loop. This makes experimenting with scf.forall for CPU execution difficult. This change factors out the rewrite in the SCF->ControlFlow pass into a utility function that can then be used in the SCF->ControlFlow lowering and via a separate -scf-forall-to-parallel pass. --------- Co-authored-by: Spenser Bauman <sabauma@fastmail>
This commit is contained in:
@@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
|
||||
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
|
||||
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
let summary = "Converts scf.forall into a nest of scf.for operations";
|
||||
let description = [{
|
||||
Converts the `scf.forall` operation pointed to by the given handle into an
|
||||
`scf.parallel` operation.
|
||||
|
||||
The operand handle must be associated with exactly one payload operation.
|
||||
|
||||
Loops with outputs are not supported.
|
||||
|
||||
#### Return Modes
|
||||
|
||||
Consumes the operand handle. Produces a silenceable failure if the operand
|
||||
is not associated with a single `scf.forall` payload operation.
|
||||
Returns a handle to the new `scf.parallel` operation.
|
||||
Produces a silenceable failure if another number of resulting handles is
|
||||
requested.
|
||||
}];
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
|
||||
|
||||
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
|
||||
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
|
||||
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
|
||||
/// Creates a pass that converts SCF forall loops to SCF for loops.
|
||||
std::unique_ptr<Pass> createForallToForLoopPass();
|
||||
|
||||
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
|
||||
std::unique_ptr<Pass> createForallToParallelLoopPass();
|
||||
|
||||
// Creates a pass which lowers for loops into while loops.
|
||||
std::unique_ptr<Pass> createForToWhileLoopPass();
|
||||
|
||||
|
||||
@@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
|
||||
let constructor = "mlir::createForallToForLoopPass()";
|
||||
}
|
||||
|
||||
def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
|
||||
let summary = "Convert SCF forall loops to SCF parallel loops";
|
||||
let constructor = "mlir::createForallToParallelLoopPass()";
|
||||
}
|
||||
|
||||
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
|
||||
let summary = "Convert SCF for loops to SCF while loops";
|
||||
let constructor = "mlir::createForToWhileLoopPass()";
|
||||
|
||||
@@ -39,6 +39,11 @@ class WhileOp;
|
||||
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
|
||||
SmallVectorImpl<Operation *> *results = nullptr);
|
||||
|
||||
/// Try converting scf.forall into an scf.parallel loop.
|
||||
/// The conversion is only supported for forall operations with no results.
|
||||
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
|
||||
ParallelOp *result = nullptr);
|
||||
|
||||
/// Fuses all adjacent scf.parallel operations with identical bounds and step
|
||||
/// into one scf.parallel operations. Uses a naive aliasing and dependency
|
||||
/// analysis.
|
||||
|
||||
@@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRSCFToControlFlow
|
||||
MLIRArithDialect
|
||||
MLIRControlFlowDialect
|
||||
MLIRSCFDialect
|
||||
MLIRSCFTransforms
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
@@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
|
||||
|
||||
LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
Location loc = forallOp.getLoc();
|
||||
if (!forallOp.getOutputs().empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
forallOp,
|
||||
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
|
||||
|
||||
// Convert mixed bounds and steps to SSA values.
|
||||
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
|
||||
rewriter, loc, forallOp.getMixedLowerBound());
|
||||
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
|
||||
rewriter, loc, forallOp.getMixedUpperBound());
|
||||
SmallVector<Value> steps =
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
|
||||
|
||||
// Create empty scf.parallel op.
|
||||
auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
|
||||
rewriter.eraseBlock(¶llelOp.getRegion().front());
|
||||
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
|
||||
parallelOp.getRegion().begin());
|
||||
// Replace the terminator.
|
||||
rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front());
|
||||
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
|
||||
parallelOp.getRegion().front().getTerminator());
|
||||
|
||||
// Erase the scf.forall op.
|
||||
rewriter.replaceOp(forallOp, parallelOp);
|
||||
return success();
|
||||
return scf::forallToParallelLoop(rewriter, forallOp);
|
||||
}
|
||||
|
||||
void mlir::populateSCFToControlFlowConversionPatterns(
|
||||
|
||||
@@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ForallToForOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
auto payload = state.getPayloadOps(getTarget());
|
||||
if (!llvm::hasSingleElement(payload))
|
||||
return emitSilenceableError() << "expected a single payload op";
|
||||
|
||||
auto target = dyn_cast<scf::ForallOp>(*payload.begin());
|
||||
if (!target) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
emitSilenceableError() << "expected the payload to be scf.forall";
|
||||
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
if (!target.getOutputs().empty()) {
|
||||
return emitSilenceableError()
|
||||
<< "unsupported shared outputs (didn't bufferize?)";
|
||||
}
|
||||
|
||||
if (getNumResults() != 1) {
|
||||
DiagnosedSilenceableFailure diag = emitSilenceableError()
|
||||
<< "op expects one result, given "
|
||||
<< getNumResults();
|
||||
diag.attachNote(target.getLoc()) << "payload op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
scf::ParallelOp opResult;
|
||||
if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
emitSilenceableError() << "failed to convert forall into parallel";
|
||||
return diag;
|
||||
}
|
||||
|
||||
results.set(cast<OpResult>(getTransformed()[0]), {opResult});
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LoopOutlineOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
ForallToFor.cpp
|
||||
ForallToParallel.cpp
|
||||
ForToWhile.cpp
|
||||
LoopCanonicalization.cpp
|
||||
LoopPipelining.cpp
|
||||
|
||||
86
mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
Normal file
86
mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
Normal file
@@ -0,0 +1,86 @@
|
||||
//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Transforms SCF.ForallOp's into SCF.ParallelOps's.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
|
||||
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
|
||||
scf::ForallOp forallOp,
|
||||
scf::ParallelOp *result) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(forallOp);
|
||||
|
||||
Location loc = forallOp.getLoc();
|
||||
if (!forallOp.getOutputs().empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
forallOp,
|
||||
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
|
||||
|
||||
// Convert mixed bounds and steps to SSA values.
|
||||
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
|
||||
rewriter, loc, forallOp.getMixedLowerBound());
|
||||
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
|
||||
rewriter, loc, forallOp.getMixedUpperBound());
|
||||
SmallVector<Value> steps =
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
|
||||
|
||||
// Create empty scf.parallel op.
|
||||
auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
|
||||
rewriter.eraseBlock(¶llelOp.getRegion().front());
|
||||
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
|
||||
parallelOp.getRegion().begin());
|
||||
// Replace the terminator.
|
||||
rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front());
|
||||
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
|
||||
parallelOp.getRegion().front().getTerminator());
|
||||
|
||||
// If the mapping attribute is present, propagate to the new parallelOp.
|
||||
if (forallOp.getMapping())
|
||||
parallelOp->setAttr("mapping", *forallOp.getMapping());
|
||||
|
||||
// Erase the scf.forall op.
|
||||
rewriter.replaceOp(forallOp, parallelOp);
|
||||
|
||||
if (result)
|
||||
*result = parallelOp;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ForallToParallelLoop final
|
||||
: public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
|
||||
void runOnOperation() override {
|
||||
Operation *parentOp = getOperation();
|
||||
IRRewriter rewriter(parentOp->getContext());
|
||||
|
||||
parentOp->walk([&](scf::ForallOp forallOp) {
|
||||
if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
|
||||
return std::make_unique<ForallToParallelLoop>();
|
||||
}
|
||||
80
mlir/test/Dialect/SCF/forall-to-parallel.mlir
Normal file
80
mlir/test/Dialect/SCF/forall-to-parallel.mlir
Normal file
@@ -0,0 +1,80 @@
|
||||
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s
|
||||
|
||||
func.func private @callee(%i: index, %j: index)
|
||||
|
||||
// CHECK-LABEL: @two_iters
|
||||
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
|
||||
func.func @two_iters(%ub1: index, %ub2: index) {
|
||||
scf.forall (%i, %j) in (%ub1, %ub2) {
|
||||
func.call @callee(%i, %j) : (index, index) -> ()
|
||||
}
|
||||
|
||||
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
|
||||
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
|
||||
// CHECK: scf.reduce
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @callee(%i: index, %j: index)
|
||||
|
||||
// CHECK-LABEL: @repeated
|
||||
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
|
||||
func.func @repeated(%ub1: index, %ub2: index) {
|
||||
scf.forall (%i, %j) in (%ub1, %ub2) {
|
||||
func.call @callee(%i, %j) : (index, index) -> ()
|
||||
}
|
||||
|
||||
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
|
||||
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
|
||||
// CHECK: scf.reduce
|
||||
scf.forall (%i, %j) in (%ub1, %ub2) {
|
||||
func.call @callee(%i, %j) : (index, index) -> ()
|
||||
}
|
||||
|
||||
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
|
||||
// CHECK: func.call @callee(%[[IV3]], %[[IV4]])
|
||||
// CHECK: scf.reduce
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @callee(%i: index, %j: index, %k: index, %l: index)
|
||||
|
||||
// CHECK-LABEL: @nested
|
||||
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
|
||||
func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
|
||||
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
|
||||
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
|
||||
// CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
|
||||
// CHECK: scf.reduce
|
||||
// CHECK: }
|
||||
// CHECK: scf.reduce
|
||||
// CHECK: }
|
||||
scf.forall (%i, %j) in (%ub1, %ub2) {
|
||||
scf.forall (%k, %l) in (%ub3, %ub4) {
|
||||
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mapping_attr
|
||||
func.func @mapping_attr() -> () {
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: scf.reduce
|
||||
// CHECK: {mapping = [#gpu.thread<x>]}
|
||||
|
||||
%num_threads = arith.constant 100 : index
|
||||
|
||||
scf.forall (%thread_idx) in (%num_threads) {
|
||||
scf.forall.in_parallel {
|
||||
}
|
||||
} {mapping = [#gpu.thread<x>]}
|
||||
return
|
||||
|
||||
}
|
||||
60
mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
Normal file
60
mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
Normal file
@@ -0,0 +1,60 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
|
||||
|
||||
func.func private @callee(%i: index, %j: index)
|
||||
|
||||
// CHECK-LABEL: @two_iters
|
||||
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
|
||||
func.func @two_iters(%ub1: index, %ub2: index) {
|
||||
scf.forall (%i, %j) in (%ub1, %ub2) {
|
||||
func.call @callee(%i, %j) : (index, index) -> ()
|
||||
}
|
||||
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
|
||||
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
|
||||
// CHECK: scf.reduce
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @callee(%i: index, %j: index)
|
||||
|
||||
func.func @repeated(%ub1: index, %ub2: index) {
|
||||
scf.forall (%i, %j) in (%ub1, %ub2) {
|
||||
func.call @callee(%i, %j) : (index, index) -> ()
|
||||
}
|
||||
scf.forall (%i, %j) in (%ub1, %ub2) {
|
||||
func.call @callee(%i, %j) : (index, index) -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{expected a single payload op}}
|
||||
transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-note @below {{payload op}}
|
||||
func.func private @callee(%i: index, %j: index)
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{expected the payload to be scf.forall}}
|
||||
transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user