[mlir][OpenMP] add attribute for privatization barrier (#140089)

A barrier is needed at the end of initialization/copying of private
variables if any of those variables is lastprivate. This ensures that
all firstprivate variables receive the original value of the variable
before the lastprivate clause overwrites it.

Previously this barrier was added by the flang fontend, but there is not
a reliable way to put the barrier in the correct place for delayed
privatization, and the OpenMP dialect could some day have other users.
It is important that there are safe ways to use the constructs available
in the dialect.

lastprivate is currently not modelled in the OpenMP dialect, and so
there is no way to reliably determine whether there were lastprivate
variables. Therefore the frontend will have to provide this information
through this new attribute.

Part of a series of patches to fix
https://github.com/llvm/llvm-project/issues/136357
This commit is contained in:
Tom Eccles
2025-05-22 15:24:02 +01:00
committed by GitHub
parent 03cc50fd7d
commit a24ed7d477
5 changed files with 159 additions and 97 deletions

View File

@@ -1102,7 +1102,10 @@ class OpenMP_PrivateClauseSkip<
let arguments = (ins
Variadic<AnyType>:$private_vars,
OptionalAttr<SymbolRefArrayAttr>:$private_syms
OptionalAttr<SymbolRefArrayAttr>:$private_syms,
// Set this attribute if a barrier is needed after initialization and
// copying of lastprivate variables.
UnitAttr:$private_needs_barrier
);
// TODO: Add description.

View File

@@ -213,8 +213,8 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
$private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
$private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
}];
let hasVerifier = 1;
@@ -258,8 +258,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
$private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
$private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
}];
let hasVerifier = 1;
@@ -317,8 +317,8 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
$private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
$private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
}];
let hasVerifier = 1;
@@ -350,7 +350,7 @@ def SingleOp : OpenMP_Op<"single", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateRegion>($region, $private_vars, type($private_vars),
$private_syms) attr-dict
$private_syms, $private_needs_barrier) attr-dict
}];
let hasVerifier = 1;
@@ -505,8 +505,8 @@ def LoopOp : OpenMP_Op<"loop", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
$private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
$private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
}];
let builders = [
@@ -557,8 +557,8 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
$private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
$private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
}];
let hasVerifier = 1;
@@ -611,8 +611,8 @@ def SimdOp : OpenMP_Op<"simd", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
$private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
$private_syms, $private_needs_barrier, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
}];
let hasVerifier = 1;
@@ -690,7 +690,7 @@ def DistributeOp : OpenMP_Op<"distribute", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateRegion>($region, $private_vars, type($private_vars),
$private_syms) attr-dict
$private_syms, $private_needs_barrier) attr-dict
}];
let hasVerifier = 1;
@@ -740,7 +740,7 @@ def TaskOp
custom<InReductionPrivateRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $private_vars,
type($private_vars), $private_syms) attr-dict
type($private_vars), $private_syms, $private_needs_barrier) attr-dict
}];
let hasVerifier = 1;
@@ -816,8 +816,9 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
custom<InReductionPrivateReductionRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $private_vars,
type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
type($private_vars), $private_syms, $private_needs_barrier,
$reduction_mod, $reduction_vars, type($reduction_vars),
$reduction_byref, $reduction_syms) attr-dict
}];
let extraClassDeclaration = [{
@@ -1324,7 +1325,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
$host_eval_vars, type($host_eval_vars), $in_reduction_vars,
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
$map_vars, type($map_vars), $private_vars, type($private_vars),
$private_syms, $private_maps) attr-dict
$private_syms, $private_needs_barrier, $private_maps) attr-dict
}];
let hasVerifier = 1;

View File

@@ -450,6 +450,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* private_needs_barrier = */ nullptr,
/* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
/* reduction_mod = */ nullptr,
/* reduction_vars = */ llvm::SmallVector<Value>{},

View File

@@ -581,11 +581,14 @@ struct PrivateParseArgs {
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
llvm::SmallVectorImpl<Type> &types;
ArrayAttr &syms;
UnitAttr &needsBarrier;
DenseI64ArrayAttr *mapIndices;
PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
SmallVectorImpl<Type> &types, ArrayAttr &syms,
UnitAttr &needsBarrier,
DenseI64ArrayAttr *mapIndices = nullptr)
: vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
: vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
mapIndices(mapIndices) {}
};
struct ReductionParseArgs {
@@ -613,6 +616,10 @@ struct AllRegionParseArgs {
};
} // namespace
static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
return "private_barrier";
}
static ParseResult parseClauseWithRegionArgs(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
@@ -620,7 +627,8 @@ static ParseResult parseClauseWithRegionArgs(
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
DenseBoolArrayAttr *byref = nullptr,
ReductionModifierAttr *modifier = nullptr) {
ReductionModifierAttr *modifier = nullptr,
UnitAttr *needsBarrier = nullptr) {
SmallVector<SymbolRefAttr> symbolVec;
SmallVector<int64_t> mapIndicesVec;
SmallVector<bool> isByRefVec;
@@ -688,6 +696,12 @@ static ParseResult parseClauseWithRegionArgs(
if (parser.parseRParen())
return failure();
if (needsBarrier) {
if (parser.parseOptionalKeyword(getPrivateNeedsBarrierSpelling())
.succeeded())
*needsBarrier = mlir::UnitAttr::get(parser.getContext());
}
auto *argsBegin = regionPrivateArgs.begin();
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
argsBegin + regionArgOffset + types.size());
@@ -735,7 +749,8 @@ static ParseResult parseBlockArgClause(
if (failed(parseClauseWithRegionArgs(
parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
&privateArgs->syms, privateArgs->mapIndices)))
&privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
/*modifier=*/nullptr, &privateArgs->needsBarrier)))
return failure();
}
return success();
@@ -824,7 +839,7 @@ static ParseResult parseTargetOpRegion(
SmallVectorImpl<Type> &mapTypes,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
DenseI64ArrayAttr &privateMaps) {
UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
AllRegionParseArgs args;
args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
@@ -832,7 +847,7 @@ static ParseResult parseTargetOpRegion(
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
&privateMaps);
privateNeedsBarrier, &privateMaps);
return parseBlockArgRegion(parser, region, args);
}
@@ -842,11 +857,13 @@ static ParseResult parseInReductionPrivateRegion(
SmallVectorImpl<Type> &inReductionTypes,
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
UnitAttr &privateNeedsBarrier) {
AllRegionParseArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
privateNeedsBarrier);
return parseBlockArgRegion(parser, region, args);
}
@@ -857,14 +874,15 @@ static ParseResult parseInReductionPrivateReductionRegion(
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
ReductionModifierAttr &reductionMod,
UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
AllRegionParseArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
privateNeedsBarrier);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms, &reductionMod);
return parseBlockArgRegion(parser, region, args);
@@ -873,9 +891,11 @@ static ParseResult parseInReductionPrivateReductionRegion(
static ParseResult parsePrivateRegion(
OpAsmParser &parser, Region &region,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
UnitAttr &privateNeedsBarrier) {
AllRegionParseArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
privateNeedsBarrier);
return parseBlockArgRegion(parser, region, args);
}
@@ -883,12 +903,13 @@ static ParseResult parsePrivateReductionRegion(
OpAsmParser &parser, Region &region,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
ReductionModifierAttr &reductionMod,
UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
AllRegionParseArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
privateNeedsBarrier);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms, &reductionMod);
return parseBlockArgRegion(parser, region, args);
@@ -931,10 +952,12 @@ struct PrivatePrintArgs {
ValueRange vars;
TypeRange types;
ArrayAttr syms;
UnitAttr needsBarrier;
DenseI64ArrayAttr mapIndices;
PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
DenseI64ArrayAttr mapIndices)
: vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
: vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
mapIndices(mapIndices) {}
};
struct ReductionPrintArgs {
ValueRange vars;
@@ -964,7 +987,7 @@ static void printClauseWithRegionArgs(
ValueRange argsSubrange, ValueRange operands, TypeRange types,
ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
DenseBoolArrayAttr byref = nullptr,
ReductionModifierAttr modifier = nullptr) {
ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
if (argsSubrange.empty())
return;
@@ -1006,6 +1029,9 @@ static void printClauseWithRegionArgs(
p << " : ";
llvm::interleaveComma(types, p);
p << ") ";
if (needsBarrier)
p << getPrivateNeedsBarrierSpelling() << " ";
}
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
@@ -1020,9 +1046,10 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
StringRef clauseName, ValueRange argsSubrange,
std::optional<PrivatePrintArgs> privateArgs) {
if (privateArgs)
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
privateArgs->vars, privateArgs->types,
privateArgs->syms, privateArgs->mapIndices);
printClauseWithRegionArgs(
p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
/*modifier=*/nullptr, privateArgs->needsBarrier);
}
static void
@@ -1068,23 +1095,23 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
// These parseXyz functions correspond to the custom<Xyz> definitions
// in the .td file(s).
static void
printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
ValueRange hostEvalVars, TypeRange hostEvalTypes,
ValueRange inReductionVars, TypeRange inReductionTypes,
DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange mapVars,
TypeRange mapTypes, ValueRange privateVars,
TypeRange privateTypes, ArrayAttr privateSyms,
DenseI64ArrayAttr privateMaps) {
static void printTargetOpRegion(
OpAsmPrinter &p, Operation *op, Region &region,
ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
ValueRange hostEvalVars, TypeRange hostEvalTypes,
ValueRange inReductionVars, TypeRange inReductionTypes,
DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
DenseI64ArrayAttr privateMaps) {
AllRegionPrintArgs args;
args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
privateNeedsBarrier, privateMaps);
printBlockArgRegion(p, op, region, args);
}
@@ -1092,11 +1119,12 @@ static void printInReductionPrivateRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms) {
ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
privateNeedsBarrier,
/*mapIndices=*/nullptr);
printBlockArgRegion(p, op, region, args);
}
@@ -1105,13 +1133,15 @@ static void printInReductionPrivateReductionRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
ValueRange reductionVars, TypeRange reductionTypes,
DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
ReductionModifierAttr reductionMod, ValueRange reductionVars,
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
privateNeedsBarrier,
/*mapIndices=*/nullptr);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms, reductionMod);
@@ -1120,21 +1150,24 @@ static void printInReductionPrivateReductionRegion(
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms) {
ArrayAttr privateSyms,
UnitAttr privateNeedsBarrier) {
AllRegionPrintArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
privateNeedsBarrier,
/*mapIndices=*/nullptr);
printBlockArgRegion(p, op, region, args);
}
static void printPrivateReductionRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
TypeRange privateTypes, ArrayAttr privateSyms,
TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
ReductionModifierAttr reductionMod, ValueRange reductionVars,
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
privateNeedsBarrier,
/*mapIndices=*/nullptr);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms, reductionMod);
@@ -1916,7 +1949,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
makeArrayAttr(ctx, clauses.privateSyms),
clauses.privateNeedsBarrier, clauses.threadLimit,
/*private_maps=*/nullptr);
}
@@ -2180,7 +2214,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
/*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
/*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
state.addAttributes(attributes);
@@ -2192,8 +2227,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.ifExpr, clauses.numThreads, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
clauses.procBindKind, clauses.reductionMod,
clauses.reductionVars,
clauses.privateNeedsBarrier, clauses.procBindKind,
clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms));
}
@@ -2297,11 +2332,12 @@ static bool opInGlobalImplicitParallelRegion(Operation *op) {
void TeamsOp::build(OpBuilder &builder, OperationState &state,
const TeamsOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms.
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
/*private_vars=*/{}, /*private_syms=*/nullptr,
clauses.reductionMod, clauses.reductionVars,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms),
clauses.threadLimit);
@@ -2358,11 +2394,11 @@ OperandRange SectionOp::getReductionVars() {
void SectionsOp::build(OpBuilder &builder, OperationState &state,
const SectionsOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms.
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.nowait, /*private_vars=*/{},
/*private_syms=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms));
}
@@ -2394,11 +2430,12 @@ LogicalResult SectionsOp::verifyRegions() {
void SingleOp::build(OpBuilder &builder, OperationState &state,
const SingleOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms.
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.copyprivateVars,
makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
/*private_vars=*/{}, /*private_syms=*/nullptr);
/*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/nullptr);
}
LogicalResult SingleOp::verify() {
@@ -2474,8 +2511,9 @@ void LoopOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
makeArrayAttr(ctx, clauses.privateSyms),
clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms));
}
@@ -2503,6 +2541,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
/*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
/*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
/*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/false,
/*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
/*reduction_byref=*/nullptr,
/*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
@@ -2514,18 +2553,17 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
void WsloopOp::build(OpBuilder &builder, OperationState &state,
const WsloopOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
// privateSyms.
WsloopOp::build(builder, state,
/*allocate_vars=*/{}, /*allocator_vars=*/{},
clauses.linearVars, clauses.linearStepVars, clauses.nowait,
clauses.order, clauses.orderMod, clauses.ordered,
clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms),
clauses.scheduleKind, clauses.scheduleChunk,
clauses.scheduleMod, clauses.scheduleSimd);
// TODO: Store clauses in op: allocateVars, allocatorVars
WsloopOp::build(
builder, state,
/*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
clauses.ordered, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
}
LogicalResult WsloopOp::verify() {
@@ -2565,14 +2603,14 @@ LogicalResult WsloopOp::verifyRegions() {
void SimdOp::build(OpBuilder &builder, OperationState &state,
const SimdOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: linearVars, linearStepVars, privateVars,
// privateSyms.
// TODO Store clauses in op: linearVars, linearStepVars
SimdOp::build(builder, state, clauses.alignedVars,
makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
/*linear_vars=*/{}, /*linear_step_vars=*/{},
clauses.nontemporalVars, clauses.order, clauses.orderMod,
clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
clauses.reductionMod, clauses.reductionVars,
clauses.privateNeedsBarrier, clauses.reductionMod,
clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
clauses.simdlen);
@@ -2622,7 +2660,8 @@ void DistributeOp::build(OpBuilder &builder, OperationState &state,
clauses.allocatorVars, clauses.distScheduleStatic,
clauses.distScheduleChunkSize, clauses.order,
clauses.orderMod, clauses.privateVars,
makeArrayAttr(builder.getContext(), clauses.privateSyms));
makeArrayAttr(builder.getContext(), clauses.privateSyms),
clauses.privateNeedsBarrier);
}
LogicalResult DistributeOp::verify() {
@@ -2778,7 +2817,8 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
clauses.priority, /*private_vars=*/clauses.privateVars,
/*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
clauses.untied, clauses.eventHandle);
clauses.privateNeedsBarrier, clauses.untied,
clauses.eventHandle);
}
LogicalResult TaskOp::verify() {
@@ -2817,18 +2857,18 @@ LogicalResult TaskgroupOp::verify() {
void TaskloopOp::build(OpBuilder &builder, OperationState &state,
const TaskloopOperands &clauses) {
MLIRContext *ctx = builder.getContext();
TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.final, clauses.grainsizeMod, clauses.grainsize,
clauses.ifExpr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms),
clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
clauses.numTasks, clauses.priority,
/*private_vars=*/clauses.privateVars,
/*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
TaskloopOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
/*private_vars=*/clauses.privateVars,
/*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
}
LogicalResult TaskloopOp::verify() {

View File

@@ -2876,6 +2876,23 @@ func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
return
}
// CHECK-LABEL: parallel_op_privatizers_barrier
// CHECK-SAME: (%[[ARG0:[^[:space:]]+]]: !llvm.ptr, %[[ARG1:[^[:space:]]+]]: !llvm.ptr)
func.func @parallel_op_privatizers_barrier(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
// CHECK: omp.parallel private(
// CHECK-SAME: @x.privatizer %[[ARG0]] -> %[[ARG0_PRIV:[^[:space:]]+]],
// CHECK-SAME: @y.privatizer %[[ARG1]] -> %[[ARG1_PRIV:[^[:space:]]+]] : !llvm.ptr, !llvm.ptr)
// CHECK-SAME: private_barrier
omp.parallel private(@x.privatizer %arg0 -> %arg2, @y.privatizer %arg1 -> %arg3 : !llvm.ptr, !llvm.ptr) private_barrier {
// CHECK: llvm.load %[[ARG0_PRIV]]
%0 = llvm.load %arg2 : !llvm.ptr -> i32
// CHECK: llvm.load %[[ARG1_PRIV]]
%1 = llvm.load %arg3 : !llvm.ptr -> i32
omp.terminator
}
return
}
// CHECK-LABEL: omp.private {type = private} @a.privatizer : !llvm.ptr init {
omp.private {type = private} @a.privatizer : !llvm.ptr init {
// CHECK: ^bb0(%{{.*}}: {{.*}}, %{{.*}}: {{.*}}):