mirror of
https://github.com/intel/llvm.git
synced 2026-01-14 11:57:39 +08:00
This patch add support for lowering of custom reductions to MLIR. It also enhances the capability of the pass to automatically mark functions as "declare target" by traversing custom reduction initializers and combiners.
1725 lines
70 KiB
C++
1725 lines
70 KiB
C++
//===-- ClauseProcessor.cpp -------------------------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "ClauseProcessor.h"
|
|
#include "Utils.h"
|
|
|
|
#include "flang/Lower/ConvertCall.h"
|
|
#include "flang/Lower/ConvertExprToHLFIR.h"
|
|
#include "flang/Lower/OpenMP/Clauses.h"
|
|
#include "flang/Lower/PFTBuilder.h"
|
|
#include "flang/Lower/Support/ReductionProcessor.h"
|
|
#include "flang/Optimizer/Dialect/FIRType.h"
|
|
#include "flang/Parser/tools.h"
|
|
#include "flang/Semantics/tools.h"
|
|
#include "flang/Utils/OpenMP.h"
|
|
#include "llvm/Frontend/OpenMP/OMP.h.inc"
|
|
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
|
|
|
|
namespace Fortran {
|
|
namespace lower {
|
|
namespace omp {
|
|
|
|
using ReductionModifier =
|
|
Fortran::lower::omp::clause::Reduction::ReductionModifier;
|
|
|
|
mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
|
|
switch (mod) {
|
|
case ReductionModifier::Default:
|
|
return mlir::omp::ReductionModifier::defaultmod;
|
|
case ReductionModifier::Inscan:
|
|
return mlir::omp::ReductionModifier::inscan;
|
|
case ReductionModifier::Task:
|
|
return mlir::omp::ReductionModifier::task;
|
|
}
|
|
return mlir::omp::ReductionModifier::defaultmod;
|
|
}
|
|
|
|
/// Check for unsupported map operand types.
|
|
static void checkMapType(mlir::Location location, mlir::Type type) {
|
|
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type))
|
|
type = refType.getElementType();
|
|
if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type))
|
|
if (!mlir::isa<fir::PointerType>(boxType.getElementType()))
|
|
TODO(location, "OMPD_target_data MapOperand BoxType");
|
|
}
|
|
|
|
static mlir::omp::ScheduleModifier
|
|
translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) {
|
|
switch (m) {
|
|
case omp::clause::Schedule::OrderingModifier::Monotonic:
|
|
return mlir::omp::ScheduleModifier::monotonic;
|
|
case omp::clause::Schedule::OrderingModifier::Nonmonotonic:
|
|
return mlir::omp::ScheduleModifier::nonmonotonic;
|
|
}
|
|
return mlir::omp::ScheduleModifier::none;
|
|
}
|
|
|
|
static mlir::omp::ScheduleModifier
|
|
getScheduleModifier(const omp::clause::Schedule &clause) {
|
|
using Schedule = omp::clause::Schedule;
|
|
const auto &modifier =
|
|
std::get<std::optional<Schedule::OrderingModifier>>(clause.t);
|
|
if (modifier)
|
|
return translateScheduleModifier(*modifier);
|
|
return mlir::omp::ScheduleModifier::none;
|
|
}
|
|
|
|
static mlir::omp::ScheduleModifier
|
|
getSimdModifier(const omp::clause::Schedule &clause) {
|
|
using Schedule = omp::clause::Schedule;
|
|
const auto &modifier =
|
|
std::get<std::optional<Schedule::ChunkModifier>>(clause.t);
|
|
if (modifier && *modifier == Schedule::ChunkModifier::Simd)
|
|
return mlir::omp::ScheduleModifier::simd;
|
|
return mlir::omp::ScheduleModifier::none;
|
|
}
|
|
|
|
static void
|
|
genAllocateClause(lower::AbstractConverter &converter,
|
|
const omp::clause::Allocate &clause,
|
|
llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
|
|
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
mlir::Location currentLocation = converter.getCurrentLocation();
|
|
lower::StatementContext stmtCtx;
|
|
|
|
auto &objects = std::get<omp::ObjectList>(clause.t);
|
|
|
|
using Allocate = omp::clause::Allocate;
|
|
// ALIGN in this context is unimplemented
|
|
if (std::get<std::optional<Allocate::AlignModifier>>(clause.t))
|
|
TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
|
|
|
|
// Check if allocate clause has allocator specified. If so, add it
|
|
// to list of allocators, otherwise, add default allocator to
|
|
// list of allocators.
|
|
using ComplexModifier = Allocate::AllocatorComplexModifier;
|
|
if (auto &mod = std::get<std::optional<ComplexModifier>>(clause.t)) {
|
|
mlir::Value operand = fir::getBase(converter.genExprValue(mod->v, stmtCtx));
|
|
allocatorOperands.append(objects.size(), operand);
|
|
} else {
|
|
mlir::Value operand = firOpBuilder.createIntegerConstant(
|
|
currentLocation, firOpBuilder.getI32Type(), 1);
|
|
allocatorOperands.append(objects.size(), operand);
|
|
}
|
|
|
|
genObjectList(objects, converter, allocateOperands);
|
|
}
|
|
|
|
static mlir::omp::ClauseBindKindAttr
|
|
genBindKindAttr(fir::FirOpBuilder &firOpBuilder,
|
|
const omp::clause::Bind &clause) {
|
|
mlir::omp::ClauseBindKind bindKind;
|
|
switch (clause.v) {
|
|
case omp::clause::Bind::Binding::Teams:
|
|
bindKind = mlir::omp::ClauseBindKind::Teams;
|
|
break;
|
|
case omp::clause::Bind::Binding::Parallel:
|
|
bindKind = mlir::omp::ClauseBindKind::Parallel;
|
|
break;
|
|
case omp::clause::Bind::Binding::Thread:
|
|
bindKind = mlir::omp::ClauseBindKind::Thread;
|
|
break;
|
|
}
|
|
return mlir::omp::ClauseBindKindAttr::get(firOpBuilder.getContext(),
|
|
bindKind);
|
|
}
|
|
|
|
static mlir::omp::ClauseProcBindKindAttr
|
|
genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder,
|
|
const omp::clause::ProcBind &clause) {
|
|
mlir::omp::ClauseProcBindKind procBindKind;
|
|
switch (clause.v) {
|
|
case omp::clause::ProcBind::AffinityPolicy::Master:
|
|
procBindKind = mlir::omp::ClauseProcBindKind::Master;
|
|
break;
|
|
case omp::clause::ProcBind::AffinityPolicy::Close:
|
|
procBindKind = mlir::omp::ClauseProcBindKind::Close;
|
|
break;
|
|
case omp::clause::ProcBind::AffinityPolicy::Spread:
|
|
procBindKind = mlir::omp::ClauseProcBindKind::Spread;
|
|
break;
|
|
case omp::clause::ProcBind::AffinityPolicy::Primary:
|
|
procBindKind = mlir::omp::ClauseProcBindKind::Primary;
|
|
break;
|
|
}
|
|
return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(),
|
|
procBindKind);
|
|
}
|
|
|
|
static mlir::omp::ClauseTaskDependAttr
|
|
genDependKindAttr(lower::AbstractConverter &converter,
|
|
const omp::clause::DependenceType kind) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
mlir::Location currentLocation = converter.getCurrentLocation();
|
|
|
|
mlir::omp::ClauseTaskDepend pbKind;
|
|
switch (kind) {
|
|
case omp::clause::DependenceType::In:
|
|
pbKind = mlir::omp::ClauseTaskDepend::taskdependin;
|
|
break;
|
|
case omp::clause::DependenceType::Out:
|
|
pbKind = mlir::omp::ClauseTaskDepend::taskdependout;
|
|
break;
|
|
case omp::clause::DependenceType::Inout:
|
|
pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
|
|
break;
|
|
case omp::clause::DependenceType::Mutexinoutset:
|
|
pbKind = mlir::omp::ClauseTaskDepend::taskdependmutexinoutset;
|
|
break;
|
|
case omp::clause::DependenceType::Inoutset:
|
|
pbKind = mlir::omp::ClauseTaskDepend::taskdependinoutset;
|
|
break;
|
|
case omp::clause::DependenceType::Depobj:
|
|
TODO(currentLocation, "DEPOBJ dependence-type");
|
|
break;
|
|
case omp::clause::DependenceType::Sink:
|
|
case omp::clause::DependenceType::Source:
|
|
llvm_unreachable("unhandled parser task dependence type");
|
|
break;
|
|
}
|
|
return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(),
|
|
pbKind);
|
|
}
|
|
|
|
static mlir::Value
|
|
getIfClauseOperand(lower::AbstractConverter &converter,
|
|
const omp::clause::If &clause,
|
|
omp::clause::If::DirectiveNameModifier directiveName,
|
|
mlir::Location clauseLocation) {
|
|
// Only consider the clause if it's intended for the given directive.
|
|
auto &directive =
|
|
std::get<std::optional<omp::clause::If::DirectiveNameModifier>>(clause.t);
|
|
if (directive && directive.value() != directiveName)
|
|
return nullptr;
|
|
|
|
lower::StatementContext stmtCtx;
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
mlir::Value ifVal = fir::getBase(
|
|
converter.genExprValue(std::get<omp::SomeExpr>(clause.t), stmtCtx));
|
|
return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
|
|
ifVal);
|
|
}
|
|
|
|
static void addUseDeviceClause(
|
|
lower::AbstractConverter &converter, const omp::ObjectList &objects,
|
|
llvm::SmallVectorImpl<mlir::Value> &operands,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
|
|
genObjectList(objects, converter, operands);
|
|
for (mlir::Value &operand : operands)
|
|
checkMapType(operand.getLoc(), operand.getType());
|
|
|
|
for (const omp::Object &object : objects)
|
|
useDeviceSyms.push_back(object.sym());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ClauseProcessor unique clauses
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool ClauseProcessor::processBare(mlir::omp::BareClauseOps &result) const {
|
|
return markClauseOccurrence<omp::clause::OmpxBare>(result.bare);
|
|
}
|
|
|
|
bool ClauseProcessor::processBind(mlir::omp::BindClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Bind>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
result.bindKind = genBindKindAttr(firOpBuilder, *clause);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processCancelDirectiveName(
|
|
mlir::omp::CancelDirectiveNameClauseOps &result) const {
|
|
using ConstructType = mlir::omp::ClauseCancellationConstructType;
|
|
mlir::MLIRContext *context = &converter.getMLIRContext();
|
|
|
|
ConstructType directive;
|
|
if (auto *clause = findUniqueClause<omp::CancellationConstructType>()) {
|
|
switch (clause->v) {
|
|
case llvm::omp::OMP_CANCELLATION_CONSTRUCT_Parallel:
|
|
directive = mlir::omp::ClauseCancellationConstructType::Parallel;
|
|
break;
|
|
case llvm::omp::OMP_CANCELLATION_CONSTRUCT_Loop:
|
|
directive = mlir::omp::ClauseCancellationConstructType::Loop;
|
|
break;
|
|
case llvm::omp::OMP_CANCELLATION_CONSTRUCT_Sections:
|
|
directive = mlir::omp::ClauseCancellationConstructType::Sections;
|
|
break;
|
|
case llvm::omp::OMP_CANCELLATION_CONSTRUCT_Taskgroup:
|
|
directive = mlir::omp::ClauseCancellationConstructType::Taskgroup;
|
|
break;
|
|
case llvm::omp::OMP_CANCELLATION_CONSTRUCT_None:
|
|
llvm_unreachable("OMP_CANCELLATION_CONSTRUCT_None");
|
|
break;
|
|
}
|
|
} else {
|
|
llvm_unreachable("cancel construct missing cancellation construct type");
|
|
}
|
|
|
|
result.cancelDirective =
|
|
mlir::omp::ClauseCancellationConstructTypeAttr::get(context, directive);
|
|
return true;
|
|
}
|
|
|
|
bool ClauseProcessor::processCollapse(
|
|
mlir::Location currentLocation, lower::pft::Evaluation &eval,
|
|
mlir::omp::LoopRelatedClauseOps &loopResult,
|
|
mlir::omp::CollapseClauseOps &collapseResult,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
|
|
|
|
int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
|
|
clauses, loopResult, iv);
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
|
|
return numCollapse > 1;
|
|
}
|
|
|
|
bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,
|
|
mlir::omp::DeviceClauseOps &result) const {
|
|
const parser::CharBlock *source = nullptr;
|
|
if (auto *clause = findUniqueClause<omp::clause::Device>(&source)) {
|
|
mlir::Location clauseLocation = converter.genLocation(*source);
|
|
if (auto deviceModifier =
|
|
std::get<std::optional<omp::clause::Device::DeviceModifier>>(
|
|
clause->t)) {
|
|
if (deviceModifier == omp::clause::Device::DeviceModifier::Ancestor) {
|
|
TODO(clauseLocation, "OMPD_target Device Modifier Ancestor");
|
|
}
|
|
}
|
|
const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
|
|
result.device = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processDeviceType(
|
|
mlir::omp::DeviceTypeClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::DeviceType>()) {
|
|
// Case: declare target ... device_type(any | host | nohost)
|
|
switch (clause->v) {
|
|
case omp::clause::DeviceType::DeviceTypeDescription::Nohost:
|
|
result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost;
|
|
break;
|
|
case omp::clause::DeviceType::DeviceTypeDescription::Host:
|
|
result.deviceType = mlir::omp::DeclareTargetDeviceType::host;
|
|
break;
|
|
case omp::clause::DeviceType::DeviceTypeDescription::Any:
|
|
result.deviceType = mlir::omp::DeclareTargetDeviceType::any;
|
|
break;
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processDistSchedule(
|
|
lower::StatementContext &stmtCtx,
|
|
mlir::omp::DistScheduleClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::DistSchedule>()) {
|
|
result.distScheduleStatic = converter.getFirOpBuilder().getUnitAttr();
|
|
const auto &chunkSize = std::get<std::optional<ExprTy>>(clause->t);
|
|
if (chunkSize)
|
|
result.distScheduleChunkSize =
|
|
fir::getBase(converter.genExprValue(*chunkSize, stmtCtx));
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processExclusive(
|
|
mlir::Location currentLocation,
|
|
mlir::omp::ExclusiveClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Exclusive>()) {
|
|
for (const Object &object : clause->v) {
|
|
const semantics::Symbol *symbol = object.sym();
|
|
mlir::Value symVal = converter.getSymbolAddress(*symbol);
|
|
result.exclusiveVars.push_back(symVal);
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
|
|
mlir::omp::FilterClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
|
|
result.filteredThreadId =
|
|
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processFinal(lower::StatementContext &stmtCtx,
|
|
mlir::omp::FinalClauseOps &result) const {
|
|
const parser::CharBlock *source = nullptr;
|
|
if (auto *clause = findUniqueClause<omp::clause::Final>(&source)) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
mlir::Location clauseLocation = converter.genLocation(*source);
|
|
|
|
mlir::Value finalVal =
|
|
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
|
|
result.final = firOpBuilder.createConvert(
|
|
clauseLocation, firOpBuilder.getI1Type(), finalVal);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
int64_t hintValue = *evaluate::ToInt64(clause->v);
|
|
result.hint = firOpBuilder.getI64IntegerAttr(hintValue);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processInclusive(
|
|
mlir::Location currentLocation,
|
|
mlir::omp::InclusiveClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Inclusive>()) {
|
|
for (const Object &object : clause->v) {
|
|
const semantics::Symbol *symbol = object.sym();
|
|
mlir::Value symVal = converter.getSymbolAddress(*symbol);
|
|
result.inclusiveVars.push_back(symVal);
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processInitializer(
|
|
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
|
|
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
|
|
genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
|
|
mlir::Type type, mlir::Value ompOrig) {
|
|
lower::SymMapScope scope(symMap);
|
|
const parser::OmpInitializerExpression &iexpr = inp.v.v;
|
|
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
|
|
const std::list<parser::OmpStylizedDeclaration> &declList =
|
|
std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
|
|
mlir::Value ompPrivVar;
|
|
for (const parser::OmpStylizedDeclaration &decl : declList) {
|
|
auto &name = std::get<parser::ObjectName>(decl.var.t);
|
|
assert(name.symbol && "Name does not have a symbol");
|
|
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
|
|
fir::StoreOp::create(builder, loc, ompOrig, addr);
|
|
fir::FortranVariableFlagsEnum extraFlags = {};
|
|
fir::FortranVariableFlagsAttr attributes =
|
|
Fortran::lower::translateSymbolAttributes(builder.getContext(),
|
|
*name.symbol, extraFlags);
|
|
auto declareOp = hlfir::DeclareOp::create(
|
|
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
|
|
0, attributes);
|
|
if (name.ToString() == "omp_priv")
|
|
ompPrivVar = declareOp.getResult(0);
|
|
symMap.addVariableDefinition(*name.symbol, declareOp);
|
|
}
|
|
// Lower the expression/function call
|
|
lower::StatementContext stmtCtx;
|
|
mlir::Value result = common::visit(
|
|
common::visitors{
|
|
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
|
|
convertCallToHLFIR(loc, converter, procRef, std::nullopt,
|
|
symMap, stmtCtx);
|
|
auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
|
|
return privVal;
|
|
},
|
|
[&](const auto &expr) -> mlir::Value {
|
|
mlir::Value exprResult = fir::getBase(convertExprToValue(
|
|
loc, converter, clause->v, symMap, stmtCtx));
|
|
// Conversion can either give a value or a refrence to a value,
|
|
// we need to return the reduction type, so an optional load may
|
|
// be generated.
|
|
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
|
|
exprResult.getType()))
|
|
if (ompPrivVar.getType() == refType)
|
|
exprResult = fir::LoadOp::create(builder, loc, exprResult);
|
|
return exprResult;
|
|
}},
|
|
clause->v.u);
|
|
stmtCtx.finalizeAndPop();
|
|
return result;
|
|
};
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processMergeable(
|
|
mlir::omp::MergeableClauseOps &result) const {
|
|
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
|
|
}
|
|
|
|
bool ClauseProcessor::processNogroup(
|
|
mlir::omp::NogroupClauseOps &result) const {
|
|
return markClauseOccurrence<omp::clause::Nogroup>(result.nogroup);
|
|
}
|
|
|
|
bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
|
|
return markClauseOccurrence<omp::clause::Nowait>(result.nowait);
|
|
}
|
|
|
|
bool ClauseProcessor::processNumTasks(
|
|
lower::StatementContext &stmtCtx,
|
|
mlir::omp::NumTasksClauseOps &result) const {
|
|
using NumTasks = omp::clause::NumTasks;
|
|
if (auto *clause = findUniqueClause<NumTasks>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
mlir::MLIRContext *context = firOpBuilder.getContext();
|
|
const auto &modifier =
|
|
std::get<std::optional<NumTasks::Prescriptiveness>>(clause->t);
|
|
if (modifier && *modifier == NumTasks::Prescriptiveness::Strict) {
|
|
result.numTasksMod = mlir::omp::ClauseNumTasksTypeAttr::get(
|
|
context, mlir::omp::ClauseNumTasksType::Strict);
|
|
}
|
|
const auto &numtasksExpr = std::get<omp::SomeExpr>(clause->t);
|
|
result.numTasks =
|
|
fir::getBase(converter.genExprValue(numtasksExpr, stmtCtx));
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processSizes(StatementContext &stmtCtx,
|
|
mlir::omp::SizesClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Sizes>()) {
|
|
result.sizes.reserve(clause->v.size());
|
|
for (const ExprTy &vv : clause->v)
|
|
result.sizes.push_back(fir::getBase(converter.genExprValue(vv, stmtCtx)));
|
|
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processNumTeams(
|
|
lower::StatementContext &stmtCtx,
|
|
mlir::omp::NumTeamsClauseOps &result) const {
|
|
// TODO Get lower and upper bounds for num_teams when parser is updated to
|
|
// accept both.
|
|
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
|
|
// The num_teams directive accepts a list of team lower/upper bounds.
|
|
// This is an extension to support grid specification for ompx_bare.
|
|
// Here, only expect a single element in the list.
|
|
assert(clause->v.size() == 1);
|
|
// auto lowerBound = std::get<std::optional<ExprTy>>(clause->v[0]->t);
|
|
auto &upperBound = std::get<ExprTy>(clause->v[0].t);
|
|
result.numTeamsUpper =
|
|
fir::getBase(converter.genExprValue(upperBound, stmtCtx));
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processNumThreads(
|
|
lower::StatementContext &stmtCtx,
|
|
mlir::omp::NumThreadsClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
|
|
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
|
|
result.numThreads =
|
|
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processOrder(mlir::omp::OrderClauseOps &result) const {
|
|
using Order = omp::clause::Order;
|
|
if (auto *clause = findUniqueClause<Order>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
result.order = mlir::omp::ClauseOrderKindAttr::get(
|
|
firOpBuilder.getContext(), mlir::omp::ClauseOrderKind::Concurrent);
|
|
const auto &modifier =
|
|
std::get<std::optional<Order::OrderModifier>>(clause->t);
|
|
if (modifier && *modifier == Order::OrderModifier::Unconstrained) {
|
|
result.orderMod = mlir::omp::OrderModifierAttr::get(
|
|
firOpBuilder.getContext(), mlir::omp::OrderModifier::unconstrained);
|
|
} else {
|
|
// "If order-modifier is not unconstrained, the behavior is as if the
|
|
// reproducible modifier is present."
|
|
result.orderMod = mlir::omp::OrderModifierAttr::get(
|
|
firOpBuilder.getContext(), mlir::omp::OrderModifier::reproducible);
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processOrdered(
|
|
mlir::omp::OrderedClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
int64_t orderedClauseValue = 0l;
|
|
if (clause->v.has_value())
|
|
orderedClauseValue = *evaluate::ToInt64(*clause->v);
|
|
result.ordered = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processPriority(
|
|
lower::StatementContext &stmtCtx,
|
|
mlir::omp::PriorityClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
|
|
result.priority = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processDetach(mlir::omp::DetachClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Detach>()) {
|
|
semantics::Symbol *sym = clause->v.sym();
|
|
mlir::Value symVal = converter.getSymbolAddress(*sym);
|
|
result.eventHandle = symVal;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processProcBind(
|
|
mlir::omp::ProcBindClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
result.procBindKind = genProcBindKindAttr(firOpBuilder, *clause);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processTileSizes(
|
|
lower::pft::Evaluation &eval, mlir::omp::LoopNestOperands &result) const {
|
|
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
|
|
collectTileSizesFromOpenMPConstruct(ompCons, result.tileSizes, semaCtx);
|
|
return !result.tileSizes.empty();
|
|
}
|
|
|
|
bool ClauseProcessor::processSafelen(
|
|
mlir::omp::SafelenClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
const std::optional<std::int64_t> safelenVal = evaluate::ToInt64(clause->v);
|
|
result.safelen = firOpBuilder.getI64IntegerAttr(*safelenVal);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processSchedule(
|
|
lower::StatementContext &stmtCtx,
|
|
mlir::omp::ScheduleClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
mlir::MLIRContext *context = firOpBuilder.getContext();
|
|
const auto &scheduleType = std::get<omp::clause::Schedule::Kind>(clause->t);
|
|
|
|
mlir::omp::ClauseScheduleKind scheduleKind;
|
|
switch (scheduleType) {
|
|
case omp::clause::Schedule::Kind::Static:
|
|
scheduleKind = mlir::omp::ClauseScheduleKind::Static;
|
|
break;
|
|
case omp::clause::Schedule::Kind::Dynamic:
|
|
scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic;
|
|
break;
|
|
case omp::clause::Schedule::Kind::Guided:
|
|
scheduleKind = mlir::omp::ClauseScheduleKind::Guided;
|
|
break;
|
|
case omp::clause::Schedule::Kind::Auto:
|
|
scheduleKind = mlir::omp::ClauseScheduleKind::Auto;
|
|
break;
|
|
case omp::clause::Schedule::Kind::Runtime:
|
|
scheduleKind = mlir::omp::ClauseScheduleKind::Runtime;
|
|
break;
|
|
}
|
|
|
|
result.scheduleKind =
|
|
mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
|
|
|
|
mlir::omp::ScheduleModifier scheduleMod = getScheduleModifier(*clause);
|
|
if (scheduleMod != mlir::omp::ScheduleModifier::none)
|
|
result.scheduleMod =
|
|
mlir::omp::ScheduleModifierAttr::get(context, scheduleMod);
|
|
|
|
if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
|
|
result.scheduleSimd = firOpBuilder.getUnitAttr();
|
|
|
|
if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
|
|
result.scheduleChunk =
|
|
fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
|
|
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processSimdlen(
|
|
mlir::omp::SimdlenClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
const std::optional<std::int64_t> simdlenVal = evaluate::ToInt64(clause->v);
|
|
result.simdlen = firOpBuilder.getI64IntegerAttr(*simdlenVal);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processThreadLimit(
|
|
lower::StatementContext &stmtCtx,
|
|
mlir::omp::ThreadLimitClauseOps &result) const {
|
|
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
|
|
result.threadLimit =
|
|
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
|
|
return markClauseOccurrence<omp::clause::Untied>(result.untied);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ClauseProcessor repeatable clauses
|
|
//===----------------------------------------------------------------------===//
|
|
static llvm::StringMap<bool> getTargetFeatures(mlir::ModuleOp module) {
|
|
llvm::StringMap<bool> featuresMap;
|
|
llvm::SmallVector<llvm::StringRef> targetFeaturesVec;
|
|
if (mlir::LLVM::TargetFeaturesAttr features =
|
|
fir::getTargetFeatures(module)) {
|
|
llvm::ArrayRef<mlir::StringAttr> featureAttrs = features.getFeatures();
|
|
for (auto &featureAttr : featureAttrs) {
|
|
llvm::StringRef featureKeyString = featureAttr.strref();
|
|
featuresMap[featureKeyString.substr(1)] = (featureKeyString[0] == '+');
|
|
}
|
|
}
|
|
return featuresMap;
|
|
}
|
|
|
|
static void
|
|
addAlignedClause(lower::AbstractConverter &converter,
|
|
const omp::clause::Aligned &clause,
|
|
llvm::SmallVectorImpl<mlir::Value> &alignedVars,
|
|
llvm::SmallVectorImpl<mlir::Attribute> &alignments) {
|
|
using Aligned = omp::clause::Aligned;
|
|
lower::StatementContext stmtCtx;
|
|
mlir::IntegerAttr alignmentValueAttr;
|
|
int64_t alignment = 0;
|
|
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
|
|
|
|
if (auto &alignmentValueParserExpr =
|
|
std::get<std::optional<Aligned::Alignment>>(clause.t)) {
|
|
mlir::Value operand = fir::getBase(
|
|
converter.genExprValue(*alignmentValueParserExpr, stmtCtx));
|
|
alignment = *fir::getIntIfConstant(operand);
|
|
} else {
|
|
llvm::StringMap<bool> featuresMap = getTargetFeatures(builder.getModule());
|
|
llvm::Triple triple = fir::getTargetTriple(builder.getModule());
|
|
alignment =
|
|
llvm::OpenMPIRBuilder::getOpenMPDefaultSimdAlign(triple, featuresMap);
|
|
}
|
|
|
|
// The default alignment for some targets is equal to 0.
|
|
// Do not generate alignment assumption if alignment is less than or equal to
|
|
// 0 or not a power of two
|
|
if (alignment > 0 && ((alignment & (alignment - 1)) == 0)) {
|
|
auto &objects = std::get<omp::ObjectList>(clause.t);
|
|
if (!objects.empty())
|
|
genObjectList(objects, converter, alignedVars);
|
|
alignmentValueAttr = builder.getI64IntegerAttr(alignment);
|
|
// All the list items in a aligned clause will have same alignment
|
|
for (std::size_t i = 0; i < objects.size(); i++)
|
|
alignments.push_back(alignmentValueAttr);
|
|
}
|
|
}
|
|
|
|
bool ClauseProcessor::processAligned(
|
|
mlir::omp::AlignedClauseOps &result) const {
|
|
return findRepeatableClause<omp::clause::Aligned>(
|
|
[&](const omp::clause::Aligned &clause, const parser::CharBlock &) {
|
|
addAlignedClause(converter, clause, result.alignedVars,
|
|
result.alignments);
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processAllocate(
|
|
mlir::omp::AllocateClauseOps &result) const {
|
|
return findRepeatableClause<omp::clause::Allocate>(
|
|
[&](const omp::clause::Allocate &clause, const parser::CharBlock &) {
|
|
genAllocateClause(converter, clause, result.allocatorVars,
|
|
result.allocateVars);
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processCopyin() const {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
|
|
firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
|
|
auto checkAndCopyHostAssociateVar =
|
|
[&](semantics::Symbol *sym,
|
|
mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) {
|
|
assert(sym->has<semantics::HostAssocDetails>() &&
|
|
"No host-association found");
|
|
if (converter.isPresentShallowLookup(*sym))
|
|
converter.copyHostAssociateVar(*sym, copyAssignIP);
|
|
};
|
|
bool hasCopyin = findRepeatableClause<omp::clause::Copyin>(
|
|
[&](const omp::clause::Copyin &clause, const parser::CharBlock &) {
|
|
for (const omp::Object &object : clause.v) {
|
|
semantics::Symbol *sym = object.sym();
|
|
assert(sym && "Expecting symbol");
|
|
if (const auto *commonDetails =
|
|
sym->detailsIf<semantics::CommonBlockDetails>()) {
|
|
for (const auto &mem : commonDetails->objects())
|
|
checkAndCopyHostAssociateVar(&*mem, &insPt);
|
|
break;
|
|
}
|
|
|
|
assert(sym->has<semantics::HostAssocDetails>() &&
|
|
"No host-association found");
|
|
checkAndCopyHostAssociateVar(sym);
|
|
}
|
|
});
|
|
|
|
// [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to
|
|
// the execution of the associated structured block. Emit implicit barrier to
|
|
// synchronize threads and avoid data races on propagation master's thread
|
|
// values of threadprivate variables to local instances of that variables of
|
|
// all other implicit threads.
|
|
|
|
// All copies are inserted at either "insPt" (i.e. immediately before it),
|
|
// or at some earlier point (as determined by "copyHostAssociateVar").
|
|
// Unless the insertion point is given to "copyHostAssociateVar" explicitly,
|
|
// it will not restore the builder's insertion point. Since the copies may be
|
|
// inserted in any order (not following the execution order), make sure the
|
|
// barrier is inserted following all of them.
|
|
firOpBuilder.restoreInsertionPoint(insPt);
|
|
if (hasCopyin)
|
|
mlir::omp::BarrierOp::create(firOpBuilder, converter.getCurrentLocation());
|
|
return hasCopyin;
|
|
}
|
|
|
|
/// Class that extracts information from the specified type.
|
|
class TypeInfo {
|
|
public:
|
|
TypeInfo(mlir::Type ty) { typeScan(ty); }
|
|
|
|
// Returns the length of character types.
|
|
std::optional<fir::CharacterType::LenType> getCharLength() const {
|
|
return charLen;
|
|
}
|
|
|
|
// Returns the shape of array types.
|
|
llvm::ArrayRef<int64_t> getShape() const { return shape; }
|
|
|
|
// Is the type inside a box?
|
|
bool isBox() const { return inBox; }
|
|
|
|
bool isBoxChar() const { return inBoxChar; }
|
|
|
|
private:
|
|
void typeScan(mlir::Type type);
|
|
|
|
std::optional<fir::CharacterType::LenType> charLen;
|
|
llvm::SmallVector<int64_t> shape;
|
|
bool inBox = false;
|
|
bool inBoxChar = false;
|
|
};
|
|
|
|
void TypeInfo::typeScan(mlir::Type ty) {
|
|
if (auto sty = mlir::dyn_cast<fir::SequenceType>(ty)) {
|
|
assert(shape.empty() && !sty.getShape().empty());
|
|
shape = llvm::SmallVector<int64_t>(sty.getShape());
|
|
typeScan(sty.getEleTy());
|
|
} else if (auto bty = mlir::dyn_cast<fir::BoxType>(ty)) {
|
|
inBox = true;
|
|
typeScan(bty.getEleTy());
|
|
} else if (auto cty = mlir::dyn_cast<fir::ClassType>(ty)) {
|
|
inBox = true;
|
|
typeScan(cty.getEleTy());
|
|
} else if (auto cty = mlir::dyn_cast<fir::CharacterType>(ty)) {
|
|
charLen = cty.getLen();
|
|
} else if (auto cty = mlir::dyn_cast<fir::BoxCharType>(ty)) {
|
|
inBoxChar = true;
|
|
typeScan(cty.getEleTy());
|
|
} else if (auto hty = mlir::dyn_cast<fir::HeapType>(ty)) {
|
|
typeScan(hty.getEleTy());
|
|
} else if (auto pty = mlir::dyn_cast<fir::PointerType>(ty)) {
|
|
typeScan(pty.getEleTy());
|
|
} else {
|
|
// The scan ends when reaching any built-in, record or boxproc type.
|
|
assert(ty.isIntOrIndexOrFloat() || mlir::isa<mlir::ComplexType>(ty) ||
|
|
mlir::isa<fir::LogicalType>(ty) || mlir::isa<fir::RecordType>(ty) ||
|
|
mlir::isa<fir::BoxProcType>(ty));
|
|
}
|
|
}
|
|
|
|
// Create a function that performs a copy between two variables, compatible
|
|
// with their types and attributes.
|
|
static mlir::func::FuncOp
|
|
createCopyFunc(mlir::Location loc, lower::AbstractConverter &converter,
|
|
mlir::Type varType, fir::FortranVariableFlagsEnum varAttrs) {
|
|
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
|
|
mlir::ModuleOp module = builder.getModule();
|
|
mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
|
|
TypeInfo typeInfo(eleTy);
|
|
std::string copyFuncName =
|
|
fir::getTypeAsString(eleTy, builder.getKindMap(), "_copy");
|
|
|
|
if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
|
|
return decl;
|
|
|
|
// create function
|
|
mlir::OpBuilder::InsertionGuard guard(builder);
|
|
mlir::OpBuilder modBuilder(module.getBodyRegion());
|
|
llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
|
|
auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
|
|
mlir::func::FuncOp funcOp =
|
|
mlir::func::FuncOp::create(modBuilder, loc, copyFuncName, funcType);
|
|
funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
|
|
fir::factory::setInternalLinkage(funcOp);
|
|
builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
|
|
{loc, loc});
|
|
builder.setInsertionPointToStart(&funcOp.getRegion().back());
|
|
// generate body
|
|
fir::FortranVariableFlagsAttr attrs;
|
|
if (varAttrs != fir::FortranVariableFlagsEnum::None)
|
|
attrs = fir::FortranVariableFlagsAttr::get(builder.getContext(), varAttrs);
|
|
mlir::Value shape;
|
|
if (!typeInfo.isBox() && !typeInfo.getShape().empty()) {
|
|
llvm::SmallVector<mlir::Value> extents;
|
|
for (auto extent : typeInfo.getShape())
|
|
extents.push_back(
|
|
builder.createIntegerConstant(loc, builder.getIndexType(), extent));
|
|
shape = fir::ShapeOp::create(builder, loc, extents);
|
|
}
|
|
mlir::Value dst = funcOp.getArgument(0);
|
|
mlir::Value src = funcOp.getArgument(1);
|
|
llvm::SmallVector<mlir::Value> typeparams;
|
|
if (typeInfo.isBoxChar()) {
|
|
// fir.boxchar will be passed here as fir.ref<fir.boxchar>
|
|
auto loadDst = fir::LoadOp::create(builder, loc, dst);
|
|
auto loadSrc = fir::LoadOp::create(builder, loc, src);
|
|
// get the actual fir.ref<fir.char> type
|
|
mlir::Type refType =
|
|
fir::ReferenceType::get(mlir::cast<fir::BoxCharType>(eleTy).getEleTy());
|
|
auto unboxedDst = fir::UnboxCharOp::create(builder, loc, refType,
|
|
builder.getIndexType(), loadDst);
|
|
auto unboxedSrc = fir::UnboxCharOp::create(builder, loc, refType,
|
|
builder.getIndexType(), loadSrc);
|
|
// Add length to type parameters
|
|
typeparams.push_back(unboxedDst.getResult(1));
|
|
dst = unboxedDst.getResult(0);
|
|
src = unboxedSrc.getResult(0);
|
|
} else if (typeInfo.getCharLength().has_value()) {
|
|
mlir::Value charLen = builder.createIntegerConstant(
|
|
loc, builder.getCharacterLengthType(), *typeInfo.getCharLength());
|
|
typeparams.push_back(charLen);
|
|
}
|
|
auto declDst = hlfir::DeclareOp::create(
|
|
builder, loc, dst, copyFuncName + "_dst", shape, typeparams,
|
|
/*dummy_scope=*/nullptr, /*storage=*/nullptr,
|
|
/*storage_offset=*/0, attrs);
|
|
auto declSrc = hlfir::DeclareOp::create(
|
|
builder, loc, src, copyFuncName + "_src", shape, typeparams,
|
|
/*dummy_scope=*/nullptr, /*storage=*/nullptr,
|
|
/*storage_offset=*/0, attrs);
|
|
converter.copyVar(loc, declDst.getBase(), declSrc.getBase(), varAttrs);
|
|
mlir::func::ReturnOp::create(builder, loc);
|
|
return funcOp;
|
|
}
|
|
|
|
bool ClauseProcessor::processCopyprivate(
|
|
mlir::Location currentLocation,
|
|
mlir::omp::CopyprivateClauseOps &result) const {
|
|
auto addCopyPrivateVar = [&](semantics::Symbol *sym) {
|
|
mlir::Value symVal = converter.getSymbolAddress(*sym);
|
|
auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>();
|
|
if (!declOp)
|
|
fir::emitFatalError(currentLocation,
|
|
"COPYPRIVATE is supported only in HLFIR mode");
|
|
symVal = declOp.getBase();
|
|
mlir::Type symType = symVal.getType();
|
|
fir::FortranVariableFlagsEnum attrs =
|
|
declOp.getFortranAttrs().has_value()
|
|
? *declOp.getFortranAttrs()
|
|
: fir::FortranVariableFlagsEnum::None;
|
|
mlir::Value cpVar = symVal;
|
|
|
|
// CopyPrivate variables must be passed by reference. However, in the case
|
|
// of assumed shapes/vla the type is not a !fir.ref, but a !fir.box.
|
|
// In the case of character types, the passed in type can also be
|
|
// !fir.boxchar. In these cases to retrieve the appropriate
|
|
// !fir.ref<!fir.box<...>> or !fir.ref<!fir.boxchar<..>> to access the data
|
|
// we need we must perform an alloca and then store to it and retrieve the
|
|
// data from the new alloca.
|
|
if (mlir::isa<fir::BaseBoxType>(symType) ||
|
|
mlir::isa<fir::BoxCharType>(symType)) {
|
|
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
|
|
auto alloca = fir::AllocaOp::create(builder, currentLocation, symType);
|
|
fir::StoreOp::create(builder, currentLocation, symVal, alloca);
|
|
cpVar = alloca;
|
|
}
|
|
|
|
result.copyprivateVars.push_back(cpVar);
|
|
mlir::func::FuncOp funcOp =
|
|
createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
|
|
result.copyprivateSyms.push_back(mlir::SymbolRefAttr::get(funcOp));
|
|
};
|
|
|
|
bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
|
|
[&](const clause::Copyprivate &clause, const parser::CharBlock &) {
|
|
for (const Object &object : clause.v) {
|
|
semantics::Symbol *sym = object.sym();
|
|
if (const auto *commonDetails =
|
|
sym->detailsIf<semantics::CommonBlockDetails>()) {
|
|
for (const auto &mem : commonDetails->objects())
|
|
addCopyPrivateVar(&*mem);
|
|
break;
|
|
}
|
|
addCopyPrivateVar(sym);
|
|
}
|
|
});
|
|
|
|
return hasCopyPrivate;
|
|
}
|
|
|
|
template <typename T>
|
|
static bool isVectorSubscript(const evaluate::Expr<T> &expr) {
|
|
if (std::optional<evaluate::DataRef> dataRef{evaluate::ExtractDataRef(expr)})
|
|
if (const auto *arrayRef = std::get_if<evaluate::ArrayRef>(&dataRef->u))
|
|
for (const evaluate::Subscript &subscript : arrayRef->subscript())
|
|
if (std::holds_alternative<evaluate::IndirectSubscriptIntegerExpr>(
|
|
subscript.u))
|
|
if (subscript.Rank() > 0)
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processDefaultMap(lower::StatementContext &stmtCtx,
|
|
DefaultMapsTy &result) const {
|
|
auto process = [&](const omp::clause::Defaultmap &clause,
|
|
const parser::CharBlock &) {
|
|
using Defmap = omp::clause::Defaultmap;
|
|
clause::Defaultmap::VariableCategory variableCategory =
|
|
Defmap::VariableCategory::All;
|
|
// Variable Category is optional, if not specified defaults to all.
|
|
// Multiples of the same category are illegal as are any other
|
|
// defaultmaps being specified when a user specified all is in place,
|
|
// however, this should be handled earlier during semantics.
|
|
if (auto varCat =
|
|
std::get<std::optional<Defmap::VariableCategory>>(clause.t))
|
|
variableCategory = varCat.value();
|
|
auto behaviour = std::get<Defmap::ImplicitBehavior>(clause.t);
|
|
result[variableCategory] = behaviour;
|
|
};
|
|
return findRepeatableClause<omp::clause::Defaultmap>(process);
|
|
}
|
|
|
|
bool ClauseProcessor::processDepend(lower::SymMap &symMap,
|
|
lower::StatementContext &stmtCtx,
|
|
mlir::omp::DependClauseOps &result) const {
|
|
auto process = [&](const omp::clause::Depend &clause,
|
|
const parser::CharBlock &) {
|
|
using Depend = omp::clause::Depend;
|
|
if (!std::holds_alternative<Depend::TaskDep>(clause.u)) {
|
|
TODO(converter.getCurrentLocation(),
|
|
"DEPEND clause with SINK or SOURCE is not supported yet");
|
|
}
|
|
auto &taskDep = std::get<Depend::TaskDep>(clause.u);
|
|
auto depType = std::get<clause::DependenceType>(taskDep.t);
|
|
auto &objects = std::get<omp::ObjectList>(taskDep.t);
|
|
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
|
|
|
|
if (std::get<std::optional<omp::clause::Iterator>>(taskDep.t)) {
|
|
TODO(converter.getCurrentLocation(),
|
|
"Support for iterator modifiers is not implemented yet");
|
|
}
|
|
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
|
|
genDependKindAttr(converter, depType);
|
|
result.dependKinds.append(objects.size(), dependTypeOperand);
|
|
|
|
for (const omp::Object &object : objects) {
|
|
assert(object.ref() && "Expecting designator");
|
|
mlir::Value dependVar;
|
|
SomeExpr expr = *object.ref();
|
|
|
|
if (evaluate::IsArrayElement(expr) || evaluate::ExtractSubstring(expr)) {
|
|
// Array Section or character (sub)string
|
|
if (isVectorSubscript(expr)) {
|
|
// OpenMP needs the address of the first indexed element (required by
|
|
// the standard to be the lowest index) to identify the dependency. We
|
|
// don't need an accurate length for the array section because the
|
|
// OpenMP standard forbids overlapping array sections.
|
|
dependVar = genVectorSubscriptedDesignatorFirstElementAddress(
|
|
converter.getCurrentLocation(), converter, expr, symMap, stmtCtx);
|
|
} else {
|
|
// Ordinary array section e.g. A(1:512:2)
|
|
hlfir::EntityWithAttributes entity = convertExprToHLFIR(
|
|
converter.getCurrentLocation(), converter, expr, symMap, stmtCtx);
|
|
dependVar = entity.getBase();
|
|
}
|
|
} else if (evaluate::isStructureComponent(expr) ||
|
|
evaluate::ExtractComplexPart(expr)) {
|
|
SomeExpr expr = *object.ref();
|
|
hlfir::EntityWithAttributes entity = convertExprToHLFIR(
|
|
converter.getCurrentLocation(), converter, expr, symMap, stmtCtx);
|
|
dependVar = entity.getBase();
|
|
} else {
|
|
semantics::Symbol *sym = object.sym();
|
|
dependVar = converter.getSymbolAddress(*sym);
|
|
}
|
|
|
|
// If we pass a mutable box e.g. !fir.ref<!fir.box<!fir.heap<...>>> then
|
|
// the runtime will use the address of the box not the address of the
|
|
// data. Flang generates a lot of memcpys between different box
|
|
// allocations so this is not a reliable way to identify the dependency.
|
|
if (auto ref = mlir::dyn_cast<fir::ReferenceType>(dependVar.getType()))
|
|
if (fir::isa_box_type(ref.getElementType()))
|
|
dependVar = fir::LoadOp::create(
|
|
builder, converter.getCurrentLocation(), dependVar);
|
|
|
|
// The openmp dialect doesn't know what to do with boxes (and it would
|
|
// break layering to teach it about them). The dependency variable can be
|
|
// a box because it was an array section or because the original symbol
|
|
// was mapped to a box.
|
|
// Getting the address of the box data is okay because all the runtime
|
|
// ultimately cares about is the base address of the array.
|
|
if (fir::isa_box_type(dependVar.getType()))
|
|
dependVar = fir::BoxAddrOp::create(
|
|
builder, converter.getCurrentLocation(), dependVar);
|
|
|
|
result.dependVars.push_back(dependVar);
|
|
}
|
|
};
|
|
|
|
return findRepeatableClause<omp::clause::Depend>(process);
|
|
}
|
|
|
|
bool ClauseProcessor::processGrainsize(
|
|
lower::StatementContext &stmtCtx,
|
|
mlir::omp::GrainsizeClauseOps &result) const {
|
|
using Grainsize = omp::clause::Grainsize;
|
|
if (auto *clause = findUniqueClause<Grainsize>()) {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
mlir::MLIRContext *context = firOpBuilder.getContext();
|
|
const auto &modifier =
|
|
std::get<std::optional<Grainsize::Prescriptiveness>>(clause->t);
|
|
if (modifier && *modifier == Grainsize::Prescriptiveness::Strict) {
|
|
result.grainsizeMod = mlir::omp::ClauseGrainsizeTypeAttr::get(
|
|
context, mlir::omp::ClauseGrainsizeType::Strict);
|
|
}
|
|
const auto &grainsizeExpr = std::get<omp::SomeExpr>(clause->t);
|
|
result.grainsize =
|
|
fir::getBase(converter.genExprValue(grainsizeExpr, stmtCtx));
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool ClauseProcessor::processHasDeviceAddr(
|
|
lower::StatementContext &stmtCtx, mlir::omp::HasDeviceAddrClauseOps &result,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceSyms) const {
|
|
// For HAS_DEVICE_ADDR objects, implicitly map the top-level entities.
|
|
// Their address (or the whole descriptor, if the entity had one) will be
|
|
// passed to the target region.
|
|
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
|
|
bool clauseFound = findRepeatableClause<omp::clause::HasDeviceAddr>(
|
|
[&](const omp::clause::HasDeviceAddr &clause,
|
|
const parser::CharBlock &source) {
|
|
mlir::Location location = converter.genLocation(source);
|
|
mlir::omp::ClauseMapFlags mapTypeBits =
|
|
mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::implicit;
|
|
omp::ObjectList baseObjects;
|
|
llvm::transform(clause.v, std::back_inserter(baseObjects),
|
|
[&](const omp::Object &object) {
|
|
if (auto maybeBase = getBaseObject(object, semaCtx))
|
|
return *maybeBase;
|
|
return object;
|
|
});
|
|
processMapObjects(stmtCtx, location, baseObjects, mapTypeBits,
|
|
parentMemberIndices, result.hasDeviceAddrVars,
|
|
hasDeviceSyms);
|
|
});
|
|
|
|
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
|
|
result.hasDeviceAddrVars, hasDeviceSyms);
|
|
return clauseFound;
|
|
}
|
|
|
|
bool ClauseProcessor::processIf(
|
|
omp::clause::If::DirectiveNameModifier directiveName,
|
|
mlir::omp::IfClauseOps &result) const {
|
|
bool found = false;
|
|
findRepeatableClause<omp::clause::If>([&](const omp::clause::If &clause,
|
|
const parser::CharBlock &source) {
|
|
mlir::Location clauseLocation = converter.genLocation(source);
|
|
mlir::Value operand =
|
|
getIfClauseOperand(converter, clause, directiveName, clauseLocation);
|
|
// Assume that, at most, a single 'if' clause will be applicable to the
|
|
// given directive.
|
|
if (operand) {
|
|
result.ifExpr = operand;
|
|
found = true;
|
|
}
|
|
});
|
|
return found;
|
|
}
|
|
|
|
template <typename T>
|
|
void collectReductionSyms(
|
|
const T &reduction,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
|
|
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
|
|
for (const Object &object : objectList) {
|
|
const semantics::Symbol *symbol = object.sym();
|
|
reductionSyms.push_back(symbol);
|
|
}
|
|
}
|
|
|
|
bool ClauseProcessor::processInReduction(
|
|
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
|
|
return findRepeatableClause<omp::clause::InReduction>(
|
|
[&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
|
|
llvm::SmallVector<mlir::Value> inReductionVars;
|
|
llvm::SmallVector<bool> inReduceVarByRef;
|
|
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
|
|
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
|
|
collectReductionSyms(clause, inReductionSyms);
|
|
|
|
ReductionProcessor rp;
|
|
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
|
|
currentLocation, converter,
|
|
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
|
|
inReductionVars, inReduceVarByRef, inReductionDeclSymbols,
|
|
inReductionSyms))
|
|
TODO(currentLocation, "Lowering unrecognised reduction type");
|
|
|
|
// Copy local lists into the output.
|
|
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
|
|
llvm::copy(inReduceVarByRef,
|
|
std::back_inserter(result.inReductionByref));
|
|
llvm::copy(inReductionDeclSymbols,
|
|
std::back_inserter(result.inReductionSyms));
|
|
llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processIsDevicePtr(
|
|
mlir::omp::IsDevicePtrClauseOps &result,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
|
|
return findRepeatableClause<omp::clause::IsDevicePtr>(
|
|
[&](const omp::clause::IsDevicePtr &devPtrClause,
|
|
const parser::CharBlock &) {
|
|
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
|
|
isDeviceSyms);
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
|
|
lower::StatementContext stmtCtx;
|
|
return findRepeatableClause<
|
|
omp::clause::Linear>([&](const omp::clause::Linear &clause,
|
|
const parser::CharBlock &) {
|
|
auto &objects = std::get<omp::ObjectList>(clause.t);
|
|
for (const omp::Object &object : objects) {
|
|
semantics::Symbol *sym = object.sym();
|
|
const mlir::Value variable = converter.getSymbolAddress(*sym);
|
|
result.linearVars.push_back(variable);
|
|
}
|
|
if (objects.size()) {
|
|
if (auto &mod =
|
|
std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
|
|
clause.t)) {
|
|
mlir::Value operand =
|
|
fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
|
|
result.linearStepVars.append(objects.size(), operand);
|
|
} else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
|
|
clause.t)) {
|
|
mlir::Location currentLocation = converter.getCurrentLocation();
|
|
TODO(currentLocation, "Linear modifiers not yet implemented");
|
|
} else {
|
|
// If nothing is present, add the default step of 1.
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
mlir::Location currentLocation = converter.getCurrentLocation();
|
|
mlir::Value operand = firOpBuilder.createIntegerConstant(
|
|
currentLocation, firOpBuilder.getI32Type(), 1);
|
|
result.linearStepVars.append(objects.size(), operand);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processLink(
|
|
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
|
|
return findRepeatableClause<omp::clause::Link>(
|
|
[&](const omp::clause::Link &clause, const parser::CharBlock &) {
|
|
// Case: declare target link(var1, var2)...
|
|
gatherFuncAndVarSyms(
|
|
clause.v, mlir::omp::DeclareTargetCaptureClause::link, result,
|
|
/*automap=*/false);
|
|
});
|
|
}
|
|
|
|
void ClauseProcessor::processMapObjects(
|
|
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
|
|
const omp::ObjectList &objects, mlir::omp::ClauseMapFlags mapTypeBits,
|
|
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
|
|
llvm::SmallVectorImpl<mlir::Value> &mapVars,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
|
|
llvm::StringRef mapperIdNameRef) const {
|
|
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
|
|
|
auto getSymbolDerivedType = [](const semantics::Symbol &symbol)
|
|
-> const semantics::DerivedTypeSpec * {
|
|
const semantics::Symbol &ultimate = symbol.GetUltimate();
|
|
if (const semantics::DeclTypeSpec *declType = ultimate.GetType())
|
|
if (const auto *derived = declType->AsDerived())
|
|
return derived;
|
|
return nullptr;
|
|
};
|
|
|
|
auto addImplicitMapper = [&](const omp::Object &object,
|
|
std::string &mapperIdName,
|
|
bool allowGenerate) -> mlir::FlatSymbolRefAttr {
|
|
if (mapperIdName.empty())
|
|
return mlir::FlatSymbolRefAttr();
|
|
|
|
if (converter.getModuleOp().lookupSymbol(mapperIdName))
|
|
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
|
|
mapperIdName);
|
|
|
|
if (!allowGenerate)
|
|
return mlir::FlatSymbolRefAttr();
|
|
|
|
const semantics::DerivedTypeSpec *typeSpec =
|
|
getSymbolDerivedType(*object.sym());
|
|
if (!typeSpec && object.sym()->owner().IsDerivedType())
|
|
typeSpec = object.sym()->owner().derivedTypeSpec();
|
|
|
|
if (!typeSpec)
|
|
return mlir::FlatSymbolRefAttr();
|
|
|
|
mlir::Type type = converter.genType(*typeSpec);
|
|
auto recordType = mlir::dyn_cast<fir::RecordType>(type);
|
|
if (!recordType)
|
|
return mlir::FlatSymbolRefAttr();
|
|
|
|
return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation,
|
|
recordType, mapperIdName);
|
|
};
|
|
|
|
auto getDefaultMapperID =
|
|
[&](const semantics::DerivedTypeSpec *typeSpec) -> std::string {
|
|
if (mlir::isa<mlir::omp::DeclareMapperOp>(
|
|
firOpBuilder.getRegion().getParentOp()) ||
|
|
!typeSpec)
|
|
return {};
|
|
|
|
std::string mapperIdName =
|
|
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
|
|
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) {
|
|
mapperIdName =
|
|
converter.mangleName(mapperIdName, sym->GetUltimate().owner());
|
|
} else {
|
|
mapperIdName = converter.mangleName(mapperIdName, *typeSpec->GetScope());
|
|
}
|
|
|
|
// Make sure we don't return a mapper to self.
|
|
if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
|
|
firOpBuilder.getRegion().getParentOp()))
|
|
if (mapperIdName == declMapOp.getSymName())
|
|
return {};
|
|
return mapperIdName;
|
|
};
|
|
|
|
// Create the mapper symbol from its name, if specified.
|
|
mlir::FlatSymbolRefAttr mapperId;
|
|
if (!mapperIdNameRef.empty() && !objects.empty() &&
|
|
mapperIdNameRef != "__implicit_mapper") {
|
|
std::string mapperIdName = mapperIdNameRef.str();
|
|
const omp::Object &object = objects.front();
|
|
if (mapperIdNameRef == "default") {
|
|
const semantics::DerivedTypeSpec *typeSpec =
|
|
getSymbolDerivedType(*object.sym());
|
|
if (!typeSpec && object.sym()->owner().IsDerivedType())
|
|
typeSpec = object.sym()->owner().derivedTypeSpec();
|
|
mapperIdName = getDefaultMapperID(typeSpec);
|
|
}
|
|
assert(converter.getModuleOp().lookupSymbol(mapperIdName) &&
|
|
"mapper not found");
|
|
mapperId =
|
|
mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), mapperIdName);
|
|
}
|
|
|
|
for (const omp::Object &object : objects) {
|
|
llvm::SmallVector<mlir::Value> bounds;
|
|
std::stringstream asFortran;
|
|
std::optional<omp::Object> parentObj;
|
|
|
|
fir::factory::AddrAndBoundsInfo info =
|
|
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
|
|
mlir::omp::MapBoundsType>(
|
|
converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
|
|
object.ref(), clauseLocation, asFortran, bounds,
|
|
treatIndexAsSection);
|
|
|
|
mlir::Value baseOp = info.rawInput;
|
|
if (object.sym()->owner().IsDerivedType()) {
|
|
omp::ObjectList objectList = gatherObjectsOf(object, semaCtx);
|
|
assert(!objectList.empty() &&
|
|
"could not find parent objects of derived type member");
|
|
parentObj = objectList[0];
|
|
parentMemberIndices.emplace(parentObj.value(),
|
|
OmpMapParentAndMemberData{});
|
|
|
|
if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) {
|
|
llvm::SmallVector<int64_t> indices;
|
|
generateMemberPlacementIndices(object, indices, semaCtx);
|
|
baseOp = createParentSymAndGenIntermediateMaps(
|
|
clauseLocation, converter, semaCtx, stmtCtx, objectList, indices,
|
|
parentMemberIndices[parentObj.value()], asFortran.str(),
|
|
mapTypeBits);
|
|
}
|
|
}
|
|
|
|
const semantics::DerivedTypeSpec *objectTypeSpec =
|
|
getSymbolDerivedType(*object.sym());
|
|
|
|
if (mapperIdNameRef == "__implicit_mapper") {
|
|
if (parentObj.has_value()) {
|
|
mapperId = mlir::FlatSymbolRefAttr();
|
|
} else if (objectTypeSpec) {
|
|
std::string mapperIdName = getDefaultMapperID(objectTypeSpec);
|
|
bool needsDefaultMapper =
|
|
semantics::IsAllocatableOrObjectPointer(object.sym()) ||
|
|
requiresImplicitDefaultDeclareMapper(*objectTypeSpec);
|
|
if (!mapperIdName.empty())
|
|
mapperId = addImplicitMapper(object, mapperIdName,
|
|
/*allowGenerate=*/needsDefaultMapper);
|
|
else
|
|
mapperId = mlir::FlatSymbolRefAttr();
|
|
} else {
|
|
mapperId = mlir::FlatSymbolRefAttr();
|
|
}
|
|
}
|
|
|
|
// Explicit map captures are captured ByRef by default,
|
|
// optimisation passes may alter this to ByCopy or other capture
|
|
// types to optimise
|
|
auto location = mlir::NameLoc::get(
|
|
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
|
|
baseOp.getLoc());
|
|
mlir::omp::MapInfoOp mapOp = utils::openmp::createMapInfoOp(
|
|
firOpBuilder, location, baseOp,
|
|
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
|
|
/*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{}, mapTypeBits,
|
|
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType(),
|
|
/*partialMap=*/false, mapperId);
|
|
|
|
if (parentObj.has_value()) {
|
|
parentMemberIndices[parentObj.value()].addChildIndexAndMapToParent(
|
|
object, mapOp, semaCtx);
|
|
} else {
|
|
mapVars.push_back(mapOp);
|
|
mapSyms.push_back(object.sym());
|
|
}
|
|
}
|
|
}
|
|
|
|
bool ClauseProcessor::processMap(
|
|
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
|
|
mlir::omp::MapClauseOps &result, llvm::omp::Directive directive,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms) const {
|
|
// We always require tracking of symbols, even if the caller does not,
|
|
// so we create an optionally used local set of symbols when the mapSyms
|
|
// argument is not present.
|
|
llvm::SmallVector<const semantics::Symbol *> localMapSyms;
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
|
|
mapSyms ? mapSyms : &localMapSyms;
|
|
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
|
|
|
|
auto process = [&](const omp::clause::Map &clause,
|
|
const parser::CharBlock &source) {
|
|
using Map = omp::clause::Map;
|
|
mlir::Location clauseLocation = converter.genLocation(source);
|
|
const auto &[mapType, typeMods, attachMod, refMod, mappers, iterator,
|
|
objects] = clause.t;
|
|
if (attachMod)
|
|
TODO(currentLocation, "ATTACH modifier is not implemented yet");
|
|
mlir::omp::ClauseMapFlags mapTypeBits = mlir::omp::ClauseMapFlags::none;
|
|
std::string mapperIdName = "__implicit_mapper";
|
|
// If the map type is specified, then process it else set the appropriate
|
|
// default value
|
|
Map::MapType type;
|
|
if (directive == llvm::omp::Directive::OMPD_target_enter_data &&
|
|
semaCtx.langOptions().OpenMPVersion >= 52)
|
|
type = mapType.value_or(Map::MapType::To);
|
|
else if (directive == llvm::omp::Directive::OMPD_target_exit_data &&
|
|
semaCtx.langOptions().OpenMPVersion >= 52)
|
|
type = mapType.value_or(Map::MapType::From);
|
|
else
|
|
type = mapType.value_or(Map::MapType::Tofrom);
|
|
|
|
switch (type) {
|
|
case Map::MapType::To:
|
|
mapTypeBits |= mlir::omp::ClauseMapFlags::to;
|
|
break;
|
|
case Map::MapType::From:
|
|
mapTypeBits |= mlir::omp::ClauseMapFlags::from;
|
|
break;
|
|
case Map::MapType::Tofrom:
|
|
mapTypeBits |=
|
|
mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::from;
|
|
break;
|
|
case Map::MapType::Storage:
|
|
mapTypeBits |= mlir::omp::ClauseMapFlags::storage;
|
|
break;
|
|
}
|
|
|
|
if (typeMods) {
|
|
// TODO: Still requires "self" modifier, an OpenMP 6.0+ feature
|
|
if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Always))
|
|
mapTypeBits |= mlir::omp::ClauseMapFlags::always;
|
|
if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Present))
|
|
mapTypeBits |= mlir::omp::ClauseMapFlags::present;
|
|
if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Close))
|
|
mapTypeBits |= mlir::omp::ClauseMapFlags::close;
|
|
if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Delete))
|
|
mapTypeBits |= mlir::omp::ClauseMapFlags::del;
|
|
if (llvm::is_contained(*typeMods, Map::MapTypeModifier::OmpxHold))
|
|
mapTypeBits |= mlir::omp::ClauseMapFlags::ompx_hold;
|
|
}
|
|
|
|
if (iterator) {
|
|
TODO(currentLocation,
|
|
"Support for iterator modifiers is not implemented yet");
|
|
}
|
|
if (mappers) {
|
|
assert(mappers->size() == 1 && "more than one mapper");
|
|
const semantics::Symbol *mapperSym = mappers->front().v.id().symbol;
|
|
mapperIdName = mapperSym->name().ToString();
|
|
if (mapperIdName != "default") {
|
|
// Mangle with the ultimate owner so that use-associated mapper
|
|
// identifiers resolve to the same symbol as their defining scope.
|
|
const semantics::Symbol &ultimate = mapperSym->GetUltimate();
|
|
mapperIdName = converter.mangleName(mapperIdName, ultimate.owner());
|
|
}
|
|
}
|
|
|
|
processMapObjects(stmtCtx, clauseLocation,
|
|
std::get<omp::ObjectList>(clause.t), mapTypeBits,
|
|
parentMemberIndices, result.mapVars, *ptrMapSyms,
|
|
mapperIdName);
|
|
};
|
|
|
|
bool clauseFound = findRepeatableClause<omp::clause::Map>(process);
|
|
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
|
|
result.mapVars, *ptrMapSyms);
|
|
|
|
return clauseFound;
|
|
}
|
|
|
|
bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
|
|
mlir::omp::MapClauseOps &result) {
|
|
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
|
|
llvm::SmallVector<const semantics::Symbol *> mapSymbols;
|
|
|
|
auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) {
|
|
mlir::Location clauseLocation = converter.genLocation(source);
|
|
const auto &[expectation, mapper, iterator, objects] = clause.t;
|
|
|
|
// TODO Support motion modifiers: mapper, iterator.
|
|
if (mapper) {
|
|
TODO(clauseLocation, "Mapper modifier is not supported yet");
|
|
} else if (iterator) {
|
|
TODO(clauseLocation, "Iterator modifier is not supported yet");
|
|
}
|
|
|
|
mlir::omp::ClauseMapFlags mapTypeBits =
|
|
std::is_same_v<llvm::remove_cvref_t<decltype(clause)>, omp::clause::To>
|
|
? mlir::omp::ClauseMapFlags::to
|
|
: mlir::omp::ClauseMapFlags::from;
|
|
if (expectation && *expectation == omp::clause::To::Expectation::Present)
|
|
mapTypeBits |= mlir::omp::ClauseMapFlags::present;
|
|
processMapObjects(stmtCtx, clauseLocation, objects, mapTypeBits,
|
|
parentMemberIndices, result.mapVars, mapSymbols);
|
|
};
|
|
|
|
bool clauseFound = findRepeatableClause<omp::clause::To>(callbackFn);
|
|
clauseFound =
|
|
findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;
|
|
|
|
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
|
|
result.mapVars, mapSymbols);
|
|
|
|
return clauseFound;
|
|
}
|
|
|
|
bool ClauseProcessor::processNontemporal(
|
|
mlir::omp::NontemporalClauseOps &result) const {
|
|
return findRepeatableClause<omp::clause::Nontemporal>(
|
|
[&](const omp::clause::Nontemporal &clause, const parser::CharBlock &) {
|
|
for (const Object &object : clause.v) {
|
|
semantics::Symbol *sym = object.sym();
|
|
mlir::Value symVal = converter.getSymbolAddress(*sym);
|
|
result.nontemporalVars.push_back(symVal);
|
|
}
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processReduction(
|
|
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
|
|
return findRepeatableClause<omp::clause::Reduction>(
|
|
[&](const omp::clause::Reduction &clause, const parser::CharBlock &) {
|
|
llvm::SmallVector<mlir::Value> reductionVars;
|
|
llvm::SmallVector<bool> reduceVarByRef;
|
|
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
|
|
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
|
|
collectReductionSyms(clause, reductionSyms);
|
|
|
|
auto mod = std::get<std::optional<ReductionModifier>>(clause.t);
|
|
if (mod.has_value()) {
|
|
if (mod.value() == ReductionModifier::Task)
|
|
TODO(currentLocation, "Reduction modifier `task` is not supported");
|
|
else
|
|
result.reductionMod = mlir::omp::ReductionModifierAttr::get(
|
|
converter.getFirOpBuilder().getContext(),
|
|
translateReductionModifier(mod.value()));
|
|
}
|
|
|
|
ReductionProcessor rp;
|
|
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
|
|
currentLocation, converter,
|
|
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
|
|
reductionVars, reduceVarByRef, reductionDeclSymbols,
|
|
reductionSyms))
|
|
TODO(currentLocation, "Lowering unrecognised reduction type");
|
|
// Copy local lists into the output.
|
|
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
|
|
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
|
|
llvm::copy(reductionDeclSymbols,
|
|
std::back_inserter(result.reductionSyms));
|
|
llvm::copy(reductionSyms, std::back_inserter(outReductionSyms));
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processTaskReduction(
|
|
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
|
|
return findRepeatableClause<omp::clause::TaskReduction>(
|
|
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
|
|
llvm::SmallVector<mlir::Value> taskReductionVars;
|
|
llvm::SmallVector<bool> taskReduceVarByRef;
|
|
llvm::SmallVector<mlir::Attribute> taskReductionDeclSymbols;
|
|
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
|
|
collectReductionSyms(clause, taskReductionSyms);
|
|
|
|
ReductionProcessor rp;
|
|
if (!rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
|
|
currentLocation, converter,
|
|
std::get<typename omp::clause::ReductionOperatorList>(clause.t),
|
|
taskReductionVars, taskReduceVarByRef, taskReductionDeclSymbols,
|
|
taskReductionSyms))
|
|
TODO(currentLocation, "Lowering unrecognised reduction type");
|
|
// Copy local lists into the output.
|
|
llvm::copy(taskReductionVars,
|
|
std::back_inserter(result.taskReductionVars));
|
|
llvm::copy(taskReduceVarByRef,
|
|
std::back_inserter(result.taskReductionByref));
|
|
llvm::copy(taskReductionDeclSymbols,
|
|
std::back_inserter(result.taskReductionSyms));
|
|
llvm::copy(taskReductionSyms, std::back_inserter(outReductionSyms));
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processTo(
|
|
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
|
|
return findRepeatableClause<omp::clause::To>(
|
|
[&](const omp::clause::To &clause, const parser::CharBlock &) {
|
|
// Case: declare target to(func, var1, var2)...
|
|
gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
|
|
mlir::omp::DeclareTargetCaptureClause::to, result,
|
|
/*automap=*/false);
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processEnter(
|
|
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
|
|
return findRepeatableClause<omp::clause::Enter>(
|
|
[&](const omp::clause::Enter &clause, const parser::CharBlock &source) {
|
|
bool automap =
|
|
std::get<std::optional<omp::clause::Enter::Modifier>>(clause.t)
|
|
.has_value();
|
|
// Case: declare target enter(func, var1, var2)...
|
|
gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
|
|
mlir::omp::DeclareTargetCaptureClause::enter,
|
|
result, automap);
|
|
});
|
|
}
|
|
|
|
bool ClauseProcessor::processUseDeviceAddr(
|
|
lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
|
|
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
|
|
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
|
|
[&](const omp::clause::UseDeviceAddr &clause,
|
|
const parser::CharBlock &source) {
|
|
mlir::Location location = converter.genLocation(source);
|
|
mlir::omp::ClauseMapFlags mapTypeBits =
|
|
mlir::omp::ClauseMapFlags::return_param;
|
|
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
|
|
parentMemberIndices, result.useDeviceAddrVars,
|
|
useDeviceSyms);
|
|
});
|
|
|
|
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
|
|
result.useDeviceAddrVars, useDeviceSyms);
|
|
return clauseFound;
|
|
}
|
|
|
|
bool ClauseProcessor::processUseDevicePtr(
|
|
lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
|
|
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
|
|
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
|
|
|
|
bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
|
|
[&](const omp::clause::UseDevicePtr &clause,
|
|
const parser::CharBlock &source) {
|
|
mlir::Location location = converter.genLocation(source);
|
|
mlir::omp::ClauseMapFlags mapTypeBits =
|
|
mlir::omp::ClauseMapFlags::return_param;
|
|
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
|
|
parentMemberIndices, result.useDevicePtrVars,
|
|
useDeviceSyms);
|
|
});
|
|
|
|
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
|
|
result.useDevicePtrVars, useDeviceSyms);
|
|
return clauseFound;
|
|
}
|
|
|
|
} // namespace omp
|
|
} // namespace lower
|
|
} // namespace Fortran
|