mirror of
https://github.com/intel/llvm.git
synced 2026-01-21 04:14:03 +08:00
[mlir][OpenMP] Added assemblyFormat for ParallelOp
This patch adds assemblyFormat for omp.parallel operation. Some existing functions have been altered to fit the custom directive in assemblyFormat. This has led to their callsites to get modified too, but those will be removed in later patches, when other operations get their assemblyFormat. All operations were not changed in one patch for ease of review. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D120157
This commit is contained in:
@@ -97,7 +97,17 @@ def ParallelOp : OpenMP_Op<"parallel", [
|
||||
let builders = [
|
||||
OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
|
||||
];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)`
|
||||
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
|
||||
| `allocate` `(`
|
||||
custom<AllocateAndAllocator>(
|
||||
$allocate_vars, type($allocate_vars),
|
||||
$allocators_vars, type($allocators_vars)
|
||||
) `)`
|
||||
| `proc_bind` `(` custom<ProcBindKind>($proc_bind_val) `)`
|
||||
) $region attr-dict
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -89,37 +89,55 @@ static ParseResult parseAllocateAndAllocator(
|
||||
SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
|
||||
SmallVectorImpl<Type> &typesAllocator) {
|
||||
|
||||
return parser.parseCommaSeparatedList(
|
||||
OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
|
||||
OpAsmParser::OperandType operand;
|
||||
Type type;
|
||||
if (parser.parseOperand(operand) || parser.parseColonType(type))
|
||||
return failure();
|
||||
operandsAllocator.push_back(operand);
|
||||
typesAllocator.push_back(type);
|
||||
if (parser.parseArrow())
|
||||
return failure();
|
||||
if (parser.parseOperand(operand) || parser.parseColonType(type))
|
||||
return failure();
|
||||
return parser.parseCommaSeparatedList([&]() -> ParseResult {
|
||||
OpAsmParser::OperandType operand;
|
||||
Type type;
|
||||
if (parser.parseOperand(operand) || parser.parseColonType(type))
|
||||
return failure();
|
||||
operandsAllocator.push_back(operand);
|
||||
typesAllocator.push_back(type);
|
||||
if (parser.parseArrow())
|
||||
return failure();
|
||||
if (parser.parseOperand(operand) || parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
operandsAllocate.push_back(operand);
|
||||
typesAllocate.push_back(type);
|
||||
return success();
|
||||
});
|
||||
operandsAllocate.push_back(operand);
|
||||
typesAllocate.push_back(type);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
/// Print allocate clause
|
||||
static void printAllocateAndAllocator(OpAsmPrinter &p,
|
||||
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
|
||||
OperandRange varsAllocate,
|
||||
OperandRange varsAllocator) {
|
||||
p << "allocate(";
|
||||
TypeRange typesAllocate,
|
||||
OperandRange varsAllocator,
|
||||
TypeRange typesAllocator) {
|
||||
for (unsigned i = 0; i < varsAllocate.size(); ++i) {
|
||||
std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
|
||||
p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
|
||||
p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
|
||||
std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
|
||||
p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
|
||||
p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
|
||||
}
|
||||
}
|
||||
|
||||
ParseResult parseProcBindKind(OpAsmParser &parser,
|
||||
omp::ClauseProcBindKindAttr &procBindAttr) {
|
||||
StringRef procBindStr;
|
||||
if (parser.parseKeyword(&procBindStr))
|
||||
return failure();
|
||||
if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) {
|
||||
procBindAttr =
|
||||
ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
void printProcBindKind(OpAsmPrinter &p, Operation *op,
|
||||
omp::ClauseProcBindKindAttr procBindAttr) {
|
||||
p << stringifyClauseProcBindKind(procBindAttr.getValue());
|
||||
}
|
||||
|
||||
LogicalResult ParallelOp::verify() {
|
||||
if (allocate_vars().size() != allocators_vars().size())
|
||||
return emitError(
|
||||
@@ -127,24 +145,6 @@ LogicalResult ParallelOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
void ParallelOp::print(OpAsmPrinter &p) {
|
||||
p << " ";
|
||||
if (auto ifCond = if_expr_var())
|
||||
p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
|
||||
|
||||
if (auto threads = num_threads_var())
|
||||
p << "num_threads(" << threads << " : " << threads.getType() << ") ";
|
||||
|
||||
if (!allocate_vars().empty())
|
||||
printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
|
||||
|
||||
if (auto bind = proc_bind_val())
|
||||
p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") ";
|
||||
|
||||
p << ' ';
|
||||
p.printRegion(getRegion());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Parser and printer for Linear Clause
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -626,9 +626,10 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
|
||||
return failure();
|
||||
clauseSegments[pos[threadLimitClause]] = 1;
|
||||
} else if (clauseKeyword == "allocate") {
|
||||
if (checkAllowed(allocateClause) ||
|
||||
if (checkAllowed(allocateClause) || parser.parseLParen() ||
|
||||
parseAllocateAndAllocator(parser, allocates, allocateTypes,
|
||||
allocators, allocatorTypes))
|
||||
allocators, allocatorTypes) ||
|
||||
parser.parseRParen())
|
||||
return failure();
|
||||
clauseSegments[pos[allocateClause]] = allocates.size();
|
||||
clauseSegments[pos[allocateClause] + 1] = allocators.size();
|
||||
@@ -803,32 +804,6 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parses a parallel operation.
|
||||
///
|
||||
/// operation ::= `omp.parallel` clause-list
|
||||
/// clause-list ::= clause | clause clause-list
|
||||
/// clause ::= if | num-threads | allocate | proc-bind
|
||||
///
|
||||
ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<ClauseType> clauses = {ifClause, numThreadsClause, allocateClause,
|
||||
procBindClause};
|
||||
|
||||
SmallVector<int> segments;
|
||||
|
||||
if (failed(parseClauses(parser, result, clauses, segments)))
|
||||
return failure();
|
||||
|
||||
result.addAttribute("operand_segment_sizes",
|
||||
parser.getBuilder().getI32VectorAttr(segments));
|
||||
|
||||
Region *body = result.addRegion();
|
||||
SmallVector<OpAsmParser::OperandType> regionArgs;
|
||||
SmallVector<Type> regionArgTypes;
|
||||
if (parser.parseRegion(*body, regionArgs, regionArgTypes))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Parser, printer and verifier for SectionsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -863,8 +838,12 @@ void SectionsOp::print(OpAsmPrinter &p) {
|
||||
if (!reduction_vars().empty())
|
||||
printReductionVarList(p, reductions(), reduction_vars());
|
||||
|
||||
if (!allocate_vars().empty())
|
||||
printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
|
||||
if (!allocate_vars().empty()) {
|
||||
printAllocateAndAllocator(p << "allocate(", *this, allocate_vars(),
|
||||
allocate_vars().getTypes(), allocators_vars(),
|
||||
allocators_vars().getTypes());
|
||||
p << ")";
|
||||
}
|
||||
|
||||
if (nowait())
|
||||
p << "nowait";
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
|
||||
|
||||
func @unknown_clause() {
|
||||
// expected-error@+1 {{invalid is not a valid clause}}
|
||||
// expected-error@+1 {{expected '{' to begin a region}}
|
||||
omp.parallel invalid {
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ func @unknown_clause() {
|
||||
// -----
|
||||
|
||||
func @if_once(%n : i1) {
|
||||
// expected-error@+1 {{at most one if clause can appear on the omp.parallel operation}}
|
||||
// expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
|
||||
omp.parallel if(%n : i1) if(%n : i1) {
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func @if_once(%n : i1) {
|
||||
// -----
|
||||
|
||||
func @num_threads_once(%n : si32) {
|
||||
// expected-error@+1 {{at most one num_threads clause can appear on the omp.parallel operation}}
|
||||
// expected-error@+1 {{`num_threads` clause can appear at most once in the expansion of the oilist directive}}
|
||||
omp.parallel num_threads(%n : si32) num_threads(%n : si32) {
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func @num_threads_once(%n : si32) {
|
||||
// -----
|
||||
|
||||
func @nowait_not_allowed(%n : memref<i32>) {
|
||||
// expected-error@+1 {{nowait is not a valid clause for the omp.parallel operation}}
|
||||
// expected-error@+1 {{expected '{' to begin a region}}
|
||||
omp.parallel nowait {}
|
||||
return
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func @nowait_not_allowed(%n : memref<i32>) {
|
||||
// -----
|
||||
|
||||
func @linear_not_allowed(%data_var : memref<i32>, %linear_var : i32) {
|
||||
// expected-error@+1 {{linear is not a valid clause for the omp.parallel operation}}
|
||||
// expected-error@+1 {{expected '{' to begin a region}}
|
||||
omp.parallel linear(%data_var = %linear_var : memref<i32>) {}
|
||||
return
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func @linear_not_allowed(%data_var : memref<i32>, %linear_var : i32) {
|
||||
// -----
|
||||
|
||||
func @schedule_not_allowed() {
|
||||
// expected-error@+1 {{schedule is not a valid clause for the omp.parallel operation}}
|
||||
// expected-error@+1 {{expected '{' to begin a region}}
|
||||
omp.parallel schedule(static) {}
|
||||
return
|
||||
}
|
||||
@@ -55,7 +55,7 @@ func @schedule_not_allowed() {
|
||||
// -----
|
||||
|
||||
func @collapse_not_allowed() {
|
||||
// expected-error@+1 {{collapse is not a valid clause for the omp.parallel operation}}
|
||||
// expected-error@+1 {{expected '{' to begin a region}}
|
||||
omp.parallel collapse(3) {}
|
||||
return
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func @collapse_not_allowed() {
|
||||
// -----
|
||||
|
||||
func @order_not_allowed() {
|
||||
// expected-error@+1 {{order is not a valid clause for the omp.parallel operation}}
|
||||
// expected-error@+1 {{expected '{' to begin a region}}
|
||||
omp.parallel order(concurrent) {}
|
||||
return
|
||||
}
|
||||
@@ -71,14 +71,14 @@ func @order_not_allowed() {
|
||||
// -----
|
||||
|
||||
func @ordered_not_allowed() {
|
||||
// expected-error@+1 {{ordered is not a valid clause for the omp.parallel operation}}
|
||||
// expected-error@+1 {{expected '{' to begin a region}}
|
||||
omp.parallel ordered(2) {}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @proc_bind_once() {
|
||||
// expected-error@+1 {{at most one proc_bind clause can appear on the omp.parallel operation}}
|
||||
// expected-error@+1 {{`proc_bind` clause can appear at most once in the expansion of the oilist directive}}
|
||||
omp.parallel proc_bind(close) proc_bind(spread) {
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
|
||||
// CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
|
||||
"omp.parallel"(%num_threads, %data_var, %data_var) ({
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
|
||||
}) {num_threads, allocate, operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
|
||||
|
||||
// CHECK: omp.barrier
|
||||
omp.barrier
|
||||
@@ -68,22 +68,22 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
|
||||
// CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
|
||||
"omp.parallel"(%if_cond, %data_var, %data_var) ({
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
|
||||
}) {if, allocate, operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
|
||||
|
||||
// test without allocate
|
||||
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
|
||||
"omp.parallel"(%if_cond, %num_threads) ({
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
|
||||
}) {if, num_threads, operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
|
||||
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
|
||||
}) {if, num_threads, allocate, operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
|
||||
|
||||
// test with multiple parameters for single variadic argument
|
||||
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
|
||||
"omp.parallel" (%data_var, %data_var) ({
|
||||
omp.terminator
|
||||
}) {operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
|
||||
}) {allocate, operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user