[Flang][OpenMP][Lower] Use clause operand structures (#86802)

This patch updates Flang lowering to use the new set of OpenMP clause
operand structures and their groupings into directive-specific sets of
clause operands.

It simplifies the passing of information from the clause processor and
the creation of operations.

The `DataSharingProcessor` is slightly modified to not hold delayed
privatization state. Instead, optional arguments are added to
`processStep1` which are only passed when delayed privatization is used.
This enables using the clause operand structure for `private` and
removes the need for the ad-hoc `DelayedPrivatizationInfo` structure.

The processing of the `schedule` clause is updated to process the
`chunk` modifier rather than requiring two separate calls to the
`ClauseProcessor`.

Lowering of a block-associated `ordered` construct is updated to emit a
TODO error if the `simd` clause is specified, since it is not currently
supported by the `ClauseProcessor` or later compilation stages.

Removed processing of `schedule` from `omp.simdloop`, as it doesn't
apply to `simd` constructs.
This commit is contained in:
Sergio Afonso
2024-04-12 12:42:41 +01:00
committed by GitHub
parent 8c0f52e9d5
commit 78eac46609
8 changed files with 455 additions and 558 deletions

View File

@@ -162,14 +162,13 @@ getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
ifVal);
}
static void
addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
const omp::ObjectList &objects,
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&useDeviceSymbols) {
static void addUseDeviceClause(
Fortran::lower::AbstractConverter &converter,
const omp::ObjectList &objects,
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands) {
checkMapType(operand.getLoc(), operand.getType());
@@ -177,25 +176,24 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
useDeviceLocs.push_back(operand.getLoc());
}
for (const omp::Object &object : objects)
useDeviceSymbols.push_back(object.id());
useDeviceSyms.push_back(object.id());
}
static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
mlir::Location loc,
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
llvm::SmallVectorImpl<mlir::Value> &upperBound,
llvm::SmallVectorImpl<mlir::Value> &step,
mlir::omp::CollapseClauseOps &result,
std::size_t loopVarTypeSize) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
// The types of lower bound, upper bound, and step are converted into the
// type of the loop variable if necessary.
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
lowerBound[it] =
firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
upperBound[it] =
firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) {
result.loopLBVar[it] =
firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]);
result.loopUBVar[it] =
firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]);
result.loopStepVar[it] =
firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]);
}
}
@@ -205,9 +203,7 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
bool ClauseProcessor::processCollapse(
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
llvm::SmallVectorImpl<mlir::Value> &upperBound,
llvm::SmallVectorImpl<mlir::Value> &step,
mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const {
bool found = false;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -238,15 +234,15 @@ bool ClauseProcessor::processCollapse(
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds for worksharing do loop");
Fortran::lower::StatementContext stmtCtx;
lowerBound.push_back(fir::getBase(converter.genExprValue(
result.loopLBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
upperBound.push_back(fir::getBase(converter.genExprValue(
result.loopUBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
if (bounds->step) {
step.push_back(fir::getBase(converter.genExprValue(
result.loopStepVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
} else { // If `step` is not present, assume it as `1`.
step.push_back(firOpBuilder.createIntegerConstant(
result.loopStepVar.push_back(firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getIntegerType(32), 1));
}
iv.push_back(bounds->name.thing.symbol);
@@ -257,8 +253,7 @@ bool ClauseProcessor::processCollapse(
&*std::next(doConstructEval->getNestedEvaluations().begin());
} while (collapseValue > 0);
convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step,
loopVarTypeSize);
convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
return found;
}
@@ -286,7 +281,7 @@ bool ClauseProcessor::processDefault() const {
}
bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
mlir::omp::DeviceClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Device>(&source)) {
mlir::Location clauseLocation = converter.genLocation(*source);
@@ -298,25 +293,26 @@ bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
}
}
const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
result = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
result.deviceVar =
fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processDeviceType(
mlir::omp::DeclareTargetDeviceType &result) const {
mlir::omp::DeviceTypeClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::DeviceType>()) {
// Case: declare target ... device_type(any | host | nohost)
switch (clause->v) {
case omp::clause::DeviceType::DeviceTypeDescription::Nohost:
result = mlir::omp::DeclareTargetDeviceType::nohost;
result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Host:
result = mlir::omp::DeclareTargetDeviceType::host;
result.deviceType = mlir::omp::DeclareTargetDeviceType::host;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Any:
result = mlir::omp::DeclareTargetDeviceType::any;
result.deviceType = mlir::omp::DeclareTargetDeviceType::any;
break;
}
return true;
@@ -325,7 +321,7 @@ bool ClauseProcessor::processDeviceType(
}
bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
mlir::omp::FinalClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Final>(&source)) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -333,100 +329,108 @@ bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value finalVal =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
result = firOpBuilder.createConvert(clauseLocation,
firOpBuilder.getI1Type(), finalVal);
result.finalVar = firOpBuilder.createConvert(
clauseLocation, firOpBuilder.getI1Type(), finalVal);
return true;
}
return false;
}
bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v);
result = firOpBuilder.getI64IntegerAttr(hintValue);
result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue);
return true;
}
return false;
}
bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result);
bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeableAttr);
}
bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
return markClauseOccurrence<omp::clause::Nowait>(result);
bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
return markClauseOccurrence<omp::clause::Nowait>(result.nowaitAttr);
}
bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
bool ClauseProcessor::processNumTeams(
Fortran::lower::StatementContext &stmtCtx,
mlir::omp::NumTeamsClauseOps &result) const {
// TODO Get lower and upper bounds for num_teams when parser is updated to
// accept both.
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
// auto lowerBound = std::get<std::optional<ExprTy>>(clause->t);
auto &upperBound = std::get<ExprTy>(clause->t);
result = fir::getBase(converter.genExprValue(upperBound, stmtCtx));
result.numTeamsUpperVar =
fir::getBase(converter.genExprValue(upperBound, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processNumThreads(
Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
Fortran::lower::StatementContext &stmtCtx,
mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
result.numThreadsVar =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
bool ClauseProcessor::processOrdered(
mlir::omp::OrderedClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t orderedClauseValue = 0l;
if (clause->v.has_value())
orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v);
result = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
return true;
}
return false;
}
bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
bool ClauseProcessor::processPriority(
Fortran::lower::StatementContext &stmtCtx,
mlir::omp::PriorityClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
result.priorityVar =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processProcBind(
mlir::omp::ClauseProcBindKindAttr &result) const {
mlir::omp::ProcBindClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
result = genProcBindKindAttr(firOpBuilder, *clause);
result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause);
return true;
}
return false;
}
bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const {
bool ClauseProcessor::processSafelen(
mlir::omp::SafelenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> safelenVal =
Fortran::evaluate::ToInt64(clause->v);
result = firOpBuilder.getI64IntegerAttr(*safelenVal);
result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal);
return true;
}
return false;
}
bool ClauseProcessor::processSchedule(
mlir::omp::ClauseScheduleKindAttr &valAttr,
mlir::omp::ScheduleModifierAttr &modifierAttr,
mlir::UnitAttr &simdModifierAttr) const {
Fortran::lower::StatementContext &stmtCtx,
mlir::omp::ScheduleClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::MLIRContext *context = firOpBuilder.getContext();
@@ -451,53 +455,51 @@ bool ClauseProcessor::processSchedule(
break;
}
mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
result.scheduleValAttr =
mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
if (scheduleModifier != mlir::omp::ScheduleModifier::none)
modifierAttr =
result.scheduleModAttr =
mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
simdModifierAttr = firOpBuilder.getUnitAttr();
result.scheduleSimdAttr = firOpBuilder.getUnitAttr();
valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
return true;
}
return false;
}
bool ClauseProcessor::processScheduleChunk(
Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
result = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
result.scheduleChunkVar =
fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
bool ClauseProcessor::processSimdlen(
mlir::omp::SimdlenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> simdlenVal =
Fortran::evaluate::ToInt64(clause->v);
result = firOpBuilder.getI64IntegerAttr(*simdlenVal);
result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal);
return true;
}
return false;
}
bool ClauseProcessor::processThreadLimit(
Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
Fortran::lower::StatementContext &stmtCtx,
mlir::omp::ThreadLimitClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
result.threadLimitVar =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
return markClauseOccurrence<omp::clause::Untied>(result);
bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
return markClauseOccurrence<omp::clause::Untied>(result.untiedAttr);
}
//===----------------------------------------------------------------------===//
@@ -505,13 +507,12 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
//===----------------------------------------------------------------------===//
bool ClauseProcessor::processAllocate(
llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
mlir::omp::AllocateClauseOps &result) const {
return findRepeatableClause<omp::clause::Allocate>(
[&](const omp::clause::Allocate &clause,
const Fortran::parser::CharBlock &) {
genAllocateClause(converter, clause, allocatorOperands,
allocateOperands);
genAllocateClause(converter, clause, result.allocatorVars,
result.allocateVars);
});
}
@@ -660,10 +661,9 @@ createCopyFunc(mlir::Location loc, Fortran::lower::AbstractConverter &converter,
return funcOp;
}
bool ClauseProcessor::processCopyPrivate(
bool ClauseProcessor::processCopyprivate(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &copyPrivateVars,
llvm::SmallVectorImpl<mlir::Attribute> &copyPrivateFuncs) const {
mlir::omp::CopyprivateClauseOps &result) const {
auto addCopyPrivateVar = [&](Fortran::semantics::Symbol *sym) {
mlir::Value symVal = converter.getSymbolAddress(*sym);
auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>();
@@ -690,10 +690,10 @@ bool ClauseProcessor::processCopyPrivate(
cpVar = alloca;
}
copyPrivateVars.push_back(cpVar);
result.copyprivateVars.push_back(cpVar);
mlir::func::FuncOp funcOp =
createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
};
bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
@@ -714,9 +714,7 @@ bool ClauseProcessor::processCopyPrivate(
return hasCopyPrivate;
}
bool ClauseProcessor::processDepend(
llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Depend>(
@@ -731,7 +729,7 @@ bool ClauseProcessor::processDepend(
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
genDependKindAttr(firOpBuilder, kind);
dependTypeOperands.append(objects.size(), dependTypeOperand);
result.dependTypeAttrs.append(objects.size(), dependTypeOperand);
for (const omp::Object &object : objects) {
assert(object.ref() && "Expecting designator");
@@ -746,13 +744,13 @@ bool ClauseProcessor::processDepend(
Fortran::semantics::Symbol *sym = object.id();
const mlir::Value variable = converter.getSymbolAddress(*sym);
dependOperands.push_back(variable);
result.dependVars.push_back(variable);
}
});
}
bool ClauseProcessor::processHasDeviceAddr(
llvm::SmallVectorImpl<mlir::Value> &operands,
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
@@ -760,14 +758,14 @@ bool ClauseProcessor::processHasDeviceAddr(
return findRepeatableClause<omp::clause::HasDeviceAddr>(
[&](const omp::clause::HasDeviceAddr &devAddrClause,
const Fortran::parser::CharBlock &) {
addUseDeviceClause(converter, devAddrClause.v, operands, isDeviceTypes,
isDeviceLocs, isDeviceSymbols);
addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars,
isDeviceTypes, isDeviceLocs, isDeviceSymbols);
});
}
bool ClauseProcessor::processIf(
omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const {
mlir::omp::IfClauseOps &result) const {
bool found = false;
findRepeatableClause<omp::clause::If>(
[&](const omp::clause::If &clause,
@@ -778,7 +776,7 @@ bool ClauseProcessor::processIf(
// Assume that, at most, a single 'if' clause will be applicable to the
// given directive.
if (operand) {
result = operand;
result.ifVar = operand;
found = true;
}
});
@@ -786,7 +784,7 @@ bool ClauseProcessor::processIf(
}
bool ClauseProcessor::processIsDevicePtr(
llvm::SmallVectorImpl<mlir::Value> &operands,
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
@@ -794,8 +792,8 @@ bool ClauseProcessor::processIsDevicePtr(
return findRepeatableClause<omp::clause::IsDevicePtr>(
[&](const omp::clause::IsDevicePtr &devPtrClause,
const Fortran::parser::CharBlock &) {
addUseDeviceClause(converter, devPtrClause.v, operands, isDeviceTypes,
isDeviceLocs, isDeviceSymbols);
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
isDeviceTypes, isDeviceLocs, isDeviceSymbols);
});
}
@@ -835,12 +833,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
bool ClauseProcessor::processMap(
mlir::Location currentLocation, const llvm::omp::Directive &directive,
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
const {
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause,
@@ -915,25 +911,23 @@ bool ClauseProcessor::processMap(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
mapOperands.push_back(mapOp);
if (mapSymTypes)
mapSymTypes->push_back(symAddr.getType());
result.mapVars.push_back(mapOp);
if (mapSyms)
mapSyms->push_back(object.id());
if (mapSymLocs)
mapSymLocs->push_back(symAddr.getLoc());
if (mapSymbols)
mapSymbols->push_back(object.id());
if (mapSymTypes)
mapSymTypes->push_back(symAddr.getType());
}
});
}
bool ClauseProcessor::processReduction(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &outReductionVars,
llvm::SmallVectorImpl<mlir::Type> &outReductionTypes,
llvm::SmallVectorImpl<mlir::Attribute> &outReductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*outReductionSymbols) const {
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *outReductionSyms)
const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause,
const Fortran::parser::CharBlock &) {
@@ -943,30 +937,31 @@ bool ClauseProcessor::processReduction(
// whether to do the reduction byref.
llvm::SmallVector<mlir::Value> reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reductionDeclSymbols,
outReductionSymbols ? &reductionSymbols
: nullptr);
outReductionSyms ? &reductionSyms : nullptr);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(outReductionVars));
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reductionDeclSymbols,
std::back_inserter(outReductionDeclSymbols));
if (outReductionSymbols)
llvm::copy(reductionSymbols,
std::back_inserter(*outReductionSymbols));
std::back_inserter(result.reductionDeclSymbols));
outReductionTypes.reserve(outReductionTypes.size() +
reductionVars.size());
llvm::transform(reductionVars, std::back_inserter(outReductionTypes),
[](mlir::Value v) { return v.getType(); });
if (outReductionTypes) {
outReductionTypes->reserve(outReductionTypes->size() +
reductionVars.size());
llvm::transform(reductionVars, std::back_inserter(*outReductionTypes),
[](mlir::Value v) { return v.getType(); });
}
if (outReductionSyms)
llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms));
});
}
bool ClauseProcessor::processSectionsReduction(
mlir::Location currentLocation) const {
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &) const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) {
TODO(currentLocation, "OMPC_Reduction");
@@ -995,30 +990,30 @@ bool ClauseProcessor::processEnter(
}
bool ClauseProcessor::processUseDeviceAddr(
llvm::SmallVectorImpl<mlir::Value> &operands,
mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
const {
return findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const Fortran::parser::CharBlock &) {
addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
useDeviceLocs, useDeviceSymbols);
addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars,
useDeviceTypes, useDeviceLocs, useDeviceSyms);
});
}
bool ClauseProcessor::processUseDevicePtr(
llvm::SmallVectorImpl<mlir::Value> &operands,
mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
const {
return findRepeatableClause<omp::clause::UseDevicePtr>(
[&](const omp::clause::UseDevicePtr &clause,
const Fortran::parser::CharBlock &) {
addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
useDeviceLocs, useDeviceSymbols);
addUseDeviceClause(converter, clause.v, result.useDevicePtrVars,
useDeviceTypes, useDeviceLocs, useDeviceSyms);
});
}

View File

@@ -37,7 +37,7 @@ namespace omp {
/// corresponding clause if it is present in the clause list. Otherwise, they
/// will return `false` to signal that the clause was not found.
///
/// The intended use is of this class is to move clause processing outside of
/// The intended use of this class is to move clause processing outside of
/// construct processing, since the same clauses can appear attached to
/// different constructs and constructs can be combined, so that code
/// duplication is minimized.
@@ -56,61 +56,51 @@ public:
// 'Unique' clauses: They can appear at most once in the clause list.
bool processCollapse(
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
llvm::SmallVectorImpl<mlir::Value> &upperBound,
llvm::SmallVectorImpl<mlir::Value> &step,
mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const;
bool processDefault() const;
bool processDevice(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
mlir::omp::DeviceClauseOps &result) const;
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
mlir::omp::FinalClauseOps &result) const;
bool
processHasDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
processHasDeviceAddr(mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&isDeviceSymbols) const;
bool processHint(mlir::IntegerAttr &result) const;
bool processMergeable(mlir::UnitAttr &result) const;
bool processNowait(mlir::UnitAttr &result) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
mlir::omp::NumTeamsClauseOps &result) const;
bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool processOrdered(mlir::IntegerAttr &result) const;
mlir::omp::NumThreadsClauseOps &result) const;
bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
bool processPriority(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const;
bool processSafelen(mlir::IntegerAttr &result) const;
bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr,
mlir::omp::ScheduleModifierAttr &modifierAttr,
mlir::UnitAttr &simdModifierAttr) const;
bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool processSimdlen(mlir::IntegerAttr &result) const;
mlir::omp::PriorityClauseOps &result) const;
bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
bool processSchedule(Fortran::lower::StatementContext &stmtCtx,
mlir::omp::ScheduleClauseOps &result) const;
bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool processUntied(mlir::UnitAttr &result) const;
mlir::omp::ThreadLimitClauseOps &result) const;
bool processUntied(mlir::omp::UntiedClauseOps &result) const;
// 'Repeatable' clauses: They can appear multiple times in the clause list.
bool
processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
bool processCopyin() const;
bool processCopyPrivate(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &copyPrivateVars,
llvm::SmallVectorImpl<mlir::Attribute> &copyPrivateFuncs) const;
bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
bool processCopyprivate(mlir::Location currentLocation,
mlir::omp::CopyprivateClauseOps &result) const;
bool processDepend(mlir::omp::DependClauseOps &result) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const;
mlir::omp::IfClauseOps &result) const;
bool
processIsDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
processIsDevicePtr(mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -119,43 +109,42 @@ public:
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
// This method is used to process a map clause.
// The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to
// The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
// store the original type, location and Fortran symbol for the map operands.
// They may be used later on to create the block_arguments for some of the
// target directives that require it.
bool processMap(mlir::Location currentLocation,
const llvm::omp::Directive &directive,
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*mapSymbols = nullptr) const;
bool
processReduction(mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*reductionSymbols = nullptr) const;
bool processSectionsReduction(mlir::Location currentLocation) const;
bool processMap(
mlir::Location currentLocation, const llvm::omp::Directive &directive,
Fortran::lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms =
nullptr,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSyms =
nullptr) const;
bool processSectionsReduction(mlir::Location currentLocation,
mlir::omp::ReductionClauseOps &result) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&useDeviceSymbols) const;
&useDeviceSyms) const;
bool
processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
processUseDevicePtr(mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&useDeviceSymbols) const;
&useDeviceSyms) const;
template <typename T>
bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands);
mlir::omp::MapClauseOps &result);
// Call this method for these clauses that should be supported but are not
// implemented yet. It triggers a compilation error if any of the given
@@ -197,7 +186,7 @@ private:
template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
mlir::omp::MapClauseOps &result) {
return findRepeatableClause<T>(
[&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
@@ -239,7 +228,7 @@ bool ClauseProcessor::processMotionClauses(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
mapOperands.push_back(mapOp);
result.mapVars.push_back(mapOp);
}
});
}

View File

@@ -23,11 +23,13 @@ namespace Fortran {
namespace lower {
namespace omp {
void DataSharingProcessor::processStep1() {
void DataSharingProcessor::processStep1(
mlir::omp::PrivateClauseOps *clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
collectSymbolsForPrivatization();
collectDefaultSymbols();
privatize();
defaultPrivatize();
privatize(clauseOps, privateSyms);
defaultPrivatize(clauseOps, privateSyms);
insertBarrier();
}
@@ -299,14 +301,16 @@ void DataSharingProcessor::collectDefaultSymbols() {
}
}
void DataSharingProcessor::privatize() {
void DataSharingProcessor::privatize(
mlir::omp::PrivateClauseOps *clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
if (const auto *commonDet =
sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (const auto &mem : commonDet->objects())
doPrivatize(&*mem);
doPrivatize(&*mem, clauseOps, privateSyms);
} else
doPrivatize(sym);
doPrivatize(sym, clauseOps, privateSyms);
}
}
@@ -323,7 +327,9 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
}
}
void DataSharingProcessor::defaultPrivatize() {
void DataSharingProcessor::defaultPrivatize(
mlir::omp::PrivateClauseOps *clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
for (const Fortran::semantics::Symbol *sym : defaultSymbols) {
if (!Fortran::semantics::IsProcedure(*sym) &&
!sym->GetUltimate().has<Fortran::semantics::DerivedTypeDetails>() &&
@@ -331,11 +337,14 @@ void DataSharingProcessor::defaultPrivatize() {
!symbolsInNestedRegions.contains(sym) &&
!symbolsInParentRegions.contains(sym) &&
!privatizedSymbols.contains(sym))
doPrivatize(sym);
doPrivatize(sym, clauseOps, privateSyms);
}
}
void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
void DataSharingProcessor::doPrivatize(
const Fortran::semantics::Symbol *sym,
mlir::omp::PrivateClauseOps *clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
if (!useDelayedPrivatization) {
cloneSymbol(sym);
copyFirstPrivateSymbol(sym);
@@ -419,10 +428,13 @@ void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
return result;
}();
delayedPrivatizationInfo.privatizers.push_back(
mlir::SymbolRefAttr::get(privatizerOp));
delayedPrivatizationInfo.originalAddresses.push_back(hsb.getAddr());
delayedPrivatizationInfo.symbols.push_back(sym);
if (clauseOps) {
clauseOps->privatizers.push_back(mlir::SymbolRefAttr::get(privatizerOp));
clauseOps->privateVars.push_back(hsb.getAddr());
}
if (privateSyms)
privateSyms->push_back(sym);
}
} // namespace omp

View File

@@ -19,28 +19,17 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
namespace mlir {
namespace omp {
struct PrivateClauseOps;
} // namespace omp
} // namespace mlir
namespace Fortran {
namespace lower {
namespace omp {
class DataSharingProcessor {
public:
/// Collects all the information needed for delayed privatization. This can be
/// used by ops with data-sharing clauses to properly generate their regions
/// (e.g. add region arguments) and map the original SSA values to their
/// corresponding OMP region operands.
struct DelayedPrivatizationInfo {
// The list of symbols referring to delayed privatizer ops (i.e.
// `omp.private` ops).
llvm::SmallVector<mlir::SymbolRefAttr> privatizers;
// SSA values that correspond to "original" values being privatized.
// "Original" here means the SSA value outside the OpenMP region from which
// a clone is created inside the region.
llvm::SmallVector<mlir::Value> originalAddresses;
// Fortran symbols corresponding to the above SSA values.
llvm::SmallVector<const Fortran::semantics::Symbol *> symbols;
};
private:
bool hasLastPrivateOp;
mlir::OpBuilder::InsertPoint lastPrivIP;
@@ -57,7 +46,6 @@ private:
Fortran::lower::pft::Evaluation &eval;
bool useDelayedPrivatization;
Fortran::lower::SymMap *symTable;
DelayedPrivatizationInfo delayedPrivatizationInfo;
bool needBarrier();
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
@@ -67,9 +55,16 @@ private:
void collectSymbolsForPrivatization();
void insertBarrier();
void collectDefaultSymbols();
void privatize();
void defaultPrivatize();
void doPrivatize(const Fortran::semantics::Symbol *sym);
void privatize(
mlir::omp::PrivateClauseOps *clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
void defaultPrivatize(
mlir::omp::PrivateClauseOps *clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
void doPrivatize(
const Fortran::semantics::Symbol *sym,
mlir::omp::PrivateClauseOps *clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
void copyLastPrivatize(mlir::Operation *op);
void insertLastPrivateCompare(mlir::Operation *op);
void cloneSymbol(const Fortran::semantics::Symbol *sym);
@@ -103,17 +98,15 @@ public:
// Step2 performs the copying for lastprivates and requires knowledge of the
// MLIR operation to insert the last private update. Step2 adds
// dealocation code as well.
void processStep1();
void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*privateSyms = nullptr);
void processStep2(mlir::Operation *op, bool isLoop);
void setLoopIV(mlir::Value iv) {
assert(!loopIV && "Loop iteration variable already set");
loopIV = iv;
}
const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const {
return delayedPrivatizationInfo;
}
};
} // namespace omp

View File

@@ -730,19 +730,25 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation) {
return genOpWithBody<mlir::omp::MasterOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested),
/*resultTypes=*/mlir::TypeRange());
.setGenNested(genNested));
}
static mlir::omp::OrderedRegionOp
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
mlir::Location currentLocation) {
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
mlir::omp::OrderedRegionClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processTODO<clause::Simd>(currentLocation,
llvm::omp::Directive::OMPD_ordered);
return genOpWithBody<mlir::omp::OrderedRegionOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested),
/*simd=*/false);
clauseOps);
}
static mlir::omp::ParallelOp
@@ -753,77 +759,62 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, numThreadsClauseOperand;
mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
reductionVars;
mlir::omp::ParallelClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(llvm::omp::Directive::OMPD_parallel, ifClauseOperand);
cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
cp.processProcBind(procBindKindAttr);
cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
cp.processNumThreads(stmtCtx, clauseOps);
cp.processProcBind(clauseOps);
cp.processDefault();
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processAllocate(clauseOps);
if (!outerCombined)
cp.processReduction(currentLocation, reductionVars, reductionTypes,
reductionDeclSymbols, &reductionSymbols);
cp.processReduction(currentLocation, clauseOps, &reductionTypes,
&reductionSyms);
if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
auto reductionCallback = [&](mlir::Operation *op) {
llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
llvm::SmallVector<mlir::Location> locs(clauseOps.reductionVars.size(),
currentLocation);
auto *block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
reductionTypes, locs);
auto *block =
firOpBuilder.createBlock(&op->getRegion(0), {}, reductionTypes, locs);
for (auto [arg, prv] :
llvm::zip_equal(reductionSymbols, block->getArguments())) {
llvm::zip_equal(reductionSyms, block->getArguments())) {
converter.bindSymbol(*arg, prv);
}
return reductionSymbols;
return reductionSyms;
};
mlir::UnitAttr byrefAttr;
if (ReductionProcessor::doReductionByRef(reductionVars))
byrefAttr = converter.getFirOpBuilder().getUnitAttr();
OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList)
.setReductions(&reductionSymbols, &reductionTypes)
.setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(reductionCallback);
if (!enableDelayedPrivatization) {
return genOpWithBody<mlir::omp::ParallelOp>(
genInfo,
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
numThreadsClauseOperand, allocateOperands, allocatorOperands,
reductionVars,
reductionDeclSymbols.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
/*privatizers=*/nullptr, byrefAttr);
}
if (!enableDelayedPrivatization)
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
bool privatize = !outerCombined;
DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
/*useDelayedPrivatization=*/true, &symTable);
if (privatize)
dsp.processStep1();
const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo();
dsp.processStep1(&clauseOps, &privateSyms);
auto genRegionEntryCB = [&](mlir::Operation *op) {
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
llvm::SmallVector<mlir::Location> reductionLocs(reductionVars.size(),
currentLocation);
llvm::SmallVector<mlir::Location> reductionLocs(
clauseOps.reductionVars.size(), currentLocation);
mlir::OperandRange privateVars = parallelOp.getPrivateVars();
mlir::Region &region = parallelOp.getRegion();
@@ -838,12 +829,12 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::transform(privateVars, std::back_inserter(privateVarLocs),
[](mlir::Value v) { return v.getLoc(); });
converter.getFirOpBuilder().createBlock(&region, /*insertPt=*/{},
privateVarTypes, privateVarLocs);
firOpBuilder.createBlock(&region, /*insertPt=*/{}, privateVarTypes,
privateVarLocs);
llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols =
reductionSymbols;
allSymbols.append(delayedPrivatizationInfo.symbols);
reductionSyms;
allSymbols.append(privateSyms);
for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) {
converter.bindSymbol(*arg, prv);
}
@@ -853,26 +844,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
// TODO Merge with the reduction CB.
genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
llvm::SmallVector<mlir::Attribute> privatizers(
delayedPrivatizationInfo.privatizers.begin(),
delayedPrivatizationInfo.privatizers.end());
return genOpWithBody<mlir::omp::ParallelOp>(
genInfo,
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
numThreadsClauseOperand, allocateOperands, allocatorOperands,
reductionVars,
reductionDeclSymbols.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr, delayedPrivatizationInfo.originalAddresses,
delayedPrivatizationInfo.privatizers.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
privatizers),
byrefAttr);
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
}
static mlir::omp::SectionOp
@@ -896,28 +868,21 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &beginClauseList,
const Fortran::parser::OmpClauseList &endClauseList) {
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
llvm::SmallVector<mlir::Value> copyPrivateVars;
llvm::SmallVector<mlir::Attribute> copyPrivateFuncs;
mlir::UnitAttr nowaitAttr;
mlir::omp::SingleClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, beginClauseList);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processAllocate(clauseOps);
// TODO Support delayed privatization.
ClauseProcessor ecp(converter, semaCtx, endClauseList);
ecp.processNowait(nowaitAttr);
ecp.processCopyPrivate(currentLocation, copyPrivateVars, copyPrivateFuncs);
ecp.processNowait(clauseOps);
ecp.processCopyprivate(currentLocation, clauseOps);
return genOpWithBody<mlir::omp::SingleOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&beginClauseList),
allocateOperands, allocatorOperands, copyPrivateVars,
copyPrivateFuncs.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
copyPrivateFuncs),
nowaitAttr);
clauseOps);
}
static mlir::omp::TaskOp
@@ -927,21 +892,19 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand;
mlir::UnitAttr untiedAttr, mergeableAttr;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
dependOperands;
mlir::omp::TaskClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(llvm::omp::Directive::OMPD_task, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
cp.processAllocate(clauseOps);
cp.processDefault();
cp.processFinal(stmtCtx, finalClauseOperand);
cp.processUntied(untiedAttr);
cp.processMergeable(mergeableAttr);
cp.processPriority(stmtCtx, priorityClauseOperand);
cp.processDepend(dependTypeOperands, dependOperands);
cp.processFinal(stmtCtx, clauseOps);
cp.processUntied(clauseOps);
cp.processMergeable(clauseOps);
cp.processPriority(stmtCtx, clauseOps);
cp.processDepend(clauseOps);
// TODO Support delayed privatization.
cp.processTODO<clause::InReduction, clause::Detach, clause::Affinity>(
currentLocation, llvm::omp::Directive::OMPD_task);
@@ -949,14 +912,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&clauseList),
ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
/*in_reduction_vars=*/mlir::ValueRange(),
/*in_reductions=*/nullptr, priorityClauseOperand,
dependTypeOperands.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
dependTypeOperands),
dependOperands, allocateOperands, allocatorOperands);
clauseOps);
}
static mlir::omp::TaskgroupOp
@@ -965,17 +921,18 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval, bool genNested,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
mlir::omp::TaskgroupClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processAllocate(clauseOps);
cp.processTODO<clause::TaskReduction>(currentLocation,
llvm::omp::Directive::OMPD_taskgroup);
return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&clauseList),
/*task_reduction_vars=*/mlir::ValueRange(),
/*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
clauseOps);
}
// This helper function implements the functionality of "promoting"
@@ -996,8 +953,7 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
// clause. Support for such list items in a use_device_ptr clause
// is deprecated."
static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
llvm::SmallVectorImpl<mlir::Value> &devicePtrOperands,
llvm::SmallVectorImpl<mlir::Value> &deviceAddrOperands,
mlir::omp::UseDeviceClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -1010,9 +966,10 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
// Iterate over our use_device_ptr list and shift all non-cptr arguments into
// use_device_addr.
for (auto *it = devicePtrOperands.begin(); it != devicePtrOperands.end();) {
for (auto *it = clauseOps.useDevicePtrVars.begin();
it != clauseOps.useDevicePtrVars.end();) {
if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) {
deviceAddrOperands.push_back(*it);
clauseOps.useDeviceAddrVars.push_back(*it);
// We have to shuffle the symbols around as well, to maintain
// the correct Input -> BlockArg for use_device_ptr/use_device_addr.
// NOTE: However, as map's do not seem to be included currently
@@ -1020,11 +977,11 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
// future alterations. I believe the reason they are not currently
// is that the BlockArg assign/lowering needs to be extended
// to a greater set of types.
auto idx = std::distance(devicePtrOperands.begin(), it);
auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it);
moveElementToBack(idx, useDeviceTypes);
moveElementToBack(idx, useDeviceLocs);
moveElementToBack(idx, useDeviceSymbols);
it = devicePtrOperands.erase(it);
it = clauseOps.useDevicePtrVars.erase(it);
continue;
}
++it;
@@ -1038,20 +995,19 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand;
llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
deviceAddrOperands;
mlir::omp::TargetDataClauseOps clauseOps;
llvm::SmallVector<mlir::Type> useDeviceTypes;
llvm::SmallVector<mlir::Location> useDeviceLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(llvm::omp::Directive::OMPD_target_data, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
useDeviceSymbols);
cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
useDeviceSymbols);
cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs,
useDeviceSyms);
cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs,
useDeviceSyms);
// This function implements the deprecated functionality of use_device_ptr
// that allows users to provide non-CPTR arguments to it with the caveat
// that the compiler will treat them as use_device_addr. A lot of legacy
@@ -1063,17 +1019,16 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
// ordering.
// TODO: Perhaps create a user provideable compiler option that will
// re-introduce a hard-error rather than a warning in these cases.
promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
devicePtrOperands, deviceAddrOperands, useDeviceTypes, useDeviceLocs,
useDeviceSymbols);
promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes,
useDeviceLocs, useDeviceSyms);
cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data,
stmtCtx, mapOperands);
stmtCtx, clauseOps);
auto dataOp = converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(
currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
deviceAddrOperands, mapOperands);
currentLocation, clauseOps);
genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp,
useDeviceTypes, useDeviceLocs, useDeviceSymbols,
useDeviceTypes, useDeviceLocs, useDeviceSyms,
currentLocation);
return dataOp;
}
@@ -1086,10 +1041,7 @@ static OpTy genTargetEnterExitDataUpdateOp(
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand;
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
// GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
[[maybe_unused]] llvm::omp::Directive directive;
@@ -1104,25 +1056,19 @@ static OpTy genTargetEnterExitDataUpdateOp(
}
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(directive, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processDepend(dependTypeOperands, dependOperands);
cp.processNowait(nowaitAttr);
cp.processIf(directive, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processDepend(clauseOps);
cp.processNowait(clauseOps);
if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
cp.processMotionClauses<clause::To>(stmtCtx, mapOperands);
cp.processMotionClauses<clause::From>(stmtCtx, mapOperands);
cp.processMotionClauses<clause::To>(stmtCtx, clauseOps);
cp.processMotionClauses<clause::From>(stmtCtx, clauseOps);
} else {
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
cp.processMap(currentLocation, directive, stmtCtx, clauseOps);
}
return firOpBuilder.create<OpTy>(
currentLocation, ifClauseOperand, deviceOperand,
dependTypeOperands.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
dependTypeOperands),
dependOperands, nowaitAttr, mapOperands);
return firOpBuilder.create<OpTy>(currentLocation, clauseOps);
}
// This functions creates a block for the body of the targetOp's region. It adds
@@ -1132,9 +1078,9 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
mlir::omp::TargetOp &targetOp,
llvm::ArrayRef<mlir::Type> mapSymTypes,
llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms,
llvm::ArrayRef<mlir::Location> mapSymLocs,
llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSymbols,
llvm::ArrayRef<mlir::Type> mapSymTypes,
const mlir::Location &currentLocation) {
assert(mapSymTypes.size() == mapSymLocs.size());
@@ -1163,7 +1109,7 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
};
// Bind the symbols to their corresponding block arguments.
for (auto [argIndex, argSymbol] : llvm::enumerate(mapSymbols)) {
for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) {
const mlir::BlockArgument &arg = region.getArgument(argIndex);
// Avoid capture of a reference to a structured binding.
const Fortran::semantics::Symbol *sym = argSymbol;
@@ -1287,31 +1233,25 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &clauseList,
llvm::omp::Directive directive, bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
llvm::SmallVector<mlir::Type> mapSymTypes;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
llvm::SmallVector<mlir::Value> devicePtrOperands, deviceAddrOperands;
llvm::SmallVector<mlir::Type> devicePtrTypes, deviceAddrTypes;
llvm::SmallVector<mlir::Location> devicePtrLocs, deviceAddrLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> devicePtrSymbols,
deviceAddrSymbols;
mlir::omp::TargetClauseOps clauseOps;
llvm::SmallVector<mlir::Type> mapTypes, devicePtrTypes, deviceAddrTypes;
llvm::SmallVector<mlir::Location> mapLocs, devicePtrLocs, deviceAddrLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms, devicePtrSyms,
deviceAddrSyms;
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(llvm::omp::Directive::OMPD_target, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLimitOperand);
cp.processDepend(dependTypeOperands, dependOperands);
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
&mapSymLocs, &mapSymbols);
cp.processIsDevicePtr(devicePtrOperands, devicePtrTypes, devicePtrLocs,
devicePtrSymbols);
cp.processHasDeviceAddr(deviceAddrOperands, deviceAddrTypes, deviceAddrLocs,
deviceAddrSymbols);
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
cp.processDepend(clauseOps);
cp.processNowait(clauseOps);
cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms,
&mapLocs, &mapTypes);
cp.processIsDevicePtr(clauseOps, devicePtrTypes, devicePtrLocs,
devicePtrSyms);
cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs,
deviceAddrSyms);
// TODO Support delayed privatization.
cp.processTODO<clause::Private, clause::Firstprivate, clause::Reduction,
clause::InReduction, clause::Allocate, clause::UsesAllocators,
@@ -1323,7 +1263,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
// symbols used inside the region that have not been explicitly mapped using
// the map clause.
auto captureImplicitMap = [&](const Fortran::semantics::Symbol &sym) {
if (llvm::find(mapSymbols, &sym) == mapSymbols.end()) {
if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
mlir::Value baseOp = converter.getSymbolAddress(sym);
if (!baseOp)
if (const auto *details = sym.template detailsIf<
@@ -1394,26 +1334,20 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
mapFlag),
captureKind, baseOp.getType());
mapOperands.push_back(mapOp);
mapSymTypes.push_back(baseOp.getType());
mapSymLocs.push_back(baseOp.getLoc());
mapSymbols.push_back(&sym);
clauseOps.mapVars.push_back(mapOp);
mapSyms.push_back(&sym);
mapLocs.push_back(baseOp.getLoc());
mapTypes.push_back(baseOp.getType());
}
}
};
Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
dependTypeOperands.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
dependTypeOperands),
dependOperands, nowaitAttr, devicePtrOperands, deviceAddrOperands,
mapOperands);
currentLocation, clauseOps);
genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
mapSymLocs, mapSymbols, currentLocation);
genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms,
mapLocs, mapTypes, currentLocation);
return targetOp;
}
@@ -1426,17 +1360,16 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand;
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
mlir::omp::TeamsClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(llvm::omp::Directive::OMPD_teams, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
cp.processAllocate(clauseOps);
cp.processDefault();
cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
cp.processThreadLimit(stmtCtx, threadLimitClauseOperand);
cp.processNumTeams(stmtCtx, clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
// TODO Support delayed privatization.
cp.processTODO<clause::Reduction>(currentLocation,
llvm::omp::Directive::OMPD_teams);
@@ -1445,30 +1378,20 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList),
/*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
threadLimitClauseOperand, allocateOperands, allocatorOperands,
reductionVars,
reductionDeclSymbols.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols));
clauseOps);
}
/// Extract the list of function and variable symbols affected by the given
/// 'declare target' directive and return the intended device type for them.
static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
static void getDeclareTargetInfo(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
mlir::omp::DeclareTargetClauseOps &clauseOps,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
// The default capture type
mlir::omp::DeclareTargetDeviceType deviceType =
mlir::omp::DeclareTargetDeviceType::any;
const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
declareTargetConstruct.t);
if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
ObjectList objects{makeObjects(*objectList, semaCtx)};
@@ -1489,12 +1412,10 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
cp.processTo(symbolAndClause);
cp.processEnter(symbolAndClause);
cp.processLink(symbolAndClause);
cp.processDeviceType(deviceType);
cp.processDeviceType(clauseOps);
cp.processTODO<clause::Indirect>(converter.getCurrentLocation(),
llvm::omp::Directive::OMPD_declare_target);
}
return deviceType;
}
static void collectDeferredDeclareTargets(
@@ -1504,9 +1425,10 @@ static void collectDeferredDeclareTargets(
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
llvm::SmallVectorImpl<Fortran::lower::OMPDeferredDeclareTargetInfo>
&deferredDeclareTarget) {
mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
mlir::omp::DeclareTargetDeviceType devType = getDeclareTargetInfo(
converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
@@ -1516,8 +1438,9 @@ static void collectDeferredDeclareTargets(
std::get<const Fortran::semantics::Symbol &>(symClause)));
if (!op) {
deferredDeclareTarget.push_back(
{std::get<0>(symClause), devType, std::get<1>(symClause)});
deferredDeclareTarget.push_back({std::get<0>(symClause),
clauseOps.deviceType,
std::get<1>(symClause)});
}
}
}
@@ -1529,9 +1452,10 @@ getDeclareTargetFunctionDevice(
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
@@ -1541,7 +1465,7 @@ getDeclareTargetFunctionDevice(
std::get<const Fortran::semantics::Symbol &>(symClause)));
if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
return deviceType;
return clauseOps.deviceType;
}
return std::nullopt;
@@ -1571,12 +1495,14 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_barrier:
firOpBuilder.create<mlir::omp::BarrierOp>(currentLocation);
break;
case llvm::omp::Directive::OMPD_taskwait:
ClauseProcessor(converter, semaCtx, opClauseList)
.processTODO<clause::Depend, clause::Nowait>(
currentLocation, llvm::omp::Directive::OMPD_taskwait);
firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation);
case llvm::omp::Directive::OMPD_taskwait: {
mlir::omp::TaskwaitClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, opClauseList);
cp.processTODO<clause::Depend, clause::Nowait>(
currentLocation, llvm::omp::Directive::OMPD_taskwait);
firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation, clauseOps);
break;
}
case llvm::omp::Directive::OMPD_taskyield:
firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
break;
@@ -1711,32 +1637,21 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
mlir::Value scheduleChunkClauseOperand, ifClauseOperand;
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
llvm::SmallVector<mlir::Value> alignedVars, nontemporalVars;
mlir::omp::SimdLoopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand;
ClauseProcessor cp(converter, semaCtx, loopOpClauseList);
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols);
cp.processIf(llvm::omp::Directive::OMPD_simd, ifClauseOperand);
cp.processSimdlen(simdlenClauseOperand);
cp.processSafelen(safelenClauseOperand);
cp.processCollapse(loc, eval, clauseOps, iv);
cp.processReduction(loc, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps);
cp.processSimdlen(clauseOps);
cp.processSafelen(clauseOps);
clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
// TODO Support delayed privatization.
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
clause::Nontemporal, clause::Order>(loc, ompDirective);
mlir::TypeRange resultType;
auto simdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
loc, resultType, lowerBound, upperBound, step, alignedVars,
/*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars,
orderClauseOperand, simdlenClauseOperand, safelenClauseOperand,
/*inclusive=*/firOpBuilder.getUnitAttr());
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(loopOpClauseList));
@@ -1744,11 +1659,12 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
return genLoopVars(op, converter, loc, iv);
};
createBodyOfOp<mlir::omp::SimdLoopOp>(
simdLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
.setClauses(&loopOpClauseList)
.setDataSharingProcessor(&dsp)
.setGenRegionEntryCb(ivCallback));
genOpWithBody<mlir::omp::SimdLoopOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
.setClauses(&loopOpClauseList)
.setDataSharingProcessor(&dsp)
.setGenRegionEntryCb(ivCallback),
clauseOps);
}
static void createWsloop(Fortran::lower::AbstractConverter &converter,
@@ -1763,77 +1679,50 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
mlir::Value scheduleChunkClauseOperand;
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
mlir::omp::WsloopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
mlir::IntegerAttr orderedClauseOperand;
mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ClauseProcessor cp(converter, semaCtx, beginClauseList);
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols,
&reductionSymbols);
cp.processTODO<clause::Linear, clause::Order>(loc, ompDirective);
cp.processCollapse(loc, eval, clauseOps, iv);
cp.processSchedule(stmtCtx, clauseOps);
cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
cp.processOrdered(clauseOps);
clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
// TODO Support delayed privatization.
if (ReductionProcessor::doReductionByRef(reductionVars))
byrefOperand = firOpBuilder.getUnitAttr();
if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
auto wsLoopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
loc, lowerBound, upperBound, step, linearVars, linearStepVars,
reductionVars,
reductionDeclSymbols.empty()
? nullptr
: mlir::ArrayAttr::get(firOpBuilder.getContext(),
reductionDeclSymbols),
scheduleValClauseOperand, scheduleChunkClauseOperand,
/*schedule_modifiers=*/nullptr,
/*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
orderedClauseOperand, orderClauseOperand,
/*inclusive=*/firOpBuilder.getUnitAttr());
cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(loc,
ompDirective);
// Handle attribute based clauses.
if (cp.processOrdered(orderedClauseOperand))
wsLoopOp.setOrderedValAttr(orderedClauseOperand);
if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand,
scheduleSimdClauseOperand)) {
wsLoopOp.setScheduleValAttr(scheduleValClauseOperand);
wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand);
wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
}
// In FORTRAN `nowait` clause occur at the end of `omp do` directive.
// i.e
// !$omp do
// <...>
// !$omp end do nowait
if (endClauseList) {
if (ClauseProcessor(converter, semaCtx, *endClauseList)
.processNowait(nowaitClauseOperand))
wsLoopOp.setNowaitAttr(nowaitClauseOperand);
ClauseProcessor ecp(converter, semaCtx, *endClauseList);
ecp.processNowait(clauseOps);
}
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(beginClauseList));
auto ivCallback = [&](mlir::Operation *op) {
return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms,
reductionTypes);
};
createBodyOfOp<mlir::omp::WsloopOp>(
wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
.setClauses(&beginClauseList)
.setDataSharingProcessor(&dsp)
.setReductions(&reductionSymbols, &reductionTypes)
.setGenRegionEntryCb(ivCallback));
genOpWithBody<mlir::omp::WsloopOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
.setClauses(&beginClauseList)
.setDataSharingProcessor(&dsp)
.setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(ivCallback),
clauseOps);
}
static void createSimdWsloop(
@@ -1921,10 +1810,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
clauseOps, symbolAndClause);
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
mlir::Operation *op = mod.lookupSymbol(converter.mangleName(
@@ -1938,7 +1828,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
markDeclareTarget(
op, converter,
std::get<mlir::omp::DeclareTargetCaptureClause>(symClause), deviceType);
std::get<mlir::omp::DeclareTargetCaptureClause>(symClause),
clauseOps.deviceType);
}
}
@@ -2072,7 +1963,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
!std::get_if<Fortran::parser::OmpClause::IsDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::HasDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
!std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::Simd>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
@@ -2092,7 +1984,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_ordered:
genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true,
currentLocation);
currentLocation, beginClauseList);
break;
case llvm::omp::Directive::OMPD_parallel:
genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
@@ -2183,7 +2075,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
mlir::IntegerAttr hintClauseOp;
std::string name;
const Fortran::parser::OmpCriticalDirective &cd =
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
@@ -2192,21 +2083,28 @@ genOMP(Fortran::lower::AbstractConverter &converter,
std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
}
const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
ClauseProcessor(converter, semaCtx, clauseList).processHint(hintClauseOp);
mlir::omp::CriticalOp criticalOp = [&]() {
if (name.empty()) {
return firOpBuilder.create<mlir::omp::CriticalOp>(
currentLocation, mlir::FlatSymbolRefAttr());
}
mlir::ModuleOp module = firOpBuilder.getModule();
mlir::OpBuilder modBuilder(module.getBodyRegion());
auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
if (!global)
global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
currentLocation,
mlir::StringAttr::get(firOpBuilder.getContext(), name), hintClauseOp);
if (!global) {
mlir::omp::CriticalClauseOps clauseOps;
const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processHint(clauseOps);
clauseOps.nameAttr =
mlir::StringAttr::get(firOpBuilder.getContext(), name);
global = modBuilder.create<mlir::omp::CriticalDeclareOp>(currentLocation,
clauseOps);
}
return firOpBuilder.create<mlir::omp::CriticalOp>(
currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(),
global.getSymName()));
@@ -2323,8 +2221,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionsConstruct &sectionsConstruct) {
mlir::Location currentLocation = converter.getCurrentLocation();
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
mlir::UnitAttr nowaitClauseOperand;
mlir::omp::SectionsClauseOps clauseOps;
const auto &beginSectionsDirective =
std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
const auto &sectionsClauseList =
@@ -2333,8 +2230,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
// Process clauses before optional omp.parallel, so that new variables are
// allocated outside of the parallel region
ClauseProcessor cp(converter, semaCtx, sectionsClauseList);
cp.processSectionsReduction(currentLocation);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processSectionsReduction(currentLocation, clauseOps);
cp.processAllocate(clauseOps);
// TODO Support delayed privatization.
llvm::omp::Directive dir =
std::get<Fortran::parser::OmpSectionsDirective>(beginSectionsDirective.t)
@@ -2351,16 +2249,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const auto &endSectionsClauseList =
std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
ClauseProcessor(converter, semaCtx, endSectionsClauseList)
.processNowait(nowaitClauseOperand);
.processNowait(clauseOps);
}
// SECTIONS construct
genOpWithBody<mlir::omp::SectionsOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(false),
/*reduction_vars=*/mlir::ValueRange(),
/*reductions=*/nullptr, allocateOperands, allocatorOperands,
nowaitClauseOperand);
clauseOps);
const auto &sectionBlocks =
std::get<Fortran::parser::OmpSectionBlocks>(sectionsConstruct.t);

View File

@@ -81,7 +81,7 @@ struct GrainsizeClauseOps {
Value grainsizeVar;
};
struct HasDeviceAddrOps {
struct HasDeviceAddrClauseOps {
llvm::SmallVector<Value> hasDeviceAddrVars;
};
struct HintClauseOps {
@@ -97,7 +97,7 @@ struct InReductionClauseOps {
llvm::SmallVector<Attribute> inReductionDeclSymbols;
};
struct IsDevicePtrOps {
struct IsDevicePtrClauseOps {
llvm::SmallVector<Value> isDevicePtrVars;
};
@@ -234,6 +234,8 @@ using DistributeClauseOps =
detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
PrivateClauseOps>;
using LoopNestClauseOps = detail::Clauses<CollapseClauseOps, LoopRelatedOps>;
// TODO `filter` clause.
using MaskedClauseOps = detail::Clauses<>;
@@ -261,8 +263,8 @@ using SingleClauseOps = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
// TODO `defaultmap`, `uses_allocators` clauses.
using TargetClauseOps =
detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
HasDeviceAddrOps, IfClauseOps, InReductionClauseOps,
IsDevicePtrOps, MapClauseOps, NowaitClauseOps,
HasDeviceAddrClauseOps, IfClauseOps, InReductionClauseOps,
IsDevicePtrClauseOps, MapClauseOps, NowaitClauseOps,
PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
using TargetDataClauseOps = detail::Clauses<DeviceClauseOps, IfClauseOps,

View File

@@ -574,6 +574,10 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
Variadic<IntLikeType>:$step,
UnitAttr:$inclusive);
let builders = [
OpBuilder<(ins CArg<"const LoopNestClauseOps &">:$clauses)>
];
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{

View File

@@ -1920,6 +1920,12 @@ void LoopNestOp::print(OpAsmPrinter &p) {
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
void LoopNestOp::build(OpBuilder &builder, OperationState &state,
const LoopNestClauseOps &clauses) {
LoopNestOp::build(builder, state, clauses.loopLBVar, clauses.loopUBVar,
clauses.loopStepVar, clauses.loopInclusiveAttr);
}
LogicalResult LoopNestOp::verify() {
if (getLowerBound().size() != getIVs().size())
return emitOpError() << "number of range arguments and IVs do not match";