[OpenMP][flang] Lowering of OpenMP custom reductions to MLIR (#168417)

This patch add support for lowering of custom reductions to MLIR. It
also enhances the capability of the pass to automatically mark functions
as "declare target" by traversing custom reduction initializers and
combiners.
This commit is contained in:
Jan Leyonberg
2025-11-24 16:00:46 -05:00
committed by GitHub
parent f581d8ad8f
commit 3e86f05621
16 changed files with 772 additions and 106 deletions

View File

@@ -40,6 +40,13 @@ namespace omp {
class ReductionProcessor {
public:
using GenInitValueCBTy =
std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value ompOrig)>;
using GenCombinerCBTy = std::function<void(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
mlir::Value op1, mlir::Value op2, bool isByRef)>;
// TODO: Move this enumeration to the OpenMP dialect
enum ReductionIdentifier {
ID,
@@ -58,6 +65,9 @@ public:
IEOR
};
static bool doReductionByRef(mlir::Type reductionType);
static bool doReductionByRef(mlir::Value reductionVar);
static ReductionIdentifier
getReductionType(const omp::clause::ProcedureDesignator &pd);
@@ -109,6 +119,14 @@ public:
ReductionIdentifier redId,
mlir::Type type, mlir::Value op1,
mlir::Value op2);
/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The init and combiner regions are generated by the callback
/// functions genCombinerCB and genInitValueCB.
template <typename DeclareRedType>
static DeclareRedType createDeclareReductionHelper(
AbstractConverter &converter, llvm::StringRef reductionOpName,
mlir::Type type, mlir::Location loc, bool isByRef,
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);
/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The declaration has a constant initializer with the neutral

View File

@@ -13,6 +13,7 @@
#include "ClauseProcessor.h"
#include "Utils.h"
#include "flang/Lower/ConvertCall.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/PFTBuilder.h"
@@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive(
return false;
}
bool ClauseProcessor::processInitializer(
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value ompOrig) {
lower::SymMapScope scope(symMap);
const parser::OmpInitializerExpression &iexpr = inp.v.v;
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
mlir::Value ompPrivVar;
for (const parser::OmpStylizedDeclaration &decl : declList) {
auto &name = std::get<parser::ObjectName>(decl.var.t);
assert(name.symbol && "Name does not have a symbol");
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
fir::StoreOp::create(builder, loc, ompOrig, addr);
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(),
*name.symbol, extraFlags);
auto declareOp = hlfir::DeclareOp::create(
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
0, attributes);
if (name.ToString() == "omp_priv")
ompPrivVar = declareOp.getResult(0);
symMap.addVariableDefinition(*name.symbol, declareOp);
}
// Lower the expression/function call
lower::StatementContext stmtCtx;
mlir::Value result = common::visit(
common::visitors{
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
convertCallToHLFIR(loc, converter, procRef, std::nullopt,
symMap, stmtCtx);
auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
return privVal;
},
[&](const auto &expr) -> mlir::Value {
mlir::Value exprResult = fir::getBase(convertExprToValue(
loc, converter, clause->v, symMap, stmtCtx));
// Conversion can either give a value or a refrence to a value,
// we need to return the reduction type, so an optional load may
// be generated.
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
exprResult.getType()))
if (ompPrivVar.getType() == refType)
exprResult = fir::LoadOp::create(builder, loc, exprResult);
return exprResult;
}},
clause->v.u);
stmtCtx.finalizeAndPop();
return result;
};
return true;
}
return false;
}
bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);

View File

@@ -18,6 +18,7 @@
#include "flang/Lower/Bridge.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
@@ -88,6 +89,9 @@ public:
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processInclusive(mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const;
bool processInitializer(
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNogroup(mlir::omp::NogroupClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;

View File

@@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp,
Initializer make(const parser::OmpClause::Initializer &inp,
semantics::SemanticsContext &semaCtx) {
llvm_unreachable("Empty: initializer");
const parser::OmpInitializerExpression &iexpr = inp.v.v;
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
const parser::OmpStylizedInstance::Instance &instance =
std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
auto &expr = std::get<parser::Expr>(as->t);
return Initializer{makeExpr(expr, semaCtx)};
} else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
if (call->typedCall) {
const auto &procRef = *call->typedCall;
semantics::SomeExpr evalProcRef{procRef};
return Initializer{evalProcRef};
}
}
llvm_unreachable("Unexpected initializer");
}
InReduction make(const parser::OmpClause::InReduction &inp,

View File

@@ -18,12 +18,15 @@
#include "Decomposer.h"
#include "Utils.h"
#include "flang/Common/idioms.h"
#include "flang/Evaluate/type.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// TODO: Add private syms and vars.
args.reduction.syms = reductionSyms;
args.reduction.vars = clauseOps.reductionVars;
return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_teams)
@@ -3570,12 +3572,156 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
}
static ReductionProcessor::GenCombinerCBTy
processReductionCombiner(lower::AbstractConverter &converter,
lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
const parser::OmpReductionSpecifier &specifier) {
ReductionProcessor::GenCombinerCBTy genCombinerCB;
const auto &combinerExpression =
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
.value();
const parser::OmpStylizedInstance &combinerInstance =
combinerExpression.v.front();
const parser::OmpStylizedInstance::Instance &instance =
std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u);
if (!as) {
TODO(converter.getCurrentLocation(),
"A combiner that is a subroutine call is not yet supported");
}
auto &expr = std::get<parser::Expr>(as->t);
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value lhs, mlir::Value rhs,
bool isByRef) {
const auto &evalExpr = makeExpr(expr, semaCtx);
lower::SymMapScope scope(symTable);
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
for (const parser::OmpStylizedDeclaration &decl : declList) {
auto &name = std::get<parser::ObjectName>(decl.var.t);
mlir::Value addr = lhs;
mlir::Type type = lhs.getType();
bool isRhs = name.ToString() == std::string("omp_in");
if (isRhs) {
addr = rhs;
type = rhs.getType();
}
assert(name.symbol && "Reduction object name does not have a symbol");
if (!fir::conformsWithPassByRef(type)) {
addr = builder.createTemporary(loc, type);
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
}
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(),
*name.symbol, extraFlags);
auto declareOp =
hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
{}, nullptr, nullptr, 0, attributes);
symTable.addVariableDefinition(*name.symbol, declareOp);
}
lower::StatementContext stmtCtx;
mlir::Value result = fir::getBase(
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
if (lhs.getType() == refType.getElementType())
result = fir::LoadOp::create(builder, loc, result);
stmtCtx.finalizeAndPop();
if (isByRef) {
fir::StoreOp::create(builder, loc, result, lhs);
mlir::omp::YieldOp::create(builder, loc, lhs);
} else {
mlir::omp::YieldOp::create(builder, loc, result);
}
};
return genCombinerCB;
}
// Checks that the reduction type is either a trivial type or a derived type of
// trivial types.
static bool isSimpleReductionType(mlir::Type reductionType) {
if (fir::isa_trivial(reductionType))
return true;
if (auto recordTy = mlir::dyn_cast<fir::RecordType>(reductionType)) {
for (auto [_, fieldType] : recordTy.getTypeList()) {
if (!fir::isa_trivial(fieldType))
return false;
}
}
return true;
}
// Getting the type from a symbol compared to a DeclSpec is simpler since we do
// not need to consider derived vs intrinsic types. Semantics is guaranteed to
// generate these symbols.
static mlir::Type
getReductionType(lower::AbstractConverter &converter,
const parser::OmpReductionSpecifier &specifier) {
const auto &combinerExpression =
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
.value();
const parser::OmpStylizedInstance &combinerInstance =
combinerExpression.v.front();
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
const parser::OmpStylizedDeclaration &decl = declList.front();
const auto &name = std::get<parser::ObjectName>(decl.var.t);
const auto &symbol = semantics::SymbolRef(*name.symbol);
mlir::Type reductionType = converter.genType(symbol);
if (!isSimpleReductionType(reductionType))
TODO(converter.getCurrentLocation(),
"declare reduction currently only supports trival types or derived "
"types containing trivial types");
return reductionType;
}
static void genOMP(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
if (!semaCtx.langOptions().OpenMPSimd)
TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
if (semaCtx.langOptions().OpenMPSimd)
return;
const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()};
const parser::OmpArgument &arg{args.v.front()};
const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
TODO(converter.getCurrentLocation(),
"multiple types in declare reduction is not yet supported");
mlir::Type reductionType = getReductionType(converter, specifier);
ReductionProcessor::GenCombinerCBTy genCombinerCB =
processReductionCombiner(converter, symTable, semaCtx, specifier);
const parser::OmpClauseList &initializer =
declareReductionConstruct.v.Clauses();
if (initializer.v.size() > 0) {
List<Clause> clauses = makeClauses(initializer, semaCtx);
ReductionProcessor::GenInitValueCBTy genInitValueCB;
ClauseProcessor cp(converter, semaCtx, clauses);
const parser::OmpClause::Initializer &iclause{
std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
cp.processInitializer(symTable, iclause, genInitValueCB);
const auto &identifier =
std::get<parser::OmpReductionIdentifier>(specifier.t);
const auto &designator =
std::get<parser::ProcedureDesignator>(identifier.u);
const auto &reductionName = std::get<parser::Name>(designator.u);
bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
ReductionProcessor::createDeclareReductionHelper<
mlir::omp::DeclareReductionOp>(
converter, reductionName.ToString(), reductionType,
converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
} else {
TODO(converter.getCurrentLocation(),
"declare reduction without an initializer clause is not yet "
"supported");
}
}
static void

View File

@@ -501,7 +501,7 @@ static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
template <typename OpType>
static void createReductionAllocAndInitRegions(
AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type,
bool isByRef) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
@@ -523,9 +523,8 @@ static void createReductionAllocAndInitRegions(
mlir::Type ty = fir::unwrapRefType(type);
builder.setInsertionPointToEnd(initBlock);
mlir::Value initValue = ReductionProcessor::getReductionInitValue(
loc, unwrapSeqOrBoxedType(ty), redId, builder);
mlir::Value initValue =
genInitValueCB(builder, loc, ty, initBlock->getArgument(0));
if (isByRef) {
populateByRefInitAndCleanupRegions(
converter, loc, type, initValue, initBlock,
@@ -536,7 +535,7 @@ static void createReductionAllocAndInitRegions(
/*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>);
}
if (fir::isa_trivial(ty)) {
if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
if (isByRef) {
// alloc region
builder.setInsertionPointToEnd(allocBlock);
@@ -556,18 +555,18 @@ static void createReductionAllocAndInitRegions(
yield(boxAlloca);
}
template <typename OpType>
OpType ReductionProcessor::createDeclareReduction(
template <typename DeclareRedType>
DeclareRedType ReductionProcessor::createDeclareReductionHelper(
AbstractConverter &converter, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef) {
mlir::Type type, mlir::Location loc, bool isByRef,
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();
assert(!reductionOpName.empty());
auto decl = module.lookupSymbol<OpType>(reductionOpName);
auto decl = module.lookupSymbol<DeclareRedType>(reductionOpName);
if (decl)
return decl;
@@ -576,23 +575,54 @@ OpType ReductionProcessor::createDeclareReduction(
if (!isByRef)
type = valTy;
decl = OpType::create(modBuilder, loc, reductionOpName, type);
createReductionAllocAndInitRegions(converter, loc, decl, redId, type,
decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type);
createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type,
isByRef);
builder.createBlock(&decl.getReductionRegion(),
decl.getReductionRegion().end(), {type, type},
{loc, loc});
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
genCombinerCB(builder, loc, type, op1, op2, isByRef);
return decl;
}
static bool doReductionByRef(mlir::Value reductionVar) {
template <typename OpType>
OpType ReductionProcessor::createDeclareReduction(
AbstractConverter &converter, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef) {
auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value val) {
mlir::Type ty = fir::unwrapRefType(type);
mlir::Value initValue = ReductionProcessor::getReductionInitValue(
loc, unwrapSeqOrBoxedType(ty), redId, builder);
return initValue;
};
auto genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value op1, mlir::Value op2,
bool isByRef) {
genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
};
return createDeclareReductionHelper<OpType>(converter, reductionOpName, type,
loc, isByRef, genCombinerCB,
genInitValueCB);
}
bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
if (forceByrefReduction)
return true;
if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) &&
!fir::isa_derived(fir::unwrapRefType(reductionType)))
return true;
return false;
}
bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) {
if (forceByrefReduction)
return true;
@@ -600,10 +630,7 @@ static bool doReductionByRef(mlir::Value reductionVar) {
mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
reductionVar = declare.getMemref();
if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
return true;
return false;
return doReductionByRef(reductionVar.getType());
}
template <typename OpType, typename RedOperatorListTy>
@@ -614,6 +641,8 @@ bool ReductionProcessor::processReductionArguments(
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
if constexpr (std::is_same_v<RedOperatorListTy,
omp::clause::ReductionOperatorList>) {
// For OpenMP reduction clauses, check if the reduction operator is
@@ -627,7 +656,13 @@ bool ReductionProcessor::processReductionArguments(
std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
return false;
// If not an intrinsic is has to be a custom reduction op, and should
// be available in the module.
semantics::Symbol *sym = reductionIntrinsic->v.sym();
mlir::ModuleOp module = builder.getModule();
auto decl = module.lookupSymbol<OpType>(getRealName(sym).ToString());
if (!decl)
return false;
}
} else {
return false;
@@ -637,7 +672,6 @@ bool ReductionProcessor::processReductionArguments(
// Reduction variable processing common to both intrinsic operators and
// procedure designators
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::OpBuilder::InsertPoint dcIP;
constexpr bool isDoConcurrent =
std::is_same_v<OpType, fir::DeclareReductionOp>;
@@ -741,7 +775,13 @@ bool ReductionProcessor::processReductionArguments(
&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
TODO(currentLocation, "Unsupported intrinsic proc reduction");
// Custom reductions we can just add to the symbols without
// generating the declare reduction op.
semantics::Symbol *sym = reductionIntrinsic->v.sym();
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
builder.getContext(), sym->name().ToString()));
++idx;
continue;
}
redId = getReductionType(*reductionIntrinsic);
reductionName =

View File

@@ -21,6 +21,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
namespace flangomp {
#define GEN_PASS_DEF_MARKDECLARETARGETPASS
@@ -31,9 +32,93 @@ namespace {
class MarkDeclareTargetPass
: public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
mlir::omp::DeclareTargetCaptureClause parentCapClause,
bool parentAutomap, mlir::Operation *currOp,
struct ParentInfo {
mlir::omp::DeclareTargetDeviceType devTy;
mlir::omp::DeclareTargetCaptureClause capClause;
bool automap;
};
void processSymbolRef(mlir::SymbolRefAttr symRef, ParentInfo parentInfo,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (auto currFOp =
getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
currFOp.getOperation());
if (current.isDeclareTarget()) {
auto currentDt = current.getDeclareTargetDeviceType();
// Found the same function twice, with different device_types,
// mark as Any as it belongs to both
if (currentDt != parentInfo.devTy &&
currentDt != mlir::omp::DeclareTargetDeviceType::any) {
current.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
current.getDeclareTargetCaptureClause(),
current.getDeclareTargetAutomap());
}
} else {
current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
parentInfo.automap);
}
markNestedFuncs(parentInfo, currFOp, visited);
}
}
void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
ParentInfo parentInfo,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (!symRefs)
return;
for (auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
if (auto declareReductionOp =
getOperation().lookupSymbol<mlir::omp::DeclareReductionOp>(
symRef)) {
markNestedFuncs(parentInfo, declareReductionOp, visited);
}
}
}
void
processReductionClauses(mlir::Operation *op, ParentInfo parentInfo,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
llvm::TypeSwitch<mlir::Operation &>(*op)
.Case([&](mlir::omp::LoopOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::ParallelOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::SectionsOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::SimdOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TargetOp op) {
processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TaskgroupOp op) {
processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TaskloopOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TaskOp op) {
processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TeamsOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::WsloopOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Default([](mlir::Operation &) {});
}
void markNestedFuncs(ParentInfo parentInfo, mlir::Operation *currOp,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (visited.contains(currOp))
return;
@@ -43,33 +128,10 @@ class MarkDeclareTargetPass
if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
callOp.getCallableForCallee())) {
if (auto currFOp =
getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
currFOp.getOperation());
if (current.isDeclareTarget()) {
auto currentDt = current.getDeclareTargetDeviceType();
// Found the same function twice, with different device_types,
// mark as Any as it belongs to both
if (currentDt != parentDevTy &&
currentDt != mlir::omp::DeclareTargetDeviceType::any) {
current.setDeclareTarget(
mlir::omp::DeclareTargetDeviceType::any,
current.getDeclareTargetCaptureClause(),
current.getDeclareTargetAutomap());
}
} else {
current.setDeclareTarget(parentDevTy, parentCapClause,
parentAutomap);
}
markNestedFuncs(parentDevTy, parentCapClause, parentAutomap,
currFOp, visited);
}
processSymbolRef(symRef, parentInfo, visited);
}
}
processReductionClauses(op, parentInfo, visited);
});
}
@@ -82,10 +144,10 @@ class MarkDeclareTargetPass
functionOp.getOperation());
if (declareTargetOp.isDeclareTarget()) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
declareTargetOp.getDeclareTargetCaptureClause(),
declareTargetOp.getDeclareTargetAutomap(), functionOp,
visited);
ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
declareTargetOp.getDeclareTargetCaptureClause(),
declareTargetOp.getDeclareTargetAutomap()};
markNestedFuncs(parentInfo, functionOp, visited);
}
}
@@ -96,12 +158,13 @@ class MarkDeclareTargetPass
// the contents of the device clause
getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
markNestedFuncs(
/*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
/*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to,
/*parentAutomap=*/false, tarOp, visited);
ParentInfo parentInfo = {
/*devTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
/*capClause=*/mlir::omp::DeclareTargetCaptureClause::to,
/*automap=*/false,
};
markNestedFuncs(parentInfo, tarOp, visited);
});
}
};
} // namespace

View File

@@ -0,0 +1,19 @@
! This test checks lowering of OpenMP declare reduction with non-trivial types
! RUN: not %flang_fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
module mymod
type advancedtype
integer(4)::myarray(10)
integer(4)::val
integer(4)::otherval
end type advancedtype
!CHECK: not yet implemented: declare reduction currently only supports trival types or derived types containing trivial types
!$omp declare reduction(myreduction: advancedtype: omp_out = omp_in) initializer(omp_priv = omp_orig)
end module mymod
program mymaxtest
use mymod
end program

View File

@@ -1,28 +0,0 @@
! This test checks lowering of OpenMP declare reduction Directive, with initialization
! via a subroutine. This functionality is currently not implemented.
! RUN: not %flang_fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
!CHECK: not yet implemented: OpenMPDeclareReductionConstruct
subroutine initme(x,n)
integer x,n
x=n
end subroutine initme
function func(x, n, init)
integer func
integer x(n)
integer res
interface
subroutine initme(x,n)
integer x,n
end subroutine initme
end interface
!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,0))
res=init
!$omp simd reduction(red_add:res)
do i=1,n
res=res+x(i)
enddo
func=res
end function func

View File

@@ -1,10 +0,0 @@
! This test checks lowering of OpenMP declare reduction Directive.
! RUN: not %flang_fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
subroutine declare_red()
integer :: my_var
!CHECK: not yet implemented: OpenMPDeclareReductionConstruct
!$omp declare reduction (my_red : integer : omp_out = omp_in) initializer (omp_priv = 0)
my_var = 0
end subroutine declare_red

View File

@@ -0,0 +1,37 @@
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -fopenmp-is-device %s -o - | FileCheck %s
program main
use, intrinsic :: iso_c_binding
implicit none
interface
subroutine myinit(priv, orig) bind(c,name="myinit")
use, intrinsic :: iso_c_binding
implicit none
integer::priv, orig
end subroutine myinit
function mycombine(lhs, rhs) bind(c,name="mycombine")
use, intrinsic :: iso_c_binding
implicit none
integer::lhs, rhs, mycombine
end function mycombine
end interface
!$omp declare reduction(myreduction:integer:omp_out = mycombine(omp_out, omp_in)) initializer(myinit(omp_priv, omp_orig))
integer :: i, s, a(10)
!$omp target
s = 0
!$omp do reduction(myreduction:s)
do i = 1, 10
s = mycombine(s, a(i))
enddo
!$omp end do
!$omp end target
end program main
!CHECK: func.func {{.*}} @myinit(!fir.ref<i32>, !fir.ref<i32>)
!CHECK-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
!CHECK-LABEL: func.func {{.*}} @mycombine(!fir.ref<i32>, !fir.ref<i32>)
!CHECK-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}

View File

@@ -0,0 +1,112 @@
! This test checks lowering of OpenMP declare reduction Directive, with initialization
! via a subroutine. This functionality is currently not implemented.
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
module maxtype_mod
implicit none
type maxtype
integer::sumval
integer::maxval
end type maxtype
contains
subroutine initme(x,n)
type(maxtype) :: x,n
x%sumval=0
x%maxval=0
end subroutine initme
function mycombine(lhs, rhs)
type(maxtype) :: lhs, rhs
type(maxtype) :: mycombine
mycombine%sumval = lhs%sumval + rhs%sumval
mycombine%maxval = max(lhs%maxval, rhs%maxval)
end function mycombine
function func(x, n, init)
type(maxtype) :: func
integer :: n, i
type(maxtype) :: x(n)
type(maxtype) :: init
type(maxtype) :: res
!$omp declare reduction(red_add_max:maxtype:omp_out=mycombine(omp_out,omp_in)) initializer(initme(omp_priv,omp_orig))
res=init
!$omp simd reduction(red_add_max:res)
do i=1,n
res=mycombine(res,x(i))
enddo
func=res
end function func
end module maxtype_mod
!CHECK: omp.declare_reduction @red_add_max : [[MAXTYPE:.*]] init {
!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: [[MAXTYPE]]):
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]]
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]]
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]>
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : [[MAXTYPE]])
!CHECK: } combiner {
!CHECK: ^bb0(%[[LHS_ARG:.*]]: [[MAXTYPE]], %[[RHS_ARG:.*]]: [[MAXTYPE]]):
!CHECK: %[[RESULT:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = ".result"}
!CHECK: %[[OMP_OUT:.*]] = fir.alloca [[MAXTYPE]]
!CHECK: %[[OMP_IN:.*]] = fir.alloca [[MAXTYPE]]
!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]>
!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]>
!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
!CHECK: fir.save_result %[[COMBINE_RESULT]] to %[[RESULT]] : [[MAXTYPE]], !fir.ref<[[MAXTYPE]]>
!CHECK: %[[TMPRESULT:.*]]:2 = hlfir.declare %[[RESULT]] {uniq_name = ".tmp.func_result"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %false = arith.constant false
!CHECK: %[[EXPRRESULT:.*]] = hlfir.as_expr %[[TMPRESULT]]#0 move %false : (!fir.ref<[[MAXTYPE]]>, i1) -> !hlfir.expr<[[MAXTYPE]]>
!CHECK: %[[ASSOCIATE:.*]]:3 = hlfir.associate %[[EXPRRESULT]] {adapt.valuebyref} : (!hlfir.expr<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>, i1)
!CHECK: %[[RESULT_VAL:.*]] = fir.load %[[ASSOCIATE]]#0 : !fir.ref<[[MAXTYPE]]>
!CHECK: hlfir.end_associate %[[ASSOCIATE]]#1, %[[ASSOCIATE]]#2 : !fir.ref<[[MAXTYPE]]>, i1
!CHECK: omp.yield(%[[RESULT_VAL]] : [[MAXTYPE]])
!CHECK: }
!CHECK: func.func @_QMmaxtype_modPinitme(%[[X_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %[[N_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) {
!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
!CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %[[N_ARG]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QMmaxtype_modFinitmeEn"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_ARG]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QMmaxtype_modFinitmeEx"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[ZERO_0:.*]] = arith.constant 0 : i32
!CHECK: %[[X_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: hlfir.assign %[[ZERO_0]] to %[[X_DESIGNATE_SUMVAL]] : i32, !fir.ref<i32>
!CHECK: %[[ZERO_1:.*]] = arith.constant 0 : i32
!CHECK: %[[X_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: hlfir.assign %[[ZERO_1]] to %[[X_DESIGNATE_MAXVAL]] : i32, !fir.ref<i32>
!CHECK: return
!CHECK: }
!CHECK: func.func @_QMmaxtype_modPmycombine(%[[LHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "lhs"}, %[[RHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "rhs"}) -> [[MAXTYPE]] {
!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
!CHECK: %[[LHS_DECL:.*]]:2 = hlfir.declare %[[LHS]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QMmaxtype_modFmycombineElhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[RESULT_ALLOC:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = "mycombine", uniq_name = "_QMmaxtype_modFmycombineEmycombine"}
!CHECK: %[[RESULT_DECL:.*]]:2 = hlfir.declare %[[RESULT_ALLOC]] {uniq_name = "_QMmaxtype_modFmycombineEmycombine"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[RHS_DECL:.*]]:2 = hlfir.declare %[[RHS]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QMmaxtype_modFmycombineErhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
!CHECK: %[[LHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: %[[LHS_SUMVAL:.*]] = fir.load %[[LHS_DESIGNATE_SUMVAL]] : !fir.ref<i32>
!CHECK: %[[RHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: %[[RHS_SUMVAL:.*]] = fir.load %[[RHS_DESIGNATE_SUMVAL]] : !fir.ref<i32>
!CHECK: %[[SUM:.*]] = arith.addi %[[LHS_SUMVAL]], %[[RHS_SUMVAL]] : i32
!CHECK: %[[RESULT_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: hlfir.assign %[[SUM]] to %[[RESULT_DESIGNATE_SUMVAL]] : i32, !fir.ref<i32>
!CHECK: %[[LHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: %[[LHS_MAXVAL:.*]] = fir.load %[[LHS_DESIGNATE_MAXVAL]] : !fir.ref<i32>
!CHECK: %[[RHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: %[[RHS_MAXVAL:.*]] = fir.load %[[RHS_DESIGNATE_MAXVAL]] : !fir.ref<i32>
!CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32
!CHECK: %[[MAX_VAL:.*]] = arith.select %[[CMP]], %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32
!CHECK: %[[RESULT_DESIGNAGE_MAXVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
!CHECK: hlfir.assign %[[MAX_VAL]] to %[[RESULT_DESIGNAGE_MAXVAL]] : i32, !fir.ref<i32>
!CHECK: %[[RESULT:.*]] = fir.load %[[RESULT_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
!CHECK: return %[[RESULT]] : [[MAXTYPE]]
!CHECK: }

View File

@@ -0,0 +1,59 @@
! This test checks lowering of OpenMP declare reduction Directive, with initialization
! via a subroutine. This functionality is currently not implemented.
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
subroutine initme(x,n)
integer x,n
x=0
end subroutine initme
function func(x, n, init)
integer func
integer x(n)
integer res
interface
subroutine initme(x,n)
integer x,n
end subroutine initme
end interface
!CHECK: omp.declare_reduction @red_add : i32 init {
!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32):
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<i32>
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<i32>
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: fir.call @_QPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> ()
!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<i32>
!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : i32)
!CHECK: } combiner {
!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32):
!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
!CHECK: %[[OMP_IN:.*]] = fir.alloca i32
!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<i32>
!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<i32>
!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
!CHECK: omp.yield(%[[SUM]] : i32)
!CHECK: }
!CHECK: func.func @_QPinitme(%[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
!CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %[[N]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QFinitmeEn"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] dummy_scope %[[OMP_OUT]] arg 1 {uniq_name = "_QFinitmeEx"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[CONST_0]] to %[[X_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: return
!CHECK: }
!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,omp_orig))
res=init
!$omp simd reduction(red_add:res)
do i=1,n
res=res+x(i)
enddo
func=res
end function func

View File

@@ -0,0 +1,33 @@
! This test checks lowering of OpenMP declare reduction Directive.
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
subroutine declare_red()
integer :: my_var
!CHECK: omp.declare_reduction @my_red : i32 init {
!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32):
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<i32>
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<i32>
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
!CHECK: omp.yield(%[[CONST_0]] : i32)
!CHECK: } combiner {
!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32):
!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
!CHECK: %[[OMP_IN:.*]] = fir.alloca i32
!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<i32>
!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<i32>
!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
!CHECK: omp.yield(%[[SUM]] : i32)
!CHECK: }
!$omp declare reduction (my_red : integer : omp_out = omp_out + omp_in) initializer (omp_priv = 0)
my_var = 0
end subroutine declare_red

View File

@@ -1170,6 +1170,7 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
template <typename T>
static void
mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
llvm::IRBuilderBase &builder,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
DenseMap<Value, llvm::Value *> &reductionVariableMap,
unsigned i) {
@@ -1180,8 +1181,17 @@ mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
mlir::Value mlirSource = loop.getReductionVars()[i];
llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
assert(llvmSource && "lookup reduction var");
moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
llvm::Value *origVal = llvmSource;
// If a non-pointer value is expected, load the value from the source pointer.
if (!isa<LLVM::LLVMPointerType>(
reduction.getInitializerMoldArg().getType()) &&
isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
origVal =
builder.CreateLoad(moduleTranslation.convertType(
reduction.getInitializerMoldArg().getType()),
llvmSource, "omp_orig");
}
moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
if (entry.getNumArguments() > 1) {
llvm::Value *allocation =
@@ -1254,7 +1264,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
SmallVector<llvm::Value *, 1> phis;
// map block argument to initializer region
mapInitializationArgs(op, moduleTranslation, reductionDecls,
mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
reductionVariableMap, i);
// TODO In some cases (specially on the GPU), the init regions may

View File

@@ -0,0 +1,88 @@
! Basic offloading test with custom OpenMP reduction on derived type
! REQUIRES: flang, amdgpu
!
! RUN: %libomptarget-compile-fortran-generic
! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
module maxtype_mod
implicit none
type maxtype
integer::sumval
integer::maxval
end type maxtype
contains
subroutine initme(x,n)
type(maxtype) :: x,n
x%sumval=0
x%maxval=0
end subroutine initme
function mycombine(lhs, rhs)
type(maxtype) :: lhs, rhs
type(maxtype) :: mycombine
mycombine%sumval = lhs%sumval + rhs%sumval
mycombine%maxval = max(lhs%maxval, rhs%maxval)
end function mycombine
end module maxtype_mod
program main
use maxtype_mod
implicit none
integer :: n = 100
integer :: i
integer :: error = 0
type(maxtype) :: x(100)
type(maxtype) :: res
integer :: expected_sum, expected_max
!$omp declare reduction(red_add_max:maxtype:omp_out=mycombine(omp_out,omp_in)) initializer(initme(omp_priv,omp_orig))
! Initialize array with test data
do i = 1, n
x(i)%sumval = i
x(i)%maxval = i
end do
! Initialize reduction variable
res%sumval = 0
res%maxval = 0
! Perform reduction in target region
!$omp target parallel do map(to:x) reduction(red_add_max:res)
do i = 1, n
res = mycombine(res, x(i))
end do
!$omp end target parallel do
! Compute expected values
expected_sum = 0
expected_max = 0
do i = 1, n
expected_sum = expected_sum + i
expected_max = max(expected_max, i)
end do
! Check results
if (res%sumval /= expected_sum) then
error = 1
endif
if (res%maxval /= expected_max) then
error = 1
endif
if (error == 0) then
print *,"PASSED"
else
print *,"FAILED"
endif
end program main
! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
! CHECK: PASSED