[mlir][OpenMP] Added assemblyFormat for ParallelOp

This patch adds assemblyFormat for omp.parallel operation.

Some existing functions have been altered to fit the custom directive
in assemblyFormat. This has led to their callsites to get modified too,
but those will be removed in later patches, when other operations get
their assemblyFormat. All operations were not changed in one patch for
ease of review.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D120157
This commit is contained in:
Shraiysh Vaishay
2022-02-19 10:00:03 +05:30
parent 357b18e282
commit 39151717db
4 changed files with 75 additions and 86 deletions

View File

@@ -97,7 +97,17 @@ def ParallelOp : OpenMP_Op<"parallel", [
let builders = [
OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)`
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
$allocate_vars, type($allocate_vars),
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ProcBindKind>($proc_bind_val) `)`
) $region attr-dict
}];
let hasVerifier = 1;
}

View File

@@ -89,37 +89,55 @@ static ParseResult parseAllocateAndAllocator(
SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
SmallVectorImpl<Type> &typesAllocator) {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
OpAsmParser::OperandType operand;
Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
operandsAllocator.push_back(operand);
typesAllocator.push_back(type);
if (parser.parseArrow())
return failure();
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
return parser.parseCommaSeparatedList([&]() -> ParseResult {
OpAsmParser::OperandType operand;
Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
operandsAllocator.push_back(operand);
typesAllocator.push_back(type);
if (parser.parseArrow())
return failure();
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
operandsAllocate.push_back(operand);
typesAllocate.push_back(type);
return success();
});
operandsAllocate.push_back(operand);
typesAllocate.push_back(type);
return success();
});
}
/// Print allocate clause
static void printAllocateAndAllocator(OpAsmPrinter &p,
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
OperandRange varsAllocate,
OperandRange varsAllocator) {
p << "allocate(";
TypeRange typesAllocate,
OperandRange varsAllocator,
TypeRange typesAllocator) {
for (unsigned i = 0; i < varsAllocate.size(); ++i) {
std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
}
}
ParseResult parseProcBindKind(OpAsmParser &parser,
omp::ClauseProcBindKindAttr &procBindAttr) {
StringRef procBindStr;
if (parser.parseKeyword(&procBindStr))
return failure();
if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) {
procBindAttr =
ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal);
return success();
}
return failure();
}
void printProcBindKind(OpAsmPrinter &p, Operation *op,
omp::ClauseProcBindKindAttr procBindAttr) {
p << stringifyClauseProcBindKind(procBindAttr.getValue());
}
LogicalResult ParallelOp::verify() {
if (allocate_vars().size() != allocators_vars().size())
return emitError(
@@ -127,24 +145,6 @@ LogicalResult ParallelOp::verify() {
return success();
}
void ParallelOp::print(OpAsmPrinter &p) {
p << " ";
if (auto ifCond = if_expr_var())
p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
if (auto threads = num_threads_var())
p << "num_threads(" << threads << " : " << threads.getType() << ") ";
if (!allocate_vars().empty())
printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
if (auto bind = proc_bind_val())
p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") ";
p << ' ';
p.printRegion(getRegion());
}
//===----------------------------------------------------------------------===//
// Parser and printer for Linear Clause
//===----------------------------------------------------------------------===//
@@ -626,9 +626,10 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
return failure();
clauseSegments[pos[threadLimitClause]] = 1;
} else if (clauseKeyword == "allocate") {
if (checkAllowed(allocateClause) ||
if (checkAllowed(allocateClause) || parser.parseLParen() ||
parseAllocateAndAllocator(parser, allocates, allocateTypes,
allocators, allocatorTypes))
allocators, allocatorTypes) ||
parser.parseRParen())
return failure();
clauseSegments[pos[allocateClause]] = allocates.size();
clauseSegments[pos[allocateClause] + 1] = allocators.size();
@@ -803,32 +804,6 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
return success();
}
/// Parses a parallel operation.
///
/// operation ::= `omp.parallel` clause-list
/// clause-list ::= clause | clause clause-list
/// clause ::= if | num-threads | allocate | proc-bind
///
ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<ClauseType> clauses = {ifClause, numThreadsClause, allocateClause,
procBindClause};
SmallVector<int> segments;
if (failed(parseClauses(parser, result, clauses, segments)))
return failure();
result.addAttribute("operand_segment_sizes",
parser.getBuilder().getI32VectorAttr(segments));
Region *body = result.addRegion();
SmallVector<OpAsmParser::OperandType> regionArgs;
SmallVector<Type> regionArgTypes;
if (parser.parseRegion(*body, regionArgs, regionArgTypes))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// Parser, printer and verifier for SectionsOp
//===----------------------------------------------------------------------===//
@@ -863,8 +838,12 @@ void SectionsOp::print(OpAsmPrinter &p) {
if (!reduction_vars().empty())
printReductionVarList(p, reductions(), reduction_vars());
if (!allocate_vars().empty())
printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
if (!allocate_vars().empty()) {
printAllocateAndAllocator(p << "allocate(", *this, allocate_vars(),
allocate_vars().getTypes(), allocators_vars(),
allocators_vars().getTypes());
p << ")";
}
if (nowait())
p << "nowait";

View File

@@ -1,7 +1,7 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
func @unknown_clause() {
// expected-error@+1 {{invalid is not a valid clause}}
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel invalid {
}
@@ -11,7 +11,7 @@ func @unknown_clause() {
// -----
func @if_once(%n : i1) {
// expected-error@+1 {{at most one if clause can appear on the omp.parallel operation}}
// expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
omp.parallel if(%n : i1) if(%n : i1) {
}
@@ -21,7 +21,7 @@ func @if_once(%n : i1) {
// -----
func @num_threads_once(%n : si32) {
// expected-error@+1 {{at most one num_threads clause can appear on the omp.parallel operation}}
// expected-error@+1 {{`num_threads` clause can appear at most once in the expansion of the oilist directive}}
omp.parallel num_threads(%n : si32) num_threads(%n : si32) {
}
@@ -31,7 +31,7 @@ func @num_threads_once(%n : si32) {
// -----
func @nowait_not_allowed(%n : memref<i32>) {
// expected-error@+1 {{nowait is not a valid clause for the omp.parallel operation}}
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel nowait {}
return
}
@@ -39,7 +39,7 @@ func @nowait_not_allowed(%n : memref<i32>) {
// -----
func @linear_not_allowed(%data_var : memref<i32>, %linear_var : i32) {
// expected-error@+1 {{linear is not a valid clause for the omp.parallel operation}}
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel linear(%data_var = %linear_var : memref<i32>) {}
return
}
@@ -47,7 +47,7 @@ func @linear_not_allowed(%data_var : memref<i32>, %linear_var : i32) {
// -----
func @schedule_not_allowed() {
// expected-error@+1 {{schedule is not a valid clause for the omp.parallel operation}}
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel schedule(static) {}
return
}
@@ -55,7 +55,7 @@ func @schedule_not_allowed() {
// -----
func @collapse_not_allowed() {
// expected-error@+1 {{collapse is not a valid clause for the omp.parallel operation}}
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel collapse(3) {}
return
}
@@ -63,7 +63,7 @@ func @collapse_not_allowed() {
// -----
func @order_not_allowed() {
// expected-error@+1 {{order is not a valid clause for the omp.parallel operation}}
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel order(concurrent) {}
return
}
@@ -71,14 +71,14 @@ func @order_not_allowed() {
// -----
func @ordered_not_allowed() {
// expected-error@+1 {{ordered is not a valid clause for the omp.parallel operation}}
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel ordered(2) {}
}
// -----
func @proc_bind_once() {
// expected-error@+1 {{at most one proc_bind clause can appear on the omp.parallel operation}}
// expected-error@+1 {{`proc_bind` clause can appear at most once in the expansion of the oilist directive}}
omp.parallel proc_bind(close) proc_bind(spread) {
}

View File

@@ -59,7 +59,7 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
// CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%num_threads, %data_var, %data_var) ({
omp.terminator
}) {operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
}) {num_threads, allocate, operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -68,22 +68,22 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
// CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%if_cond, %data_var, %data_var) ({
omp.terminator
}) {operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
}) {if, allocate, operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
}) {operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
}) {if, num_threads, operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
omp.terminator
}) {operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
}) {if, num_threads, allocate, operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
}) {operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
}) {allocate, operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
return
}