[MLIR][OpenMP]Add prescriptiveness-modifier support to granularity clauses of taskloop construct (#128477)

Added modifier(strict) support to the granularity(grainsize and num_tasks) clauses of taskloop construct.
This commit is contained in:
Kaviya Rajendiran
2025-02-27 21:56:32 +05:30
committed by GitHub
parent 4af2e36b1d
commit d39f4a1980
4 changed files with 150 additions and 17 deletions

View File

@@ -436,12 +436,11 @@ class OpenMP_GrainsizeClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
Optional<IntLikeType>:$grainsize
);
let arguments = (ins OptionalAttr<GrainsizeTypeAttr>:$grainsize_mod,
Optional<IntLikeType>:$grainsize);
let optAssemblyFormat = [{
`grainsize` `(` $grainsize `:` type($grainsize) `)`
`grainsize` `(` custom<GrainsizeClause>($grainsize_mod , $grainsize, type($grainsize)) `)`
}];
let description = [{
@@ -895,12 +894,11 @@ class OpenMP_NumTasksClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
Optional<IntLikeType>:$num_tasks
);
let arguments = (ins OptionalAttr<NumTasksTypeAttr>:$num_tasks_mod,
Optional<IntLikeType>:$num_tasks);
let optAssemblyFormat = [{
`num_tasks` `(` $num_tasks `:` type($num_tasks) `)`
`num_tasks` `(` custom<NumTasksClause>($num_tasks_mod , $num_tasks, type($num_tasks)) `)`
}];
let description = [{

View File

@@ -472,6 +472,99 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
p << stringifyClauseOrderKind(order.getValue());
}
template <typename ClauseTypeAttr, typename ClauseType>
static ParseResult
parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
std::optional<OpAsmParser::UnresolvedOperand> &operand,
Type &operandType,
std::optional<ClauseType> (*symbolizeClause)(StringRef),
StringRef clauseName) {
StringRef enumStr;
if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
if (parser.parseComma())
return failure();
} else {
return parser.emitError(parser.getCurrentLocation())
<< "invalid " << clauseName << " modifier : '" << enumStr << "'";
;
}
}
OpAsmParser::UnresolvedOperand var;
if (succeeded(parser.parseOperand(var))) {
operand = var;
} else {
return parser.emitError(parser.getCurrentLocation())
<< "expected " << clauseName << " operand";
}
if (operand.has_value()) {
if (parser.parseColonType(operandType))
return failure();
}
return success();
}
template <typename ClauseTypeAttr, typename ClauseType>
static void
printGranularityClause(OpAsmPrinter &p, Operation *op,
ClauseTypeAttr prescriptiveness, Value operand,
mlir::Type operandType,
StringRef (*stringifyClauseType)(ClauseType)) {
if (prescriptiveness)
p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
if (operand)
p << operand << ": " << operandType;
}
//===----------------------------------------------------------------------===//
// Parser and printer for grainsize Clause
//===----------------------------------------------------------------------===//
// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
static ParseResult
parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
Type &grainsizeType) {
return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
parser, grainsizeMod, grainsize, grainsizeType,
&symbolizeClauseGrainsizeType, "grainsize");
}
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
ClauseGrainsizeTypeAttr grainsizeMod,
Value grainsize, mlir::Type grainsizeType) {
printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
p, op, grainsizeMod, grainsize, grainsizeType,
&stringifyClauseGrainsizeType);
}
//===----------------------------------------------------------------------===//
// Parser and printer for num_tasks Clause
//===----------------------------------------------------------------------===//
// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
static ParseResult
parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
Type &numTasksType) {
return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
"num_tasks");
}
static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
ClauseNumTasksTypeAttr numTasksMod,
Value numTasks, mlir::Type numTasksType) {
printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
}
//===----------------------------------------------------------------------===//
// Parsers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
@@ -2593,15 +2686,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
const TaskloopOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms.
TaskloopOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
/*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.final, clauses.grainsizeMod, clauses.grainsize,
clauses.ifExpr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms),
clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
clauses.numTasks, clauses.priority, /*private_vars=*/{},
/*private_syms=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
}
SmallVector<Value> TaskloopOp::getAllReductionVars() {

View File

@@ -2064,6 +2064,30 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// -----
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testi64 = "test.i64"() : () -> (i64)
// expected-error @below {{invalid grainsize modifier : 'strict1'}}
omp.taskloop grainsize(strict1, %testi64: i64) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
omp.yield
}
}
return
}
// -----
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testi64 = "test.i64"() : () -> (i64)
// expected-error @below {{invalid num_tasks modifier : 'default'}}
omp.taskloop num_tasks(default, %testi64: i64) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
omp.yield
}
}
return
}
// -----
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// expected-error @below {{op nested in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
omp.taskloop {

View File

@@ -2417,6 +2417,22 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
}
}
// CHECK: omp.taskloop grainsize(strict, %{{[^:]+}}: i64) {
omp.taskloop grainsize(strict, %testi64: i64) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// CHECK: omp.yield
omp.yield
}
}
// CHECK: omp.taskloop num_tasks(strict, %{{[^:]+}}: i64) {
omp.taskloop num_tasks(strict, %testi64: i64) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// CHECK: omp.yield
omp.yield
}
}
// CHECK: omp.taskloop nogroup {
omp.taskloop nogroup {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {