mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[OpenMP][flang] Lowering of OpenMP custom reductions to MLIR (#168417)
This patch add support for lowering of custom reductions to MLIR. It also enhances the capability of the pass to automatically mark functions as "declare target" by traversing custom reduction initializers and combiners.
This commit is contained in:
@@ -40,6 +40,13 @@ namespace omp {
|
||||
|
||||
class ReductionProcessor {
|
||||
public:
|
||||
using GenInitValueCBTy =
|
||||
std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value ompOrig)>;
|
||||
using GenCombinerCBTy = std::function<void(
|
||||
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
|
||||
mlir::Value op1, mlir::Value op2, bool isByRef)>;
|
||||
|
||||
// TODO: Move this enumeration to the OpenMP dialect
|
||||
enum ReductionIdentifier {
|
||||
ID,
|
||||
@@ -58,6 +65,9 @@ public:
|
||||
IEOR
|
||||
};
|
||||
|
||||
static bool doReductionByRef(mlir::Type reductionType);
|
||||
static bool doReductionByRef(mlir::Value reductionVar);
|
||||
|
||||
static ReductionIdentifier
|
||||
getReductionType(const omp::clause::ProcedureDesignator &pd);
|
||||
|
||||
@@ -109,6 +119,14 @@ public:
|
||||
ReductionIdentifier redId,
|
||||
mlir::Type type, mlir::Value op1,
|
||||
mlir::Value op2);
|
||||
/// Creates an OpenMP reduction declaration and inserts it into the provided
|
||||
/// symbol table. The init and combiner regions are generated by the callback
|
||||
/// functions genCombinerCB and genInitValueCB.
|
||||
template <typename DeclareRedType>
|
||||
static DeclareRedType createDeclareReductionHelper(
|
||||
AbstractConverter &converter, llvm::StringRef reductionOpName,
|
||||
mlir::Type type, mlir::Location loc, bool isByRef,
|
||||
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);
|
||||
|
||||
/// Creates an OpenMP reduction declaration and inserts it into the provided
|
||||
/// symbol table. The declaration has a constant initializer with the neutral
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ClauseProcessor.h"
|
||||
#include "Utils.h"
|
||||
|
||||
#include "flang/Lower/ConvertCall.h"
|
||||
#include "flang/Lower/ConvertExprToHLFIR.h"
|
||||
#include "flang/Lower/OpenMP/Clauses.h"
|
||||
#include "flang/Lower/PFTBuilder.h"
|
||||
@@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive(
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ClauseProcessor::processInitializer(
|
||||
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
|
||||
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
|
||||
if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
|
||||
genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value ompOrig) {
|
||||
lower::SymMapScope scope(symMap);
|
||||
const parser::OmpInitializerExpression &iexpr = inp.v.v;
|
||||
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
|
||||
const std::list<parser::OmpStylizedDeclaration> &declList =
|
||||
std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
|
||||
mlir::Value ompPrivVar;
|
||||
for (const parser::OmpStylizedDeclaration &decl : declList) {
|
||||
auto &name = std::get<parser::ObjectName>(decl.var.t);
|
||||
assert(name.symbol && "Name does not have a symbol");
|
||||
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
|
||||
fir::StoreOp::create(builder, loc, ompOrig, addr);
|
||||
fir::FortranVariableFlagsEnum extraFlags = {};
|
||||
fir::FortranVariableFlagsAttr attributes =
|
||||
Fortran::lower::translateSymbolAttributes(builder.getContext(),
|
||||
*name.symbol, extraFlags);
|
||||
auto declareOp = hlfir::DeclareOp::create(
|
||||
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
|
||||
0, attributes);
|
||||
if (name.ToString() == "omp_priv")
|
||||
ompPrivVar = declareOp.getResult(0);
|
||||
symMap.addVariableDefinition(*name.symbol, declareOp);
|
||||
}
|
||||
// Lower the expression/function call
|
||||
lower::StatementContext stmtCtx;
|
||||
mlir::Value result = common::visit(
|
||||
common::visitors{
|
||||
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
|
||||
convertCallToHLFIR(loc, converter, procRef, std::nullopt,
|
||||
symMap, stmtCtx);
|
||||
auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
|
||||
return privVal;
|
||||
},
|
||||
[&](const auto &expr) -> mlir::Value {
|
||||
mlir::Value exprResult = fir::getBase(convertExprToValue(
|
||||
loc, converter, clause->v, symMap, stmtCtx));
|
||||
// Conversion can either give a value or a refrence to a value,
|
||||
// we need to return the reduction type, so an optional load may
|
||||
// be generated.
|
||||
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
|
||||
exprResult.getType()))
|
||||
if (ompPrivVar.getType() == refType)
|
||||
exprResult = fir::LoadOp::create(builder, loc, exprResult);
|
||||
return exprResult;
|
||||
}},
|
||||
clause->v.u);
|
||||
stmtCtx.finalizeAndPop();
|
||||
return result;
|
||||
};
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ClauseProcessor::processMergeable(
|
||||
mlir::omp::MergeableClauseOps &result) const {
|
||||
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "flang/Lower/Bridge.h"
|
||||
#include "flang/Lower/DirectivesCommon.h"
|
||||
#include "flang/Lower/OpenMP/Clauses.h"
|
||||
#include "flang/Lower/Support/ReductionProcessor.h"
|
||||
#include "flang/Optimizer/Builder/Todo.h"
|
||||
#include "flang/Parser/dump-parse-tree.h"
|
||||
#include "flang/Parser/parse-tree.h"
|
||||
@@ -88,6 +89,9 @@ public:
|
||||
bool processHint(mlir::omp::HintClauseOps &result) const;
|
||||
bool processInclusive(mlir::Location currentLocation,
|
||||
mlir::omp::InclusiveClauseOps &result) const;
|
||||
bool processInitializer(
|
||||
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
|
||||
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
|
||||
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
|
||||
bool processNogroup(mlir::omp::NogroupClauseOps &result) const;
|
||||
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
|
||||
|
||||
@@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp,
|
||||
|
||||
Initializer make(const parser::OmpClause::Initializer &inp,
|
||||
semantics::SemanticsContext &semaCtx) {
|
||||
llvm_unreachable("Empty: initializer");
|
||||
const parser::OmpInitializerExpression &iexpr = inp.v.v;
|
||||
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
|
||||
const parser::OmpStylizedInstance::Instance &instance =
|
||||
std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
|
||||
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
|
||||
auto &expr = std::get<parser::Expr>(as->t);
|
||||
return Initializer{makeExpr(expr, semaCtx)};
|
||||
} else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
|
||||
if (call->typedCall) {
|
||||
const auto &procRef = *call->typedCall;
|
||||
semantics::SomeExpr evalProcRef{procRef};
|
||||
return Initializer{evalProcRef};
|
||||
}
|
||||
}
|
||||
|
||||
llvm_unreachable("Unexpected initializer");
|
||||
}
|
||||
|
||||
InReduction make(const parser::OmpClause::InReduction &inp,
|
||||
|
||||
@@ -18,12 +18,15 @@
|
||||
#include "Decomposer.h"
|
||||
#include "Utils.h"
|
||||
#include "flang/Common/idioms.h"
|
||||
#include "flang/Evaluate/type.h"
|
||||
#include "flang/Lower/Bridge.h"
|
||||
#include "flang/Lower/ConvertExpr.h"
|
||||
#include "flang/Lower/ConvertExprToHLFIR.h"
|
||||
#include "flang/Lower/ConvertVariable.h"
|
||||
#include "flang/Lower/DirectivesCommon.h"
|
||||
#include "flang/Lower/OpenMP/Clauses.h"
|
||||
#include "flang/Lower/StatementContext.h"
|
||||
#include "flang/Lower/Support/ReductionProcessor.h"
|
||||
#include "flang/Lower/SymbolMap.h"
|
||||
#include "flang/Optimizer/Builder/BoxValue.h"
|
||||
#include "flang/Optimizer/Builder/FIRBuilder.h"
|
||||
@@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
|
||||
// TODO: Add private syms and vars.
|
||||
args.reduction.syms = reductionSyms;
|
||||
args.reduction.vars = clauseOps.reductionVars;
|
||||
|
||||
return genOpWithBody<mlir::omp::TeamsOp>(
|
||||
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
|
||||
llvm::omp::Directive::OMPD_teams)
|
||||
@@ -3570,12 +3572,156 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
|
||||
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
|
||||
}
|
||||
|
||||
static ReductionProcessor::GenCombinerCBTy
|
||||
processReductionCombiner(lower::AbstractConverter &converter,
|
||||
lower::SymMap &symTable,
|
||||
semantics::SemanticsContext &semaCtx,
|
||||
const parser::OmpReductionSpecifier &specifier) {
|
||||
ReductionProcessor::GenCombinerCBTy genCombinerCB;
|
||||
const auto &combinerExpression =
|
||||
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
|
||||
.value();
|
||||
const parser::OmpStylizedInstance &combinerInstance =
|
||||
combinerExpression.v.front();
|
||||
const parser::OmpStylizedInstance::Instance &instance =
|
||||
std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
|
||||
|
||||
const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u);
|
||||
if (!as) {
|
||||
TODO(converter.getCurrentLocation(),
|
||||
"A combiner that is a subroutine call is not yet supported");
|
||||
}
|
||||
auto &expr = std::get<parser::Expr>(as->t);
|
||||
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value lhs, mlir::Value rhs,
|
||||
bool isByRef) {
|
||||
const auto &evalExpr = makeExpr(expr, semaCtx);
|
||||
lower::SymMapScope scope(symTable);
|
||||
const std::list<parser::OmpStylizedDeclaration> &declList =
|
||||
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
|
||||
for (const parser::OmpStylizedDeclaration &decl : declList) {
|
||||
auto &name = std::get<parser::ObjectName>(decl.var.t);
|
||||
mlir::Value addr = lhs;
|
||||
mlir::Type type = lhs.getType();
|
||||
bool isRhs = name.ToString() == std::string("omp_in");
|
||||
if (isRhs) {
|
||||
addr = rhs;
|
||||
type = rhs.getType();
|
||||
}
|
||||
|
||||
assert(name.symbol && "Reduction object name does not have a symbol");
|
||||
if (!fir::conformsWithPassByRef(type)) {
|
||||
addr = builder.createTemporary(loc, type);
|
||||
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
|
||||
}
|
||||
fir::FortranVariableFlagsEnum extraFlags = {};
|
||||
fir::FortranVariableFlagsAttr attributes =
|
||||
Fortran::lower::translateSymbolAttributes(builder.getContext(),
|
||||
*name.symbol, extraFlags);
|
||||
auto declareOp =
|
||||
hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
|
||||
{}, nullptr, nullptr, 0, attributes);
|
||||
symTable.addVariableDefinition(*name.symbol, declareOp);
|
||||
}
|
||||
|
||||
lower::StatementContext stmtCtx;
|
||||
mlir::Value result = fir::getBase(
|
||||
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
|
||||
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
|
||||
if (lhs.getType() == refType.getElementType())
|
||||
result = fir::LoadOp::create(builder, loc, result);
|
||||
stmtCtx.finalizeAndPop();
|
||||
if (isByRef) {
|
||||
fir::StoreOp::create(builder, loc, result, lhs);
|
||||
mlir::omp::YieldOp::create(builder, loc, lhs);
|
||||
} else {
|
||||
mlir::omp::YieldOp::create(builder, loc, result);
|
||||
}
|
||||
};
|
||||
return genCombinerCB;
|
||||
}
|
||||
|
||||
// Checks that the reduction type is either a trivial type or a derived type of
|
||||
// trivial types.
|
||||
static bool isSimpleReductionType(mlir::Type reductionType) {
|
||||
if (fir::isa_trivial(reductionType))
|
||||
return true;
|
||||
if (auto recordTy = mlir::dyn_cast<fir::RecordType>(reductionType)) {
|
||||
for (auto [_, fieldType] : recordTy.getTypeList()) {
|
||||
if (!fir::isa_trivial(fieldType))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Getting the type from a symbol compared to a DeclSpec is simpler since we do
|
||||
// not need to consider derived vs intrinsic types. Semantics is guaranteed to
|
||||
// generate these symbols.
|
||||
static mlir::Type
|
||||
getReductionType(lower::AbstractConverter &converter,
|
||||
const parser::OmpReductionSpecifier &specifier) {
|
||||
const auto &combinerExpression =
|
||||
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
|
||||
.value();
|
||||
const parser::OmpStylizedInstance &combinerInstance =
|
||||
combinerExpression.v.front();
|
||||
const std::list<parser::OmpStylizedDeclaration> &declList =
|
||||
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
|
||||
const parser::OmpStylizedDeclaration &decl = declList.front();
|
||||
const auto &name = std::get<parser::ObjectName>(decl.var.t);
|
||||
const auto &symbol = semantics::SymbolRef(*name.symbol);
|
||||
mlir::Type reductionType = converter.genType(symbol);
|
||||
|
||||
if (!isSimpleReductionType(reductionType))
|
||||
TODO(converter.getCurrentLocation(),
|
||||
"declare reduction currently only supports trival types or derived "
|
||||
"types containing trivial types");
|
||||
return reductionType;
|
||||
}
|
||||
|
||||
static void genOMP(
|
||||
lower::AbstractConverter &converter, lower::SymMap &symTable,
|
||||
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
|
||||
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
|
||||
if (!semaCtx.langOptions().OpenMPSimd)
|
||||
TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
|
||||
if (semaCtx.langOptions().OpenMPSimd)
|
||||
return;
|
||||
|
||||
const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()};
|
||||
const parser::OmpArgument &arg{args.v.front()};
|
||||
const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
|
||||
|
||||
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
|
||||
TODO(converter.getCurrentLocation(),
|
||||
"multiple types in declare reduction is not yet supported");
|
||||
|
||||
mlir::Type reductionType = getReductionType(converter, specifier);
|
||||
ReductionProcessor::GenCombinerCBTy genCombinerCB =
|
||||
processReductionCombiner(converter, symTable, semaCtx, specifier);
|
||||
const parser::OmpClauseList &initializer =
|
||||
declareReductionConstruct.v.Clauses();
|
||||
if (initializer.v.size() > 0) {
|
||||
List<Clause> clauses = makeClauses(initializer, semaCtx);
|
||||
ReductionProcessor::GenInitValueCBTy genInitValueCB;
|
||||
ClauseProcessor cp(converter, semaCtx, clauses);
|
||||
const parser::OmpClause::Initializer &iclause{
|
||||
std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
|
||||
cp.processInitializer(symTable, iclause, genInitValueCB);
|
||||
const auto &identifier =
|
||||
std::get<parser::OmpReductionIdentifier>(specifier.t);
|
||||
const auto &designator =
|
||||
std::get<parser::ProcedureDesignator>(identifier.u);
|
||||
const auto &reductionName = std::get<parser::Name>(designator.u);
|
||||
bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
|
||||
ReductionProcessor::createDeclareReductionHelper<
|
||||
mlir::omp::DeclareReductionOp>(
|
||||
converter, reductionName.ToString(), reductionType,
|
||||
converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
|
||||
} else {
|
||||
TODO(converter.getCurrentLocation(),
|
||||
"declare reduction without an initializer clause is not yet "
|
||||
"supported");
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
|
||||
@@ -501,7 +501,7 @@ static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
|
||||
template <typename OpType>
|
||||
static void createReductionAllocAndInitRegions(
|
||||
AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
|
||||
const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
|
||||
ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type,
|
||||
bool isByRef) {
|
||||
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
|
||||
auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
|
||||
@@ -523,9 +523,8 @@ static void createReductionAllocAndInitRegions(
|
||||
|
||||
mlir::Type ty = fir::unwrapRefType(type);
|
||||
builder.setInsertionPointToEnd(initBlock);
|
||||
mlir::Value initValue = ReductionProcessor::getReductionInitValue(
|
||||
loc, unwrapSeqOrBoxedType(ty), redId, builder);
|
||||
|
||||
mlir::Value initValue =
|
||||
genInitValueCB(builder, loc, ty, initBlock->getArgument(0));
|
||||
if (isByRef) {
|
||||
populateByRefInitAndCleanupRegions(
|
||||
converter, loc, type, initValue, initBlock,
|
||||
@@ -536,7 +535,7 @@ static void createReductionAllocAndInitRegions(
|
||||
/*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>);
|
||||
}
|
||||
|
||||
if (fir::isa_trivial(ty)) {
|
||||
if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
|
||||
if (isByRef) {
|
||||
// alloc region
|
||||
builder.setInsertionPointToEnd(allocBlock);
|
||||
@@ -556,18 +555,18 @@ static void createReductionAllocAndInitRegions(
|
||||
yield(boxAlloca);
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
OpType ReductionProcessor::createDeclareReduction(
|
||||
template <typename DeclareRedType>
|
||||
DeclareRedType ReductionProcessor::createDeclareReductionHelper(
|
||||
AbstractConverter &converter, llvm::StringRef reductionOpName,
|
||||
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
|
||||
bool isByRef) {
|
||||
mlir::Type type, mlir::Location loc, bool isByRef,
|
||||
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB) {
|
||||
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
|
||||
mlir::OpBuilder::InsertionGuard guard(builder);
|
||||
mlir::ModuleOp module = builder.getModule();
|
||||
|
||||
assert(!reductionOpName.empty());
|
||||
|
||||
auto decl = module.lookupSymbol<OpType>(reductionOpName);
|
||||
auto decl = module.lookupSymbol<DeclareRedType>(reductionOpName);
|
||||
if (decl)
|
||||
return decl;
|
||||
|
||||
@@ -576,23 +575,54 @@ OpType ReductionProcessor::createDeclareReduction(
|
||||
if (!isByRef)
|
||||
type = valTy;
|
||||
|
||||
decl = OpType::create(modBuilder, loc, reductionOpName, type);
|
||||
createReductionAllocAndInitRegions(converter, loc, decl, redId, type,
|
||||
decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type);
|
||||
createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type,
|
||||
isByRef);
|
||||
|
||||
builder.createBlock(&decl.getReductionRegion(),
|
||||
decl.getReductionRegion().end(), {type, type},
|
||||
{loc, loc});
|
||||
|
||||
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
|
||||
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
|
||||
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
|
||||
genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
|
||||
|
||||
genCombinerCB(builder, loc, type, op1, op2, isByRef);
|
||||
return decl;
|
||||
}
|
||||
|
||||
static bool doReductionByRef(mlir::Value reductionVar) {
|
||||
template <typename OpType>
|
||||
OpType ReductionProcessor::createDeclareReduction(
|
||||
AbstractConverter &converter, llvm::StringRef reductionOpName,
|
||||
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
|
||||
bool isByRef) {
|
||||
auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value val) {
|
||||
mlir::Type ty = fir::unwrapRefType(type);
|
||||
mlir::Value initValue = ReductionProcessor::getReductionInitValue(
|
||||
loc, unwrapSeqOrBoxedType(ty), redId, builder);
|
||||
return initValue;
|
||||
};
|
||||
auto genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value op1, mlir::Value op2,
|
||||
bool isByRef) {
|
||||
genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
|
||||
};
|
||||
|
||||
return createDeclareReductionHelper<OpType>(converter, reductionOpName, type,
|
||||
loc, isByRef, genCombinerCB,
|
||||
genInitValueCB);
|
||||
}
|
||||
|
||||
bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
|
||||
if (forceByrefReduction)
|
||||
return true;
|
||||
|
||||
if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) &&
|
||||
!fir::isa_derived(fir::unwrapRefType(reductionType)))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) {
|
||||
if (forceByrefReduction)
|
||||
return true;
|
||||
|
||||
@@ -600,10 +630,7 @@ static bool doReductionByRef(mlir::Value reductionVar) {
|
||||
mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
|
||||
reductionVar = declare.getMemref();
|
||||
|
||||
if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
return doReductionByRef(reductionVar.getType());
|
||||
}
|
||||
|
||||
template <typename OpType, typename RedOperatorListTy>
|
||||
@@ -614,6 +641,8 @@ bool ReductionProcessor::processReductionArguments(
|
||||
llvm::SmallVectorImpl<bool> &reduceVarByRef,
|
||||
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
|
||||
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
|
||||
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
|
||||
|
||||
if constexpr (std::is_same_v<RedOperatorListTy,
|
||||
omp::clause::ReductionOperatorList>) {
|
||||
// For OpenMP reduction clauses, check if the reduction operator is
|
||||
@@ -627,7 +656,13 @@ bool ReductionProcessor::processReductionArguments(
|
||||
std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
|
||||
if (!ReductionProcessor::supportedIntrinsicProcReduction(
|
||||
*reductionIntrinsic)) {
|
||||
return false;
|
||||
// If not an intrinsic is has to be a custom reduction op, and should
|
||||
// be available in the module.
|
||||
semantics::Symbol *sym = reductionIntrinsic->v.sym();
|
||||
mlir::ModuleOp module = builder.getModule();
|
||||
auto decl = module.lookupSymbol<OpType>(getRealName(sym).ToString());
|
||||
if (!decl)
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
@@ -637,7 +672,6 @@ bool ReductionProcessor::processReductionArguments(
|
||||
|
||||
// Reduction variable processing common to both intrinsic operators and
|
||||
// procedure designators
|
||||
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
|
||||
mlir::OpBuilder::InsertPoint dcIP;
|
||||
constexpr bool isDoConcurrent =
|
||||
std::is_same_v<OpType, fir::DeclareReductionOp>;
|
||||
@@ -741,7 +775,13 @@ bool ReductionProcessor::processReductionArguments(
|
||||
&redOperator.u)) {
|
||||
if (!ReductionProcessor::supportedIntrinsicProcReduction(
|
||||
*reductionIntrinsic)) {
|
||||
TODO(currentLocation, "Unsupported intrinsic proc reduction");
|
||||
// Custom reductions we can just add to the symbols without
|
||||
// generating the declare reduction op.
|
||||
semantics::Symbol *sym = reductionIntrinsic->v.sym();
|
||||
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
|
||||
builder.getContext(), sym->name().ToString()));
|
||||
++idx;
|
||||
continue;
|
||||
}
|
||||
redId = getReductionType(*reductionIntrinsic);
|
||||
reductionName =
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
namespace flangomp {
|
||||
#define GEN_PASS_DEF_MARKDECLARETARGETPASS
|
||||
@@ -31,9 +32,93 @@ namespace {
|
||||
class MarkDeclareTargetPass
|
||||
: public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
|
||||
|
||||
void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
|
||||
mlir::omp::DeclareTargetCaptureClause parentCapClause,
|
||||
bool parentAutomap, mlir::Operation *currOp,
|
||||
struct ParentInfo {
|
||||
mlir::omp::DeclareTargetDeviceType devTy;
|
||||
mlir::omp::DeclareTargetCaptureClause capClause;
|
||||
bool automap;
|
||||
};
|
||||
|
||||
void processSymbolRef(mlir::SymbolRefAttr symRef, ParentInfo parentInfo,
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
|
||||
if (auto currFOp =
|
||||
getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
|
||||
auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
|
||||
currFOp.getOperation());
|
||||
|
||||
if (current.isDeclareTarget()) {
|
||||
auto currentDt = current.getDeclareTargetDeviceType();
|
||||
|
||||
// Found the same function twice, with different device_types,
|
||||
// mark as Any as it belongs to both
|
||||
if (currentDt != parentInfo.devTy &&
|
||||
currentDt != mlir::omp::DeclareTargetDeviceType::any) {
|
||||
current.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
|
||||
current.getDeclareTargetCaptureClause(),
|
||||
current.getDeclareTargetAutomap());
|
||||
}
|
||||
} else {
|
||||
current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
|
||||
parentInfo.automap);
|
||||
}
|
||||
|
||||
markNestedFuncs(parentInfo, currFOp, visited);
|
||||
}
|
||||
}
|
||||
|
||||
void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
|
||||
ParentInfo parentInfo,
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
|
||||
if (!symRefs)
|
||||
return;
|
||||
|
||||
for (auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
|
||||
if (auto declareReductionOp =
|
||||
getOperation().lookupSymbol<mlir::omp::DeclareReductionOp>(
|
||||
symRef)) {
|
||||
markNestedFuncs(parentInfo, declareReductionOp, visited);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
processReductionClauses(mlir::Operation *op, ParentInfo parentInfo,
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
|
||||
llvm::TypeSwitch<mlir::Operation &>(*op)
|
||||
.Case([&](mlir::omp::LoopOp op) {
|
||||
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Case([&](mlir::omp::ParallelOp op) {
|
||||
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Case([&](mlir::omp::SectionsOp op) {
|
||||
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Case([&](mlir::omp::SimdOp op) {
|
||||
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Case([&](mlir::omp::TargetOp op) {
|
||||
processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Case([&](mlir::omp::TaskgroupOp op) {
|
||||
processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Case([&](mlir::omp::TaskloopOp op) {
|
||||
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
|
||||
processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Case([&](mlir::omp::TaskOp op) {
|
||||
processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Case([&](mlir::omp::TeamsOp op) {
|
||||
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Case([&](mlir::omp::WsloopOp op) {
|
||||
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
|
||||
})
|
||||
.Default([](mlir::Operation &) {});
|
||||
}
|
||||
|
||||
void markNestedFuncs(ParentInfo parentInfo, mlir::Operation *currOp,
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
|
||||
if (visited.contains(currOp))
|
||||
return;
|
||||
@@ -43,33 +128,10 @@ class MarkDeclareTargetPass
|
||||
if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
|
||||
if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
|
||||
callOp.getCallableForCallee())) {
|
||||
if (auto currFOp =
|
||||
getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
|
||||
auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
|
||||
currFOp.getOperation());
|
||||
|
||||
if (current.isDeclareTarget()) {
|
||||
auto currentDt = current.getDeclareTargetDeviceType();
|
||||
|
||||
// Found the same function twice, with different device_types,
|
||||
// mark as Any as it belongs to both
|
||||
if (currentDt != parentDevTy &&
|
||||
currentDt != mlir::omp::DeclareTargetDeviceType::any) {
|
||||
current.setDeclareTarget(
|
||||
mlir::omp::DeclareTargetDeviceType::any,
|
||||
current.getDeclareTargetCaptureClause(),
|
||||
current.getDeclareTargetAutomap());
|
||||
}
|
||||
} else {
|
||||
current.setDeclareTarget(parentDevTy, parentCapClause,
|
||||
parentAutomap);
|
||||
}
|
||||
|
||||
markNestedFuncs(parentDevTy, parentCapClause, parentAutomap,
|
||||
currFOp, visited);
|
||||
}
|
||||
processSymbolRef(symRef, parentInfo, visited);
|
||||
}
|
||||
}
|
||||
processReductionClauses(op, parentInfo, visited);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -82,10 +144,10 @@ class MarkDeclareTargetPass
|
||||
functionOp.getOperation());
|
||||
if (declareTargetOp.isDeclareTarget()) {
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
|
||||
markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
|
||||
declareTargetOp.getDeclareTargetCaptureClause(),
|
||||
declareTargetOp.getDeclareTargetAutomap(), functionOp,
|
||||
visited);
|
||||
ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
|
||||
declareTargetOp.getDeclareTargetCaptureClause(),
|
||||
declareTargetOp.getDeclareTargetAutomap()};
|
||||
markNestedFuncs(parentInfo, functionOp, visited);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,12 +158,13 @@ class MarkDeclareTargetPass
|
||||
// the contents of the device clause
|
||||
getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
|
||||
markNestedFuncs(
|
||||
/*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
|
||||
/*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to,
|
||||
/*parentAutomap=*/false, tarOp, visited);
|
||||
ParentInfo parentInfo = {
|
||||
/*devTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
|
||||
/*capClause=*/mlir::omp::DeclareTargetCaptureClause::to,
|
||||
/*automap=*/false,
|
||||
};
|
||||
markNestedFuncs(parentInfo, tarOp, visited);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
! This test checks lowering of OpenMP declare reduction with non-trivial types
|
||||
|
||||
! RUN: not %flang_fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
|
||||
|
||||
module mymod
|
||||
type advancedtype
|
||||
integer(4)::myarray(10)
|
||||
integer(4)::val
|
||||
integer(4)::otherval
|
||||
end type advancedtype
|
||||
!CHECK: not yet implemented: declare reduction currently only supports trival types or derived types containing trivial types
|
||||
!$omp declare reduction(myreduction: advancedtype: omp_out = omp_in) initializer(omp_priv = omp_orig)
|
||||
end module mymod
|
||||
|
||||
program mymaxtest
|
||||
use mymod
|
||||
|
||||
end program
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
! This test checks lowering of OpenMP declare reduction Directive, with initialization
|
||||
! via a subroutine. This functionality is currently not implemented.
|
||||
|
||||
! RUN: not %flang_fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
|
||||
|
||||
!CHECK: not yet implemented: OpenMPDeclareReductionConstruct
|
||||
subroutine initme(x,n)
|
||||
integer x,n
|
||||
x=n
|
||||
end subroutine initme
|
||||
|
||||
function func(x, n, init)
|
||||
integer func
|
||||
integer x(n)
|
||||
integer res
|
||||
interface
|
||||
subroutine initme(x,n)
|
||||
integer x,n
|
||||
end subroutine initme
|
||||
end interface
|
||||
!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,0))
|
||||
res=init
|
||||
!$omp simd reduction(red_add:res)
|
||||
do i=1,n
|
||||
res=res+x(i)
|
||||
enddo
|
||||
func=res
|
||||
end function func
|
||||
@@ -1,10 +0,0 @@
|
||||
! This test checks lowering of OpenMP declare reduction Directive.
|
||||
|
||||
! RUN: not %flang_fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
|
||||
|
||||
subroutine declare_red()
|
||||
integer :: my_var
|
||||
!CHECK: not yet implemented: OpenMPDeclareReductionConstruct
|
||||
!$omp declare reduction (my_red : integer : omp_out = omp_in) initializer (omp_priv = 0)
|
||||
my_var = 0
|
||||
end subroutine declare_red
|
||||
@@ -0,0 +1,37 @@
|
||||
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
|
||||
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -fopenmp-is-device %s -o - | FileCheck %s
|
||||
|
||||
program main
|
||||
use, intrinsic :: iso_c_binding
|
||||
implicit none
|
||||
interface
|
||||
subroutine myinit(priv, orig) bind(c,name="myinit")
|
||||
use, intrinsic :: iso_c_binding
|
||||
implicit none
|
||||
integer::priv, orig
|
||||
end subroutine myinit
|
||||
|
||||
function mycombine(lhs, rhs) bind(c,name="mycombine")
|
||||
use, intrinsic :: iso_c_binding
|
||||
implicit none
|
||||
integer::lhs, rhs, mycombine
|
||||
end function mycombine
|
||||
end interface
|
||||
!$omp declare reduction(myreduction:integer:omp_out = mycombine(omp_out, omp_in)) initializer(myinit(omp_priv, omp_orig))
|
||||
|
||||
integer :: i, s, a(10)
|
||||
!$omp target
|
||||
s = 0
|
||||
!$omp do reduction(myreduction:s)
|
||||
do i = 1, 10
|
||||
s = mycombine(s, a(i))
|
||||
enddo
|
||||
!$omp end do
|
||||
!$omp end target
|
||||
end program main
|
||||
|
||||
!CHECK: func.func {{.*}} @myinit(!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
|
||||
!CHECK-LABEL: func.func {{.*}} @mycombine(!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK-SAME: {{.*}}, omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to), automap = false>{{.*}}
|
||||
|
||||
112
flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
Normal file
112
flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90
Normal file
@@ -0,0 +1,112 @@
|
||||
! This test checks lowering of OpenMP declare reduction Directive, with initialization
|
||||
! via a subroutine. This functionality is currently not implemented.
|
||||
|
||||
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
|
||||
module maxtype_mod
|
||||
implicit none
|
||||
|
||||
type maxtype
|
||||
integer::sumval
|
||||
integer::maxval
|
||||
end type maxtype
|
||||
|
||||
contains
|
||||
|
||||
subroutine initme(x,n)
|
||||
type(maxtype) :: x,n
|
||||
x%sumval=0
|
||||
x%maxval=0
|
||||
end subroutine initme
|
||||
|
||||
function mycombine(lhs, rhs)
|
||||
type(maxtype) :: lhs, rhs
|
||||
type(maxtype) :: mycombine
|
||||
mycombine%sumval = lhs%sumval + rhs%sumval
|
||||
mycombine%maxval = max(lhs%maxval, rhs%maxval)
|
||||
end function mycombine
|
||||
|
||||
function func(x, n, init)
|
||||
type(maxtype) :: func
|
||||
integer :: n, i
|
||||
type(maxtype) :: x(n)
|
||||
type(maxtype) :: init
|
||||
type(maxtype) :: res
|
||||
!$omp declare reduction(red_add_max:maxtype:omp_out=mycombine(omp_out,omp_in)) initializer(initme(omp_priv,omp_orig))
|
||||
res=init
|
||||
!$omp simd reduction(red_add_max:res)
|
||||
do i=1,n
|
||||
res=mycombine(res,x(i))
|
||||
enddo
|
||||
func=res
|
||||
end function func
|
||||
|
||||
end module maxtype_mod
|
||||
!CHECK: omp.declare_reduction @red_add_max : [[MAXTYPE:.*]] init {
|
||||
!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: [[MAXTYPE]]):
|
||||
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]]
|
||||
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]]
|
||||
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
|
||||
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]>
|
||||
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> ()
|
||||
!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
|
||||
!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : [[MAXTYPE]])
|
||||
!CHECK: } combiner {
|
||||
!CHECK: ^bb0(%[[LHS_ARG:.*]]: [[MAXTYPE]], %[[RHS_ARG:.*]]: [[MAXTYPE]]):
|
||||
!CHECK: %[[RESULT:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = ".result"}
|
||||
!CHECK: %[[OMP_OUT:.*]] = fir.alloca [[MAXTYPE]]
|
||||
!CHECK: %[[OMP_IN:.*]] = fir.alloca [[MAXTYPE]]
|
||||
!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]>
|
||||
!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]>
|
||||
!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]]
|
||||
!CHECK: fir.save_result %[[COMBINE_RESULT]] to %[[RESULT]] : [[MAXTYPE]], !fir.ref<[[MAXTYPE]]>
|
||||
!CHECK: %[[TMPRESULT:.*]]:2 = hlfir.declare %[[RESULT]] {uniq_name = ".tmp.func_result"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: %false = arith.constant false
|
||||
!CHECK: %[[EXPRRESULT:.*]] = hlfir.as_expr %[[TMPRESULT]]#0 move %false : (!fir.ref<[[MAXTYPE]]>, i1) -> !hlfir.expr<[[MAXTYPE]]>
|
||||
!CHECK: %[[ASSOCIATE:.*]]:3 = hlfir.associate %[[EXPRRESULT]] {adapt.valuebyref} : (!hlfir.expr<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>, i1)
|
||||
!CHECK: %[[RESULT_VAL:.*]] = fir.load %[[ASSOCIATE]]#0 : !fir.ref<[[MAXTYPE]]>
|
||||
!CHECK: hlfir.end_associate %[[ASSOCIATE]]#1, %[[ASSOCIATE]]#2 : !fir.ref<[[MAXTYPE]]>, i1
|
||||
!CHECK: omp.yield(%[[RESULT_VAL]] : [[MAXTYPE]])
|
||||
!CHECK: }
|
||||
|
||||
!CHECK: func.func @_QMmaxtype_modPinitme(%[[X_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %[[N_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) {
|
||||
!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
|
||||
!CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %[[N_ARG]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QMmaxtype_modFinitmeEn"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_ARG]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QMmaxtype_modFinitmeEx"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: %[[ZERO_0:.*]] = arith.constant 0 : i32
|
||||
!CHECK: %[[X_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
|
||||
!CHECK: hlfir.assign %[[ZERO_0]] to %[[X_DESIGNATE_SUMVAL]] : i32, !fir.ref<i32>
|
||||
!CHECK: %[[ZERO_1:.*]] = arith.constant 0 : i32
|
||||
!CHECK: %[[X_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
|
||||
!CHECK: hlfir.assign %[[ZERO_1]] to %[[X_DESIGNATE_MAXVAL]] : i32, !fir.ref<i32>
|
||||
!CHECK: return
|
||||
!CHECK: }
|
||||
|
||||
|
||||
!CHECK: func.func @_QMmaxtype_modPmycombine(%[[LHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "lhs"}, %[[RHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "rhs"}) -> [[MAXTYPE]] {
|
||||
!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
|
||||
!CHECK: %[[LHS_DECL:.*]]:2 = hlfir.declare %[[LHS]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QMmaxtype_modFmycombineElhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: %[[RESULT_ALLOC:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = "mycombine", uniq_name = "_QMmaxtype_modFmycombineEmycombine"}
|
||||
!CHECK: %[[RESULT_DECL:.*]]:2 = hlfir.declare %[[RESULT_ALLOC]] {uniq_name = "_QMmaxtype_modFmycombineEmycombine"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: %[[RHS_DECL:.*]]:2 = hlfir.declare %[[RHS]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QMmaxtype_modFmycombineErhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>)
|
||||
!CHECK: %[[LHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
|
||||
!CHECK: %[[LHS_SUMVAL:.*]] = fir.load %[[LHS_DESIGNATE_SUMVAL]] : !fir.ref<i32>
|
||||
!CHECK: %[[RHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
|
||||
!CHECK: %[[RHS_SUMVAL:.*]] = fir.load %[[RHS_DESIGNATE_SUMVAL]] : !fir.ref<i32>
|
||||
!CHECK: %[[SUM:.*]] = arith.addi %[[LHS_SUMVAL]], %[[RHS_SUMVAL]] : i32
|
||||
!CHECK: %[[RESULT_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
|
||||
!CHECK: hlfir.assign %[[SUM]] to %[[RESULT_DESIGNATE_SUMVAL]] : i32, !fir.ref<i32>
|
||||
!CHECK: %[[LHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
|
||||
!CHECK: %[[LHS_MAXVAL:.*]] = fir.load %[[LHS_DESIGNATE_MAXVAL]] : !fir.ref<i32>
|
||||
!CHECK: %[[RHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
|
||||
!CHECK: %[[RHS_MAXVAL:.*]] = fir.load %[[RHS_DESIGNATE_MAXVAL]] : !fir.ref<i32>
|
||||
!CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32
|
||||
!CHECK: %[[MAX_VAL:.*]] = arith.select %[[CMP]], %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32
|
||||
!CHECK: %[[RESULT_DESIGNAGE_MAXVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref<i32>
|
||||
!CHECK: hlfir.assign %[[MAX_VAL]] to %[[RESULT_DESIGNAGE_MAXVAL]] : i32, !fir.ref<i32>
|
||||
!CHECK: %[[RESULT:.*]] = fir.load %[[RESULT_DECL]]#0 : !fir.ref<[[MAXTYPE]]>
|
||||
!CHECK: return %[[RESULT]] : [[MAXTYPE]]
|
||||
!CHECK: }
|
||||
59
flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
Normal file
59
flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90
Normal file
@@ -0,0 +1,59 @@
|
||||
! This test checks lowering of OpenMP declare reduction Directive, with initialization
|
||||
! via a subroutine. This functionality is currently not implemented.
|
||||
|
||||
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
|
||||
|
||||
subroutine initme(x,n)
|
||||
integer x,n
|
||||
x=0
|
||||
end subroutine initme
|
||||
|
||||
function func(x, n, init)
|
||||
integer func
|
||||
integer x(n)
|
||||
integer res
|
||||
interface
|
||||
subroutine initme(x,n)
|
||||
integer x,n
|
||||
end subroutine initme
|
||||
end interface
|
||||
!CHECK: omp.declare_reduction @red_add : i32 init {
|
||||
!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32):
|
||||
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
|
||||
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
|
||||
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: fir.call @_QPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> ()
|
||||
!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<i32>
|
||||
!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : i32)
|
||||
!CHECK: } combiner {
|
||||
!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32):
|
||||
!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
|
||||
!CHECK: %[[OMP_IN:.*]] = fir.alloca i32
|
||||
!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
|
||||
!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
|
||||
!CHECK: omp.yield(%[[SUM]] : i32)
|
||||
!CHECK: }
|
||||
!CHECK: func.func @_QPinitme(%[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
|
||||
!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
|
||||
!CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %[[N]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QFinitmeEn"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] dummy_scope %[[OMP_OUT]] arg 1 {uniq_name = "_QFinitmeEx"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
|
||||
!CHECK: hlfir.assign %[[CONST_0]] to %[[X_DECL]]#0 : i32, !fir.ref<i32>
|
||||
!CHECK: return
|
||||
!CHECK: }
|
||||
!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,omp_orig))
|
||||
res=init
|
||||
!$omp simd reduction(red_add:res)
|
||||
do i=1,n
|
||||
res=res+x(i)
|
||||
enddo
|
||||
func=res
|
||||
end function func
|
||||
33
flang/test/Lower/OpenMP/omp-declare-reduction.f90
Normal file
33
flang/test/Lower/OpenMP/omp-declare-reduction.f90
Normal file
@@ -0,0 +1,33 @@
|
||||
! This test checks lowering of OpenMP declare reduction Directive.
|
||||
|
||||
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
|
||||
|
||||
subroutine declare_red()
|
||||
integer :: my_var
|
||||
!CHECK: omp.declare_reduction @my_red : i32 init {
|
||||
!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32):
|
||||
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
|
||||
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
|
||||
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
|
||||
!CHECK: omp.yield(%[[CONST_0]] : i32)
|
||||
!CHECK: } combiner {
|
||||
!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32):
|
||||
!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
|
||||
!CHECK: %[[OMP_IN:.*]] = fir.alloca i32
|
||||
!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
|
||||
!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref<i32>
|
||||
!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32
|
||||
!CHECK: omp.yield(%[[SUM]] : i32)
|
||||
!CHECK: }
|
||||
|
||||
!$omp declare reduction (my_red : integer : omp_out = omp_out + omp_in) initializer (omp_priv = 0)
|
||||
my_var = 0
|
||||
end subroutine declare_red
|
||||
@@ -1170,6 +1170,7 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
|
||||
template <typename T>
|
||||
static void
|
||||
mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
|
||||
llvm::IRBuilderBase &builder,
|
||||
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
|
||||
DenseMap<Value, llvm::Value *> &reductionVariableMap,
|
||||
unsigned i) {
|
||||
@@ -1180,8 +1181,17 @@ mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
|
||||
|
||||
mlir::Value mlirSource = loop.getReductionVars()[i];
|
||||
llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
|
||||
assert(llvmSource && "lookup reduction var");
|
||||
moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
|
||||
llvm::Value *origVal = llvmSource;
|
||||
// If a non-pointer value is expected, load the value from the source pointer.
|
||||
if (!isa<LLVM::LLVMPointerType>(
|
||||
reduction.getInitializerMoldArg().getType()) &&
|
||||
isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
|
||||
origVal =
|
||||
builder.CreateLoad(moduleTranslation.convertType(
|
||||
reduction.getInitializerMoldArg().getType()),
|
||||
llvmSource, "omp_orig");
|
||||
}
|
||||
moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
|
||||
|
||||
if (entry.getNumArguments() > 1) {
|
||||
llvm::Value *allocation =
|
||||
@@ -1254,7 +1264,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
|
||||
SmallVector<llvm::Value *, 1> phis;
|
||||
|
||||
// map block argument to initializer region
|
||||
mapInitializationArgs(op, moduleTranslation, reductionDecls,
|
||||
mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
|
||||
reductionVariableMap, i);
|
||||
|
||||
// TODO In some cases (specially on the GPU), the init regions may
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
! Basic offloading test with custom OpenMP reduction on derived type
|
||||
! REQUIRES: flang, amdgpu
|
||||
!
|
||||
! RUN: %libomptarget-compile-fortran-generic
|
||||
! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
|
||||
module maxtype_mod
|
||||
implicit none
|
||||
|
||||
type maxtype
|
||||
integer::sumval
|
||||
integer::maxval
|
||||
end type maxtype
|
||||
|
||||
contains
|
||||
|
||||
subroutine initme(x,n)
|
||||
type(maxtype) :: x,n
|
||||
x%sumval=0
|
||||
x%maxval=0
|
||||
end subroutine initme
|
||||
|
||||
function mycombine(lhs, rhs)
|
||||
type(maxtype) :: lhs, rhs
|
||||
type(maxtype) :: mycombine
|
||||
mycombine%sumval = lhs%sumval + rhs%sumval
|
||||
mycombine%maxval = max(lhs%maxval, rhs%maxval)
|
||||
end function mycombine
|
||||
|
||||
end module maxtype_mod
|
||||
|
||||
program main
|
||||
use maxtype_mod
|
||||
implicit none
|
||||
|
||||
integer :: n = 100
|
||||
integer :: i
|
||||
integer :: error = 0
|
||||
type(maxtype) :: x(100)
|
||||
type(maxtype) :: res
|
||||
integer :: expected_sum, expected_max
|
||||
|
||||
!$omp declare reduction(red_add_max:maxtype:omp_out=mycombine(omp_out,omp_in)) initializer(initme(omp_priv,omp_orig))
|
||||
|
||||
! Initialize array with test data
|
||||
do i = 1, n
|
||||
x(i)%sumval = i
|
||||
x(i)%maxval = i
|
||||
end do
|
||||
|
||||
! Initialize reduction variable
|
||||
res%sumval = 0
|
||||
res%maxval = 0
|
||||
|
||||
! Perform reduction in target region
|
||||
!$omp target parallel do map(to:x) reduction(red_add_max:res)
|
||||
do i = 1, n
|
||||
res = mycombine(res, x(i))
|
||||
end do
|
||||
!$omp end target parallel do
|
||||
|
||||
! Compute expected values
|
||||
expected_sum = 0
|
||||
expected_max = 0
|
||||
do i = 1, n
|
||||
expected_sum = expected_sum + i
|
||||
expected_max = max(expected_max, i)
|
||||
end do
|
||||
|
||||
! Check results
|
||||
if (res%sumval /= expected_sum) then
|
||||
error = 1
|
||||
endif
|
||||
|
||||
if (res%maxval /= expected_max) then
|
||||
error = 1
|
||||
endif
|
||||
|
||||
if (error == 0) then
|
||||
print *,"PASSED"
|
||||
else
|
||||
print *,"FAILED"
|
||||
endif
|
||||
|
||||
end program main
|
||||
|
||||
! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
|
||||
! CHECK: PASSED
|
||||
|
||||
Reference in New Issue
Block a user