[mlir][openacc] Introduce OpenACC dialect with parallel, data, loop operations

This patch introduces the OpenACC dialect with three operation defined
parallel, data and loop operations with custom parsing and printing.

OpenACC dialect RFC can be find here: https://llvm.discourse.group/t/rfc-openacc-dialect/546/2

Reviewed By: rriddle, kiranchandramohan

Differential Revision: https://reviews.llvm.org/D84268
This commit is contained in:
Valentin Clement
2020-08-12 11:48:24 -04:00
committed by clementval
parent 2916dd5669
commit 4225e7fa34
10 changed files with 1110 additions and 0 deletions

View File

@@ -3,6 +3,7 @@ add_subdirectory(AVX512)
add_subdirectory(GPU)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(Quant)
add_subdirectory(SCF)

View File

@@ -0,0 +1,9 @@
set(LLVM_TARGET_DEFINITIONS OpenACCOps.td)
mlir_tablegen(OpenACCOpsDialect.h.inc -gen-dialect-decls -dialect=acc)
mlir_tablegen(OpenACCOps.h.inc -gen-op-decls)
mlir_tablegen(OpenACCOps.cpp.inc -gen-op-defs)
mlir_tablegen(OpenACCOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpenACCOpsEnums.cpp.inc -gen-enum-defs)
add_mlir_doc(OpenACCOps -gen-dialect-doc OpenACCDialect Dialects/)
add_public_tablegen_target(MLIROpenACCOpsIncGen)

View File

@@ -0,0 +1,44 @@
//===- OpenACC.h - MLIR OpenACC Dialect -------------------------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// ============================================================================
//
// This file declares the OpenACC dialect in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_OPENACC_OPENACC_H_
#define MLIR_DIALECT_OPENACC_OPENACC_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.h.inc"
namespace mlir {
namespace acc {
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenACC/OpenACCOps.h.inc"
#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.h.inc"
/// Enumeration used to encode the execution mapping on a loop construct.
/// They refer directly to the OpenACC 3.0 standard:
/// 2.9.2. gang
/// 2.9.3. worker
/// 2.9.4. vector
/// 2.9.5. seq
///
/// Value can be combined bitwise to reflect the mapping applied to the
/// construct. e.g. `acc.loop gang vector`, the `gang` and `vector` could be
/// combined and the final mapping value would be 5 (4 & 1).
enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4, SEQ = 8 };
} // end namespace acc
} // end namespace mlir
#endif // MLIR_DIALECT_OPENACC_OPENACC_H_

View File

@@ -0,0 +1,273 @@
//===- OpenACC.td - OpenACC operation definitions ----------*- tablegen -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// =============================================================================
//
// Defines MLIR OpenACC operations.
//
//===----------------------------------------------------------------------===//
#ifndef OPENACC_OPS
#define OPENACC_OPS
include "mlir/IR/OpBase.td"
def OpenACC_Dialect : Dialect {
let name = "acc";
let summary = "An OpenACC dialect for MLIR.";
let description = [{
This dialect models the construct from the OpenACC 3.0 directive language.
}];
let cppNamespace = "acc";
}
// Base class for OpenACC dialect ops.
class OpenACC_Op<string mnemonic, list<OpTrait> traits = []> :
Op<OpenACC_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
// Reduction operation enumeration
def OpenACC_ReductionOpAdd : StrEnumAttrCase<"redop_add">;
def OpenACC_ReductionOpMul : StrEnumAttrCase<"redop_mul">;
def OpenACC_ReductionOpMax : StrEnumAttrCase<"redop_max">;
def OpenACC_ReductionOpMin : StrEnumAttrCase<"redop_min">;
def OpenACC_ReductionOpAnd : StrEnumAttrCase<"redop_and">;
def OpenACC_ReductionOpOr : StrEnumAttrCase<"redop_or">;
def OpenACC_ReductionOpXor : StrEnumAttrCase<"redop_xor">;
def OpenACC_ReductionOpLogEqv : StrEnumAttrCase<"redop_leqv">;
def OpenACC_ReductionOpLogNeqv : StrEnumAttrCase<"redop_lneqv">;
def OpenACC_ReductionOpLogAnd : StrEnumAttrCase<"redop_land">;
def OpenACC_ReductionOpLogOr : StrEnumAttrCase<"redop_lor">;
def OpenACC_ReductionOpAttr : StrEnumAttr<"ReductionOpAttr",
"built-in reduction operations supported by OpenACC",
[OpenACC_ReductionOpAdd, OpenACC_ReductionOpMul, OpenACC_ReductionOpMax,
OpenACC_ReductionOpMin, OpenACC_ReductionOpAnd, OpenACC_ReductionOpOr,
OpenACC_ReductionOpXor, OpenACC_ReductionOpLogEqv,
OpenACC_ReductionOpLogNeqv, OpenACC_ReductionOpLogAnd,
OpenACC_ReductionOpLogOr
]> {
let cppNamespace = "::mlir::acc";
}
//===----------------------------------------------------------------------===//
// 2.5.1 parallel Construct
//===----------------------------------------------------------------------===//
def OpenACC_ParallelOp : OpenACC_Op<"parallel",
[AttrSizedOperandSegments]> {
let summary = "parallel construct";
let description = [{
The "acc.parallel" operation represents a parallel construct block. It has
one region to be executued in parallel on the current device.
Example:
```mlir
acc.parallel num_gangs(%c10) num_workers(%c10)
private(%c : memref<10xf32>) {
// parallel region
}
```
}];
let arguments = (ins Optional<Index>:$async,
Variadic<Index>:$waitOperands,
Optional<Index>:$numGangs,
Optional<Index>:$numWorkers,
Optional<Index>:$vectorLength,
Optional<I1>:$ifCond,
Optional<I1>:$selfCond,
OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
Variadic<AnyType>:$reductionOperands,
Variadic<AnyType>:$copyOperands,
Variadic<AnyType>:$copyinOperands,
Variadic<AnyType>:$copyoutOperands,
Variadic<AnyType>:$createOperands,
Variadic<AnyType>:$noCreateOperands,
Variadic<AnyType>:$presentOperands,
Variadic<AnyType>:$devicePtrOperands,
Variadic<AnyType>:$attachOperands,
Variadic<AnyType>:$gangPrivateOperands,
Variadic<AnyType>:$gangFirstPrivateOperands);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
static StringRef getAsyncKeyword() { return "async"; }
static StringRef getWaitKeyword() { return "wait"; }
static StringRef getNumGangsKeyword() { return "num_gangs"; }
static StringRef getNumWorkersKeyword() { return "num_workers"; }
static StringRef getVectorLengthKeyword() { return "vector_length"; }
static StringRef getIfKeyword() { return "if"; }
static StringRef getSelfKeyword() { return "self"; }
static StringRef getReductionKeyword() { return "reduction"; }
static StringRef getCopyKeyword() { return "copy"; }
static StringRef getCopyinKeyword() { return "copyin"; }
static StringRef getCopyoutKeyword() { return "copyout"; }
static StringRef getCreateKeyword() { return "create"; }
static StringRef getNoCreateKeyword() { return "no_create"; }
static StringRef getPresentKeyword() { return "present"; }
static StringRef getDevicePtrKeyword() { return "deviceptr"; }
static StringRef getAttachKeyword() { return "attach"; }
static StringRef getPrivateKeyword() { return "private"; }
static StringRef getFirstPrivateKeyword() { return "firstprivate"; }
}];
let verifier = ?;
}
//===----------------------------------------------------------------------===//
// 2.6.5 data Construct
//===----------------------------------------------------------------------===//
def OpenACC_DataOp : OpenACC_Op<"data",
[AttrSizedOperandSegments]> {
let summary = "data construct";
let description = [{
The "acc.data" operation represents a data construct. It defines vars to
be allocated in the current device memory for the duration of the region,
whether data should be copied from local memory to the current device
memory upon region entry , and copied from device memory to local memory
upon region exit.
Example:
```mlir
acc.data present(%a: memref<10x10xf32>, %b: memref<10x10xf32>,
%c: memref<10xf32>, %d: memref<10xf32>) {
// data region
}
```
}];
let arguments = (ins Variadic<AnyType>:$presentOperands,
Variadic<AnyType>:$copyOperands,
Variadic<AnyType>:$copyinOperands,
Variadic<AnyType>:$copyoutOperands,
Variadic<AnyType>:$createOperands,
Variadic<AnyType>:$noCreateOperands,
Variadic<AnyType>:$deleteOperands,
Variadic<AnyType>:$attachOperands,
Variadic<AnyType>:$detachOperands);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
static StringRef getAttachKeyword() { return "attach"; }
static StringRef getDeleteKeyword() { return "delete"; }
static StringRef getDetachKeyword() { return "detach"; }
static StringRef getCopyinKeyword() { return "copyin"; }
static StringRef getCopyKeyword() { return "copy"; }
static StringRef getCopyoutKeyword() { return "copyout"; }
static StringRef getCreateKeyword() { return "create"; }
static StringRef getNoCreateKeyword() { return "no_create"; }
static StringRef getPresentKeyword() { return "present"; }
}];
let verifier = ?;
}
def OpenACC_TerminatorOp : OpenACC_Op<"terminator", [Terminator]> {
let summary = "Generic terminator for OpenACC regions";
let description = [{
A terminator operation for regions that appear in the body of OpenACC
operation. Generic OpenACC construct regions are not expected to return any
value so the terminator takes no operands. The terminator op returns control
to the enclosing op.
}];
let verifier = ?;
let assemblyFormat = "attr-dict";
}
//===----------------------------------------------------------------------===//
// 2.9 loop Construct
//===----------------------------------------------------------------------===//
def OpenACC_LoopOp : OpenACC_Op<"loop",
[AttrSizedOperandSegments]> {
let summary = "loop construct";
let description = [{
The "acc.loop" operation represents the OpenACC loop construct.
Example:
```mlir
acc.loop gang vector {
scf.for %arg3 = %c0 to %c10 step %c1 {
scf.for %arg4 = %c0 to %c10 step %c1 {
scf.for %arg5 = %c0 to %c10 step %c1 {
// ... body
}
}
}
acc.yield
} attributes { collapse = 3 }
```
}];
let arguments = (ins OptionalAttr<I64Attr>:$collapse,
Variadic<AnyType>:$privateOperands,
OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
Variadic<AnyType>:$reductionOperands);
let results = (outs Variadic<AnyType>:$results);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
static StringRef getCollapseAttrName() { return "collapse"; }
static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
static StringRef getGangAttrName() { return "gang"; }
static StringRef getSeqAttrName() { return "seq"; }
static StringRef getVectorAttrName() { return "vector"; }
static StringRef getWorkerAttrName() { return "worker"; }
static StringRef getPrivateKeyword() { return "private"; }
static StringRef getReductionKeyword() { return "reduction"; }
}];
let verifier = ?;
}
// Yield operation for the acc.loop and acc.parallel operations.
def OpenACC_YieldOp : OpenACC_Op<"yield", [Terminator,
ParentOneOf<["ParallelOp, LoopOp"]>]> {
let summary = "Acc yield and termination operation";
let description = [{
`acc.yield` is a special terminator operation for block inside regions in
acc ops (parallel and loop). It returns values to the immediately enclosing
acc op.
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result",
[{ /* nothing to do */ }]>
];
let verifier = ?;
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
#endif // OPENACC_OPS

View File

@@ -20,6 +20,7 @@ DEFINE_SYM_KIND_RANGE(QUANTIZATION)
DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect
DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect
DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect
DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect

View File

@@ -22,6 +22,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/SCF.h"
@@ -38,6 +39,7 @@ namespace mlir {
// all the possible dialects to be made available to the context automatically.
inline void registerAllDialects() {
static bool init_once = []() {
registerDialect<acc::OpenACCDialect>();
registerDialect<AffineDialect>();
registerDialect<avx512::AVX512Dialect>();
registerDialect<gpu::GPUDialect>();

View File

@@ -3,6 +3,7 @@ add_subdirectory(AVX512)
add_subdirectory(GPU)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(Quant)
add_subdirectory(SCF)

View File

@@ -0,0 +1,13 @@
add_mlir_dialect_library(MLIROpenACC
IR/OpenACC.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
DEPENDS
MLIROpenACCOpsIncGen
LINK_LIBS PUBLIC
MLIRIR
)

View File

@@ -0,0 +1,579 @@
//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// =============================================================================
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
using namespace acc;
//===----------------------------------------------------------------------===//
// OpenACC operations
//===----------------------------------------------------------------------===//
void OpenACCDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
>();
}
template <typename StructureOp>
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
unsigned nRegions = 1) {
SmallVector<Region *, 2> regions;
for (unsigned i = 0; i < nRegions; ++i)
regions.push_back(state.addRegion());
for (Region *region : regions) {
if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
}
return success();
}
static ParseResult
parseOperandList(OpAsmParser &parser, StringRef keyword,
SmallVectorImpl<OpAsmParser::OperandType> &args,
SmallVectorImpl<Type> &argTypes, OperationState &result) {
if (failed(parser.parseOptionalKeyword(keyword)))
return success();
if (failed(parser.parseLParen()))
return failure();
// Exit early if the list is empty.
if (succeeded(parser.parseOptionalRParen()))
return success();
do {
OpAsmParser::OperandType arg;
Type type;
if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
return failure();
args.push_back(arg);
argTypes.push_back(type);
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen()))
return failure();
return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
result.operands);
}
static void printOperandList(Operation::operand_range operands,
StringRef listName, OpAsmPrinter &printer) {
if (operands.size() > 0) {
printer << " " << listName << "(";
llvm::interleaveComma(operands, printer, [&](Value op) {
printer << op << ": " << op.getType();
});
printer << ")";
}
}
static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
OpAsmParser::OperandType &operand,
Type type, bool &hasOptional,
OperationState &result) {
hasOptional = false;
if (succeeded(parser.parseOptionalKeyword(keyword))) {
hasOptional = true;
if (parser.parseLParen() || parser.parseOperand(operand) ||
parser.resolveOperand(operand, type, result.operands) ||
parser.parseRParen())
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
/// Parse acc.parallel operation
/// operation := `acc.parallel` `async` `(` index `)`?
/// `wait` `(` index-list `)`?
/// `num_gangs` `(` value `)`?
/// `num_workers` `(` value `)`?
/// `vector_length` `(` value `)`?
/// `if` `(` value `)`?
/// `self` `(` value `)`?
/// `reduction` `(` value-list `)`?
/// `copy` `(` value-list `)`?
/// `copyin` `(` value-list `)`?
/// `copyout` `(` value-list `)`?
/// `create` `(` value-list `)`?
/// `no_create` `(` value-list `)`?
/// `present` `(` value-list `)`?
/// `deviceptr` `(` value-list `)`?
/// `attach` `(` value-list `)`?
/// `private` `(` value-list `)`?
/// `firstprivate` `(` value-list `)`?
/// region attr-dict?
static ParseResult parseParallelOp(OpAsmParser &parser,
OperationState &result) {
Builder &builder = parser.getBuilder();
SmallVector<OpAsmParser::OperandType, 8> privateOperands,
firstprivateOperands, createOperands, copyOperands, copyinOperands,
copyoutOperands, noCreateOperands, presentOperands, devicePtrOperands,
attachOperands, waitOperands, reductionOperands;
SmallVector<Type, 8> operandTypes;
OpAsmParser::OperandType async, numGangs, numWorkers, vectorLength, ifCond,
selfCond;
bool hasAsync = false, hasNumGangs = false, hasNumWorkers = false;
bool hasVectorLength = false, hasIfCond = false, hasSelfCond = false;
Type indexType = builder.getIndexType();
Type i1Type = builder.getI1Type();
// async()?
if (failed(parseOptionalOperand(parser, ParallelOp::getAsyncKeyword(), async,
indexType, hasAsync, result)))
return failure();
// wait()?
if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
waitOperands, operandTypes, result)))
return failure();
// num_gangs(value)?
if (failed(parseOptionalOperand(parser, ParallelOp::getNumGangsKeyword(),
numGangs, indexType, hasNumGangs, result)))
return failure();
// num_workers(value)?
if (failed(parseOptionalOperand(parser, ParallelOp::getNumWorkersKeyword(),
numWorkers, indexType, hasNumWorkers,
result)))
return failure();
// vector_length(value)?
if (failed(parseOptionalOperand(parser, ParallelOp::getVectorLengthKeyword(),
vectorLength, indexType, hasVectorLength,
result)))
return failure();
// if()?
if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond,
i1Type, hasIfCond, result)))
return failure();
// self()?
if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(),
selfCond, i1Type, hasSelfCond, result)))
return failure();
// reduction()?
if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
reductionOperands, operandTypes, result)))
return failure();
// copy()?
if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
copyOperands, operandTypes, result)))
return failure();
// copyin()?
if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
copyinOperands, operandTypes, result)))
return failure();
// copyout()?
if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
copyoutOperands, operandTypes, result)))
return failure();
// create()?
if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
createOperands, operandTypes, result)))
return failure();
// no_create()?
if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
noCreateOperands, operandTypes, result)))
return failure();
// present()?
if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
presentOperands, operandTypes, result)))
return failure();
// deviceptr()?
if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
devicePtrOperands, operandTypes, result)))
return failure();
// attach()?
if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
attachOperands, operandTypes, result)))
return failure();
// private()?
if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
privateOperands, operandTypes, result)))
return failure();
// firstprivate()?
if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
firstprivateOperands, operandTypes, result)))
return failure();
// Parallel op region
if (failed(parseRegions<ParallelOp>(parser, result)))
return failure();
result.addAttribute(ParallelOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr(
{static_cast<int32_t>(hasAsync ? 1 : 0),
static_cast<int32_t>(waitOperands.size()),
static_cast<int32_t>(hasNumGangs ? 1 : 0),
static_cast<int32_t>(hasNumWorkers ? 1 : 0),
static_cast<int32_t>(hasVectorLength ? 1 : 0),
static_cast<int32_t>(hasIfCond ? 1 : 0),
static_cast<int32_t>(hasSelfCond ? 1 : 0),
static_cast<int32_t>(reductionOperands.size()),
static_cast<int32_t>(copyOperands.size()),
static_cast<int32_t>(copyinOperands.size()),
static_cast<int32_t>(copyoutOperands.size()),
static_cast<int32_t>(createOperands.size()),
static_cast<int32_t>(noCreateOperands.size()),
static_cast<int32_t>(presentOperands.size()),
static_cast<int32_t>(devicePtrOperands.size()),
static_cast<int32_t>(attachOperands.size()),
static_cast<int32_t>(privateOperands.size()),
static_cast<int32_t>(firstprivateOperands.size())}));
// Additional attributes
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();
return success();
}
static void print(OpAsmPrinter &printer, ParallelOp &op) {
printer << ParallelOp::getOperationName();
// async()?
if (auto async = op.async())
printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ")";
// wait()?
printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer);
// num_gangs()?
if (auto numGangs = op.numGangs())
printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
<< ")";
// num_workers()?
if (auto numWorkers = op.numWorkers())
printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
<< ")";
// if()?
if (Value ifCond = op.ifCond())
printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";
// self()?
if (Value selfCond = op.selfCond())
printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";
// reduction()?
printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(),
printer);
// copy()?
printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer);
// copyin()?
printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(),
printer);
// copyout()?
printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(),
printer);
// create()?
printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(),
printer);
// no_create()?
printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(),
printer);
// present()?
printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(),
printer);
// deviceptr()?
printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
printer);
// attach()?
printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(),
printer);
// private()?
printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
printer);
// firstprivate()?
printOperandList(op.gangFirstPrivateOperands(),
ParallelOp::getFirstPrivateKeyword(), printer);
printer.printRegion(op.region(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
printer.printOptionalAttrDictWithKeyword(
op.getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
}
//===----------------------------------------------------------------------===//
// DataOp
//===----------------------------------------------------------------------===//
/// Parse acc.data operation
/// operation := `acc.parallel` `present` `(` value-list `)`?
/// `copy` `(` value-list `)`?
/// `copyin` `(` value-list `)`?
/// `copyout` `(` value-list `)`?
/// `create` `(` value-list `)`?
/// `no_create` `(` value-list `)`?
/// `delete` `(` value-list `)`?
/// `attach` `(` value-list `)`?
/// `detach` `(` value-list `)`?
/// region attr-dict?
static ParseResult parseDataOp(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder();
SmallVector<OpAsmParser::OperandType, 8> presentOperands, copyOperands,
copyinOperands, copyoutOperands, createOperands, noCreateOperands,
deleteOperands, attachOperands, detachOperands;
SmallVector<Type, 8> operandsTypes;
// present(value-list)?
if (failed(parseOperandList(parser, DataOp::getPresentKeyword(),
presentOperands, operandsTypes, result)))
return failure();
// copy(value-list)?
if (failed(parseOperandList(parser, DataOp::getCopyKeyword(), copyOperands,
operandsTypes, result)))
return failure();
// copyin(value-list)?
if (failed(parseOperandList(parser, DataOp::getCopyinKeyword(),
copyinOperands, operandsTypes, result)))
return failure();
// copyout(value-list)?
if (failed(parseOperandList(parser, DataOp::getCopyoutKeyword(),
copyoutOperands, operandsTypes, result)))
return failure();
// create(value-list)?
if (failed(parseOperandList(parser, DataOp::getCreateKeyword(),
createOperands, operandsTypes, result)))
return failure();
// no_create(value-list)?
if (failed(parseOperandList(parser, DataOp::getCreateKeyword(),
noCreateOperands, operandsTypes, result)))
return failure();
// delete(value-list)?
if (failed(parseOperandList(parser, DataOp::getDeleteKeyword(),
deleteOperands, operandsTypes, result)))
return failure();
// attach(value-list)?
if (failed(parseOperandList(parser, DataOp::getAttachKeyword(),
attachOperands, operandsTypes, result)))
return failure();
// detach(value-list)?
if (failed(parseOperandList(parser, DataOp::getDetachKeyword(),
detachOperands, operandsTypes, result)))
return failure();
// Data op region
if (failed(parseRegions<ParallelOp>(parser, result)))
return failure();
result.addAttribute(
ParallelOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({static_cast<int32_t>(presentOperands.size()),
static_cast<int32_t>(copyOperands.size()),
static_cast<int32_t>(copyinOperands.size()),
static_cast<int32_t>(copyoutOperands.size()),
static_cast<int32_t>(createOperands.size()),
static_cast<int32_t>(noCreateOperands.size()),
static_cast<int32_t>(deleteOperands.size()),
static_cast<int32_t>(attachOperands.size()),
static_cast<int32_t>(detachOperands.size())}));
// Additional attributes
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();
return success();
}
static void print(OpAsmPrinter &printer, DataOp &op) {
printer << DataOp::getOperationName();
// present(value-list)?
printOperandList(op.presentOperands(), DataOp::getPresentKeyword(), printer);
// copy(value-list)?
printOperandList(op.copyOperands(), DataOp::getCopyKeyword(), printer);
// copyin(value-list)?
printOperandList(op.copyinOperands(), DataOp::getCopyinKeyword(), printer);
// copyout(value-list)?
printOperandList(op.copyoutOperands(), DataOp::getCopyoutKeyword(), printer);
// create(value-list)?
printOperandList(op.createOperands(), DataOp::getCreateKeyword(), printer);
// no_create(value-list)?
printOperandList(op.noCreateOperands(), DataOp::getNoCreateKeyword(),
printer);
// delete(value-list)?
printOperandList(op.deleteOperands(), DataOp::getDeleteKeyword(), printer);
// attach(value-list)?
printOperandList(op.attachOperands(), DataOp::getAttachKeyword(), printer);
// detach(value-list)?
printOperandList(op.detachOperands(), DataOp::getDetachKeyword(), printer);
printer.printRegion(op.region(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
printer.printOptionalAttrDictWithKeyword(
op.getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
}
//===----------------------------------------------------------------------===//
// LoopOp
//===----------------------------------------------------------------------===//
/// Parse acc.loop operation
/// operation := `acc.loop` `gang`? `vector`? `worker`? `seq`?
/// `private` `(` value-list `)`?
/// `reduction` `(` value-list `)`?
/// region attr-dict?
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder();
unsigned executionMapping = 0;
SmallVector<Type, 8> operandTypes;
SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
// gang?
if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangAttrName())))
executionMapping |= OpenACCExecMapping::GANG;
// vector?
if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorAttrName())))
executionMapping |= OpenACCExecMapping::VECTOR;
// worker?
if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerAttrName())))
executionMapping |= OpenACCExecMapping::WORKER;
// seq?
if (succeeded(parser.parseOptionalKeyword(LoopOp::getSeqAttrName())))
executionMapping |= OpenACCExecMapping::SEQ;
// private()?
if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
privateOperands, operandTypes, result)))
return failure();
// reduction()?
if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(),
reductionOperands, operandTypes, result)))
return failure();
if (executionMapping != 0)
result.addAttribute(LoopOp::getExecutionMappingAttrName(),
builder.getI64IntegerAttr(executionMapping));
// Parse optional results in case there is a reduce.
if (parser.parseOptionalArrowTypeList(result.types))
return failure();
if (failed(parseRegions<LoopOp>(parser, result)))
return failure();
result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr(
{static_cast<int32_t>(privateOperands.size()),
static_cast<int32_t>(reductionOperands.size())}));
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();
return success();
}
static void print(OpAsmPrinter &printer, LoopOp &op) {
printer << LoopOp::getOperationName();
unsigned execMapping =
(op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName()) !=
nullptr)
? op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName())
.getInt()
: 0;
if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG)
printer << " " << LoopOp::getGangAttrName();
if ((execMapping & OpenACCExecMapping::WORKER) == OpenACCExecMapping::WORKER)
printer << " " << LoopOp::getWorkerAttrName();
if ((execMapping & OpenACCExecMapping::VECTOR) == OpenACCExecMapping::VECTOR)
printer << " " << LoopOp::getVectorAttrName();
if ((execMapping & OpenACCExecMapping::SEQ) == OpenACCExecMapping::SEQ)
printer << " " << LoopOp::getSeqAttrName();
// private()?
printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);
// reduction()?
printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(),
printer);
if (op.getNumResults() > 0)
printer << " -> (" << op.getResultTypes() << ")";
printer.printRegion(op.region(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
printer.printOptionalAttrDictWithKeyword(
op.getAttrs(), {LoopOp::getExecutionMappingAttrName(),
LoopOp::getOperandSegmentSizeAttr()});
}
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"

View File

@@ -0,0 +1,187 @@
// RUN: mlir-opt %s | FileCheck %s
// Verify the printed output can be parsed.
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// Verify the generic form can be parsed.
// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> {
%c0 = constant 0 : index
%c10 = constant 10 : index
%c1 = constant 1 : index
acc.parallel async(%c1) {
acc.loop gang vector {
scf.for %arg3 = %c0 to %c10 step %c1 {
scf.for %arg4 = %c0 to %c10 step %c1 {
scf.for %arg5 = %c0 to %c10 step %c1 {
%a = load %A[%arg3, %arg5] : memref<10x10xf32>
%b = load %B[%arg5, %arg4] : memref<10x10xf32>
%cij = load %C[%arg3, %arg4] : memref<10x10xf32>
%p = mulf %a, %b : f32
%co = addf %cij, %p : f32
store %co, %C[%arg3, %arg4] : memref<10x10xf32>
}
}
}
acc.yield
} attributes { collapse = 3 }
acc.yield
}
return %C : memref<10x10xf32>
}
// CHECK-LABEL: func @compute1(
// CHECK-NEXT: %{{.*}} = constant 0 : index
// CHECK-NEXT: %{{.*}} = constant 10 : index
// CHECK-NEXT: %{{.*}} = constant 1 : index
// CHECK-NEXT: acc.parallel async(%{{.*}}) {
// CHECK-NEXT: acc.loop gang vector {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
// CHECK-NEXT: } attributes {collapse = 3 : i64}
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
// CHECK-NEXT: return %{{.*}} : memref<10x10xf32>
// CHECK-NEXT: }
func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> {
%c0 = constant 0 : index
%c10 = constant 10 : index
%c1 = constant 1 : index
acc.parallel {
acc.loop seq {
scf.for %arg3 = %c0 to %c10 step %c1 {
scf.for %arg4 = %c0 to %c10 step %c1 {
scf.for %arg5 = %c0 to %c10 step %c1 {
%a = load %A[%arg3, %arg5] : memref<10x10xf32>
%b = load %B[%arg5, %arg4] : memref<10x10xf32>
%cij = load %C[%arg3, %arg4] : memref<10x10xf32>
%p = mulf %a, %b : f32
%co = addf %cij, %p : f32
store %co, %C[%arg3, %arg4] : memref<10x10xf32>
}
}
}
acc.yield
}
acc.yield
}
return %C : memref<10x10xf32>
}
// CHECK-LABEL: func @compute2(
// CHECK-NEXT: %{{.*}} = constant 0 : index
// CHECK-NEXT: %{{.*}} = constant 10 : index
// CHECK-NEXT: %{{.*}} = constant 1 : index
// CHECK-NEXT: acc.parallel {
// CHECK-NEXT: acc.loop seq {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
// CHECK-NEXT: return %{{.*}} : memref<10x10xf32>
// CHECK-NEXT: }
func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> {
%lb = constant 0 : index
%st = constant 1 : index
%c10 = constant 10 : index
acc.data present(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) {
acc.parallel num_gangs(%c10) num_workers(%c10) private(%c : memref<10xf32>) {
acc.loop gang {
scf.for %x = %lb to %c10 step %st {
acc.loop worker {
scf.for %y = %lb to %c10 step %st {
%axy = load %a[%x, %y] : memref<10x10xf32>
%bxy = load %b[%x, %y] : memref<10x10xf32>
%tmp = addf %axy, %bxy : f32
store %tmp, %c[%y] : memref<10xf32>
}
acc.yield
}
acc.loop seq {
// for i = 0 to 10 step 1
// d[x] += c[i]
scf.for %i = %lb to %c10 step %st {
%ci = load %c[%i] : memref<10xf32>
%dx = load %d[%x] : memref<10xf32>
%z = addf %ci, %dx : f32
store %z, %d[%x] : memref<10xf32>
}
acc.yield
}
}
acc.yield
}
acc.yield
}
acc.terminator
}
return %d : memref<10xf32>
}
// CHECK: func @compute3({{.*}}: memref<10x10xf32>, {{.*}}: memref<10x10xf32>, [[ARG2:%.*]]: memref<10xf32>, {{.*}}: memref<10xf32>) -> memref<10xf32> {
// CHECK-NEXT: [[C0:%.*]] = constant 0 : index
// CHECK-NEXT: [[C1:%.*]] = constant 1 : index
// CHECK-NEXT: [[C10:%.*]] = constant 10 : index
// CHECK-NEXT: acc.data present(%{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10xf32>, %{{.*}}: memref<10xf32>) {
// CHECK-NEXT: acc.parallel num_gangs([[C10]]) num_workers([[C10]]) private([[ARG2]]: memref<10xf32>) {
// CHECK-NEXT: acc.loop gang {
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
// CHECK-NEXT: acc.loop worker {
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
// CHECK-NEXT: acc.loop seq {
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
// CHECK-NEXT: acc.terminator
// CHECK-NEXT: }
// CHECK-NEXT: return %{{.*}} : memref<10xf32>
// CHECK-NEXT: }