[MLIR][OpenMP] Add private clause to omp.parallel (#81452)

Extends the `omp.parallel` op by adding a `private` clause to model
[first]private variables. This uses the `omp.private` op to map
privatized variables to their corresponding privatizers.

Example `omp.private` op with `private` variable:
```
omp.parallel private(@x.privatizer %arg0 -> %arg1 : !llvm.ptr) {
  ^bb0(%arg1: !llvm.ptr):
    // ... use %arg1 ...
    omp.terminator
}
```

Whether the variable is private or firstprivate is determined by the
attributes of the corresponding `omp.private` op.
This commit is contained in:
Kareem Ergawy
2024-02-18 09:02:06 +01:00
committed by GitHub
parent 1ecbab56dc
commit 833fea40d2
7 changed files with 255 additions and 67 deletions

View File

@@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr);
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
/*privatizers=*/nullptr);
}
static mlir::omp::SectionOp

View File

@@ -276,7 +276,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
Variadic<AnyType>:$private_vars,
OptionalAttr<SymbolRefArrayAttr>:$privatizers);
let regions = (region AnyRegion:$region);
@@ -297,7 +299,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars),
$reductions, $private_vars, type($private_vars),
$privatizers) attr-dict
}];
let hasVerifier = 1;
}

View File

@@ -450,7 +450,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocators_vars = */ llvm::SmallVector<Value>{},
/* reduction_vars = */ llvm::SmallVector<Value>{},
/* reductions = */ ArrayAttr{},
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
/* private_vars = */ ValueRange(),
/* privatizers = */ nullptr);
{
OpBuilder::InsertionGuard guard(rewriter);

View File

@@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//
ParseResult
parseReductionClause(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
SmallVectorImpl<OpAsmParser::Argument> &privates) {
if (failed(parser.parseOptionalKeyword("reduction")))
return failure();
ParseResult parseClauseWithRegionArgs(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &symbols,
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
SmallVector<SymbolRefAttr> reductionVec;
unsigned regionArgOffset = regionPrivateArgs.size();
if (failed(
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseArrow() ||
parser.parseArgument(privates.emplace_back()) ||
parser.parseArgument(regionPrivateArgs.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
return success();
})))
return failure();
for (auto [prv, type] : llvm::zip_equal(privates, types)) {
auto *argsBegin = regionPrivateArgs.begin();
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
argsBegin + regionArgOffset + types.size());
for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
prv.type = type;
}
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
symbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
}
static void printReductionClause(OpAsmPrinter &p, Operation *op,
ValueRange reductionArgs, ValueRange operands,
TypeRange types, ArrayAttr reductionSymbols) {
p << "reduction(";
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
ValueRange argsSubrange,
StringRef clauseName, ValueRange operands,
TypeRange types, ArrayAttr symbols) {
p << clauseName << "(";
llvm::interleaveComma(
llvm::zip_equal(reductionSymbols, operands, reductionArgs, types), p,
[&p](auto t) {
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : " << type;
});
p << ") ";
}
static ParseResult
parseParallelRegion(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
static ParseResult parseParallelRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
llvm::SmallVectorImpl<Type> &privateVarsTypes,
ArrayAttr &privatizerSymbols) {
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
llvm::SmallVector<OpAsmParser::Argument> privates;
if (succeeded(parseReductionClause(parser, region, operands, types,
reductionSymbols, privates)))
return parser.parseRegion(region, privates);
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
reductionVarTypes, reductionSymbols,
regionPrivateArgs)))
return failure();
}
return parser.parseRegion(region);
if (succeeded(parser.parseOptionalKeyword("private"))) {
if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
privateVarsTypes, privatizerSymbols,
regionPrivateArgs)))
return failure();
}
return parser.parseRegion(region, regionPrivateArgs);
}
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange operands, TypeRange types,
ArrayAttr reductionSymbols) {
if (reductionSymbols)
printReductionClause(p, op, region.front().getArguments(), operands, types,
reductionSymbols);
ValueRange reductionVarOperands,
TypeRange reductionVarTypes,
ArrayAttr reductionSymbols,
ValueRange privateVarOperands,
TypeRange privateVarTypes,
ArrayAttr privatizerSymbols) {
if (reductionSymbols) {
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin,
argsBegin + reductionVarTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
reductionVarOperands, reductionVarTypes,
reductionSymbols);
}
if (privatizerSymbols) {
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
argsBegin + reductionVarOperands.size() +
privateVarTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, "private",
privateVarOperands, privateVarTypes,
privatizerSymbols);
}
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
@@ -1174,14 +1208,64 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
/*proc_bind_val=*/nullptr);
/*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
/*privatizers=*/nullptr);
state.addAttributes(attributes);
}
template <typename OpType>
static LogicalResult verifyPrivateVarList(OpType &op) {
auto privateVars = op.getPrivateVars();
auto privatizers = op.getPrivatizersAttr();
if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
return success();
auto numPrivateVars = privateVars.size();
auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
if (numPrivateVars != numPrivatizers)
return op.emitError() << "inconsistent number of private variables and "
"privatizer op symbols, private vars: "
<< numPrivateVars
<< " vs. privatizer op symbols: " << numPrivatizers;
for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
Type varType = std::get<0>(privateVarInfo).getType();
SymbolRefAttr privatizerSym =
std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
PrivateClauseOp privatizerOp =
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
privatizerSym);
if (privatizerOp == nullptr)
return op.emitError() << "failed to lookup privatizer op with symbol: '"
<< privatizerSym << "'";
Type privatizerType = privatizerOp.getType();
if (varType != privatizerType)
return op.emitError()
<< "type mismatch between a "
<< (privatizerOp.getDataSharingType() ==
DataSharingClauseType::Private
? "private"
: "firstprivate")
<< " variable and its privatizer op, var type: " << varType
<< " vs. privatizer op type: " << privatizerType;
}
return success();
}
LogicalResult ParallelOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
if (failed(verifyPrivateVarList(*this)))
return failure();
return verifyReductionVarList(*this, getReductions(), getReductionVars());
}
@@ -1279,9 +1363,10 @@ parseWsLoop(OpAsmParser &parser, Region &region,
// Parse an optional reduction clause
llvm::SmallVector<OpAsmParser::Argument> privates;
bool hasReduction = succeeded(
parseReductionClause(parser, region, reductionOperands, reductionTypes,
reductionSymbols, privates));
bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) &&
succeeded(parseClauseWithRegionArgs(
parser, region, reductionOperands, reductionTypes,
reductionSymbols, privates));
if (parser.parseKeyword("for"))
return failure();
@@ -1328,8 +1413,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region &region,
if (reductionSymbols) {
auto reductionArgs =
region.front().getArguments().drop_front(loopVarTypes.size());
printReductionClause(p, op, reductionArgs, reductionOperands,
reductionTypes, reductionSymbols);
printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
reductionOperands, reductionTypes,
reductionSymbols);
}
p << " for ";

View File

@@ -1865,3 +1865,59 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
^bb0(%arg0: f32):
omp.yield(%arg0 : f32)
}
// -----
func.func @private_type_mismatch(%arg0: index) {
// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
omp.terminator
}
return
}
omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}
// -----
func.func @firstprivate_type_mismatch(%arg0: index) {
// expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
omp.terminator
}
return
}
omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
} copy {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}
// -----
func.func @undefined_privatizer(%arg0: index) {
// expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
omp.terminator
}
return
}
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
"omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
return
}

View File

@@ -59,7 +59,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%num_threads, %data_var, %data_var) ({
omp.terminator
}) {operandSegmentSizes = array<i32: 0,1,1,1,0>} : (i32, memref<i32>, memref<i32>) -> ()
}) {operandSegmentSizes = array<i32: 0,1,1,1,0,0>} : (i32, memref<i32>, memref<i32>) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -68,22 +68,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%if_cond, %data_var, %data_var) ({
omp.terminator
}) {operandSegmentSizes = array<i32: 1,0,1,1,0>} : (i1, memref<i32>, memref<i32>) -> ()
}) {operandSegmentSizes = array<i32: 1,0,1,1,0,0>} : (i1, memref<i32>, memref<i32>) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
}) {operandSegmentSizes = array<i32: 1,1,0,0,0>} : (i1, i32) -> ()
}) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (i1, i32) -> ()
omp.terminator
}) {operandSegmentSizes = array<i32: 1,1,1,1,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
}) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
}) {operandSegmentSizes = array<i32: 0,0,1,1,0>} : (memref<i32>, memref<i32>) -> ()
}) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (memref<i32>, memref<i32>) -> ()
return
}
@@ -2231,3 +2231,63 @@ func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memre
omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
return
}
// CHECK-LABEL: parallel_op_privatizers
// CHECK-SAME: (%[[ARG0:[^[:space:]]+]]: !llvm.ptr, %[[ARG1:[^[:space:]]+]]: !llvm.ptr)
func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
// CHECK: omp.parallel private(
// CHECK-SAME: @x.privatizer %[[ARG0]] -> %[[ARG0_PRIV:[^[:space:]]+]] : !llvm.ptr,
// CHECK-SAME: @y.privatizer %[[ARG1]] -> %[[ARG1_PRIV:[^[:space:]]+]] : !llvm.ptr)
omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr) {
// CHECK: llvm.load %[[ARG0_PRIV]]
%0 = llvm.load %arg2 : !llvm.ptr -> i32
// CHECK: llvm.load %[[ARG1_PRIV]]
%1 = llvm.load %arg3 : !llvm.ptr -> i32
omp.terminator
}
return
}
// CHECK-LABEL: omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
// CHECK: ^bb0(%{{.*}}: {{.*}}):
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}
// CHECK-LABEL: omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
// CHECK: ^bb0(%{{.*}}: {{.*}}):
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
// CHECK: } copy {
} copy {
// CHECK: ^bb0(%{{.*}}: {{.*}}, %{{.*}}: {{.*}}):
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}
// CHECK-LABEL: parallel_op_reduction_and_private
func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !llvm.ptr, %reduc_var: !llvm.ptr, %reduc_var2: !llvm.ptr) {
// CHECK: omp.parallel
// CHECK-SAME: reduction(
// CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr,
// CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr)
//
// CHECK-SAME: private(
// CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]] : !llvm.ptr,
// CHECK-SAME: @y.privatizer %[[PRIV_VAR2:[^[:space:]]+]] -> %[[PRIV_ARG2:[^[:space:]]+]] : !llvm.ptr)
omp.parallel reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr)
private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr) {
// CHECK: llvm.load %[[PRIV_ARG]]
%0 = llvm.load %priv_arg : !llvm.ptr -> f32
// CHECK: llvm.load %[[PRIV_ARG2]]
%1 = llvm.load %priv_arg2 : !llvm.ptr -> f32
// CHECK: llvm.load %[[REDUC_ARG]]
%2 = llvm.load %reduc_arg : !llvm.ptr -> f32
// CHECK: llvm.load %[[REDUC_ARG2]]
%3 = llvm.load %reduc_arg2 : !llvm.ptr -> f32
omp.terminator
}
return
}

View File

@@ -1,21 +0,0 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// CHECK: omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
// CHECK: ^bb0(%arg0: {{.*}}):
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}
// CHECK: omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
// CHECK: ^bb0(%arg0: {{.*}}):
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
// CHECK: } copy {
} copy {
// CHECK: ^bb0(%arg0: {{.*}}, %arg1: {{.*}}):
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}