[mlir][scf] Implement conversion from scf.forall to scf.parallel (#94109)

There is currently no path to lower scf.forall to scf.parallel with the
goal of targeting the OpenMP dialect.

In the SCF->ControlFlow conversion, scf.forall is briefly converted to
scf.parallel, but the scf.parallel is lowered directly to a sequential
loop. This makes experimenting with scf.forall for CPU execution
difficult.

This change factors out the rewrite in the SCF->ControlFlow pass into a
utility function that can then be used in the SCF->ControlFlow lowering
and via a separate -scf-forall-to-parallel pass.

---------

Co-authored-by: Spenser Bauman <sabauma@fastmail>
This commit is contained in:
Spenser Bauman
2024-06-04 15:41:09 -04:00
committed by GitHub
parent e775efcec4
commit 0b665c3dd2
11 changed files with 313 additions and 27 deletions

View File

@@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}
def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Converts scf.forall into a nest of scf.for operations";
let description = [{
Converts the `scf.forall` operation pointed to by the given handle into an
`scf.parallel` operation.
The operand handle must be associated with exactly one payload operation.
Loops with outputs are not supported.
#### Return Modes
Consumes the operand handle. Produces a silenceable failure if the operand
is not associated with a single `scf.forall` payload operation.
Returns a handle to the new `scf.parallel` operation.
Produces a silenceable failure if another number of resulting handles is
requested.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {

View File

@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
/// Creates a pass that converts SCF forall loops to SCF for loops.
std::unique_ptr<Pass> createForallToForLoopPass();
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
std::unique_ptr<Pass> createForallToParallelLoopPass();
// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();

View File

@@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
let constructor = "mlir::createForallToForLoopPass()";
}
def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
let summary = "Convert SCF forall loops to SCF parallel loops";
let constructor = "mlir::createForallToParallelLoopPass()";
}
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";

View File

@@ -39,6 +39,11 @@ class WhileOp;
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
SmallVectorImpl<Operation *> *results = nullptr);
/// Try converting scf.forall into an scf.parallel loop.
/// The conversion is only supported for forall operations with no results.
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
ParallelOp *result = nullptr);
/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.

View File

@@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRSCFToControlFlow
MLIRArithDialect
MLIRControlFlowDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRTransforms
)

View File

@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
@@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
PatternRewriter &rewriter) const {
Location loc = forallOp.getLoc();
if (!forallOp.getOutputs().empty())
return rewriter.notifyMatchFailure(
forallOp,
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
// Convert mixed bounds and steps to SSA values.
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> steps =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
// Create empty scf.parallel op.
auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
rewriter.eraseBlock(&parallelOp.getRegion().front());
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
// Replace the terminator.
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());
// Erase the scf.forall op.
rewriter.replaceOp(forallOp, parallelOp);
return success();
return scf::forallToParallelLoop(rewriter, forallOp);
}
void mlir::populateSCFToControlFlowConversionPatterns(

View File

@@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// ForallToForOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto payload = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(payload))
return emitSilenceableError() << "expected a single payload op";
auto target = dyn_cast<scf::ForallOp>(*payload.begin());
if (!target) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected the payload to be scf.forall";
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
return diag;
}
if (!target.getOutputs().empty()) {
return emitSilenceableError()
<< "unsupported shared outputs (didn't bufferize?)";
}
if (getNumResults() != 1) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "op expects one result, given "
<< getNumResults();
diag.attachNote(target.getLoc()) << "payload op";
return diag;
}
scf::ParallelOp opResult;
if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "failed to convert forall into parallel";
return diag;
}
results.set(cast<OpResult>(getTransformed()[0]), {opResult});
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//

View File

@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp

View File

@@ -0,0 +1,86 @@
//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ForallOp's into SCF.ParallelOps's.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
scf::ForallOp forallOp,
scf::ParallelOp *result) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(forallOp);
Location loc = forallOp.getLoc();
if (!forallOp.getOutputs().empty())
return rewriter.notifyMatchFailure(
forallOp,
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
// Convert mixed bounds and steps to SSA values.
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> steps =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
// Create empty scf.parallel op.
auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
rewriter.eraseBlock(&parallelOp.getRegion().front());
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
// Replace the terminator.
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());
// If the mapping attribute is present, propagate to the new parallelOp.
if (forallOp.getMapping())
parallelOp->setAttr("mapping", *forallOp.getMapping());
// Erase the scf.forall op.
rewriter.replaceOp(forallOp, parallelOp);
if (result)
*result = parallelOp;
return success();
}
namespace {
struct ForallToParallelLoop final
: public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
void runOnOperation() override {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());
parentOp->walk([&](scf::ForallOp forallOp) {
if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
return signalPassFailure();
}
});
}
};
} // namespace
std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
return std::make_unique<ForallToParallelLoop>();
}

View File

@@ -0,0 +1,80 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s
func.func private @callee(%i: index, %j: index)
// CHECK-LABEL: @two_iters
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @two_iters(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
// CHECK: scf.reduce
return
}
// -----
func.func private @callee(%i: index, %j: index)
// CHECK-LABEL: @repeated
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @repeated(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
// CHECK: scf.reduce
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV3]], %[[IV4]])
// CHECK: scf.reduce
return
}
// -----
func.func private @callee(%i: index, %j: index, %k: index, %l: index)
// CHECK-LABEL: @nested
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
// CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
// CHECK: scf.reduce
// CHECK: }
// CHECK: scf.reduce
// CHECK: }
scf.forall (%i, %j) in (%ub1, %ub2) {
scf.forall (%k, %l) in (%ub3, %ub4) {
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
}
}
return
}
// -----
// CHECK-LABEL: @mapping_attr
func.func @mapping_attr() -> () {
// CHECK: scf.parallel
// CHECK: scf.reduce
// CHECK: {mapping = [#gpu.thread<x>]}
%num_threads = arith.constant 100 : index
scf.forall (%thread_idx) in (%num_threads) {
scf.forall.in_parallel {
}
} {mapping = [#gpu.thread<x>]}
return
}

View File

@@ -0,0 +1,60 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
func.func private @callee(%i: index, %j: index)
// CHECK-LABEL: @two_iters
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @two_iters(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
// CHECK: scf.reduce
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
// -----
func.func private @callee(%i: index, %j: index)
func.func @repeated(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected a single payload op}}
transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
// -----
// expected-note @below {{payload op}}
func.func private @callee(%i: index, %j: index)
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected the payload to be scf.forall}}
transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}