mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[mlir][openacc] Introduce OpenACC dialect with parallel, data, loop operations
This patch introduces the OpenACC dialect with three operation defined parallel, data and loop operations with custom parsing and printing. OpenACC dialect RFC can be find here: https://llvm.discourse.group/t/rfc-openacc-dialect/546/2 Reviewed By: rriddle, kiranchandramohan Differential Revision: https://reviews.llvm.org/D84268
This commit is contained in:
committed by
clementval
parent
2916dd5669
commit
4225e7fa34
@@ -3,6 +3,7 @@ add_subdirectory(AVX512)
|
||||
add_subdirectory(GPU)
|
||||
add_subdirectory(Linalg)
|
||||
add_subdirectory(LLVMIR)
|
||||
add_subdirectory(OpenACC)
|
||||
add_subdirectory(OpenMP)
|
||||
add_subdirectory(Quant)
|
||||
add_subdirectory(SCF)
|
||||
|
||||
9
mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
Normal file
9
mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
set(LLVM_TARGET_DEFINITIONS OpenACCOps.td)
|
||||
mlir_tablegen(OpenACCOpsDialect.h.inc -gen-dialect-decls -dialect=acc)
|
||||
mlir_tablegen(OpenACCOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(OpenACCOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(OpenACCOpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpenACCOpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_mlir_doc(OpenACCOps -gen-dialect-doc OpenACCDialect Dialects/)
|
||||
add_public_tablegen_target(MLIROpenACCOpsIncGen)
|
||||
|
||||
44
mlir/include/mlir/Dialect/OpenACC/OpenACC.h
Normal file
44
mlir/include/mlir/Dialect/OpenACC/OpenACC.h
Normal file
@@ -0,0 +1,44 @@
|
||||
//===- OpenACC.h - MLIR OpenACC Dialect -------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
// ============================================================================
|
||||
//
|
||||
// This file declares the OpenACC dialect in MLIR.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_OPENACC_OPENACC_H_
|
||||
#define MLIR_DIALECT_OPENACC_OPENACC_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace acc {
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/OpenACC/OpenACCOps.h.inc"
|
||||
|
||||
#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.h.inc"
|
||||
|
||||
/// Enumeration used to encode the execution mapping on a loop construct.
|
||||
/// They refer directly to the OpenACC 3.0 standard:
|
||||
/// 2.9.2. gang
|
||||
/// 2.9.3. worker
|
||||
/// 2.9.4. vector
|
||||
/// 2.9.5. seq
|
||||
///
|
||||
/// Value can be combined bitwise to reflect the mapping applied to the
|
||||
/// construct. e.g. `acc.loop gang vector`, the `gang` and `vector` could be
|
||||
/// combined and the final mapping value would be 5 (4 & 1).
|
||||
enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4, SEQ = 8 };
|
||||
|
||||
} // end namespace acc
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_OPENACC_OPENACC_H_
|
||||
273
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Normal file
273
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Normal file
@@ -0,0 +1,273 @@
|
||||
//===- OpenACC.td - OpenACC operation definitions ----------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines MLIR OpenACC operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef OPENACC_OPS
|
||||
#define OPENACC_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def OpenACC_Dialect : Dialect {
|
||||
let name = "acc";
|
||||
|
||||
let summary = "An OpenACC dialect for MLIR.";
|
||||
|
||||
let description = [{
|
||||
This dialect models the construct from the OpenACC 3.0 directive language.
|
||||
}];
|
||||
|
||||
let cppNamespace = "acc";
|
||||
}
|
||||
|
||||
// Base class for OpenACC dialect ops.
|
||||
class OpenACC_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<OpenACC_Dialect, mnemonic, traits> {
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
// Reduction operation enumeration
|
||||
def OpenACC_ReductionOpAdd : StrEnumAttrCase<"redop_add">;
|
||||
def OpenACC_ReductionOpMul : StrEnumAttrCase<"redop_mul">;
|
||||
def OpenACC_ReductionOpMax : StrEnumAttrCase<"redop_max">;
|
||||
def OpenACC_ReductionOpMin : StrEnumAttrCase<"redop_min">;
|
||||
def OpenACC_ReductionOpAnd : StrEnumAttrCase<"redop_and">;
|
||||
def OpenACC_ReductionOpOr : StrEnumAttrCase<"redop_or">;
|
||||
def OpenACC_ReductionOpXor : StrEnumAttrCase<"redop_xor">;
|
||||
def OpenACC_ReductionOpLogEqv : StrEnumAttrCase<"redop_leqv">;
|
||||
def OpenACC_ReductionOpLogNeqv : StrEnumAttrCase<"redop_lneqv">;
|
||||
def OpenACC_ReductionOpLogAnd : StrEnumAttrCase<"redop_land">;
|
||||
def OpenACC_ReductionOpLogOr : StrEnumAttrCase<"redop_lor">;
|
||||
|
||||
def OpenACC_ReductionOpAttr : StrEnumAttr<"ReductionOpAttr",
|
||||
"built-in reduction operations supported by OpenACC",
|
||||
[OpenACC_ReductionOpAdd, OpenACC_ReductionOpMul, OpenACC_ReductionOpMax,
|
||||
OpenACC_ReductionOpMin, OpenACC_ReductionOpAnd, OpenACC_ReductionOpOr,
|
||||
OpenACC_ReductionOpXor, OpenACC_ReductionOpLogEqv,
|
||||
OpenACC_ReductionOpLogNeqv, OpenACC_ReductionOpLogAnd,
|
||||
OpenACC_ReductionOpLogOr
|
||||
]> {
|
||||
let cppNamespace = "::mlir::acc";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// 2.5.1 parallel Construct
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def OpenACC_ParallelOp : OpenACC_Op<"parallel",
|
||||
[AttrSizedOperandSegments]> {
|
||||
let summary = "parallel construct";
|
||||
let description = [{
|
||||
The "acc.parallel" operation represents a parallel construct block. It has
|
||||
one region to be executued in parallel on the current device.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
acc.parallel num_gangs(%c10) num_workers(%c10)
|
||||
private(%c : memref<10xf32>) {
|
||||
// parallel region
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Optional<Index>:$async,
|
||||
Variadic<Index>:$waitOperands,
|
||||
Optional<Index>:$numGangs,
|
||||
Optional<Index>:$numWorkers,
|
||||
Optional<Index>:$vectorLength,
|
||||
Optional<I1>:$ifCond,
|
||||
Optional<I1>:$selfCond,
|
||||
OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
|
||||
Variadic<AnyType>:$reductionOperands,
|
||||
Variadic<AnyType>:$copyOperands,
|
||||
Variadic<AnyType>:$copyinOperands,
|
||||
Variadic<AnyType>:$copyoutOperands,
|
||||
Variadic<AnyType>:$createOperands,
|
||||
Variadic<AnyType>:$noCreateOperands,
|
||||
Variadic<AnyType>:$presentOperands,
|
||||
Variadic<AnyType>:$devicePtrOperands,
|
||||
Variadic<AnyType>:$attachOperands,
|
||||
Variadic<AnyType>:$gangPrivateOperands,
|
||||
Variadic<AnyType>:$gangFirstPrivateOperands);
|
||||
|
||||
let regions = (region AnyRegion:$region);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getAsyncKeyword() { return "async"; }
|
||||
static StringRef getWaitKeyword() { return "wait"; }
|
||||
static StringRef getNumGangsKeyword() { return "num_gangs"; }
|
||||
static StringRef getNumWorkersKeyword() { return "num_workers"; }
|
||||
static StringRef getVectorLengthKeyword() { return "vector_length"; }
|
||||
static StringRef getIfKeyword() { return "if"; }
|
||||
static StringRef getSelfKeyword() { return "self"; }
|
||||
static StringRef getReductionKeyword() { return "reduction"; }
|
||||
static StringRef getCopyKeyword() { return "copy"; }
|
||||
static StringRef getCopyinKeyword() { return "copyin"; }
|
||||
static StringRef getCopyoutKeyword() { return "copyout"; }
|
||||
static StringRef getCreateKeyword() { return "create"; }
|
||||
static StringRef getNoCreateKeyword() { return "no_create"; }
|
||||
static StringRef getPresentKeyword() { return "present"; }
|
||||
static StringRef getDevicePtrKeyword() { return "deviceptr"; }
|
||||
static StringRef getAttachKeyword() { return "attach"; }
|
||||
static StringRef getPrivateKeyword() { return "private"; }
|
||||
static StringRef getFirstPrivateKeyword() { return "firstprivate"; }
|
||||
}];
|
||||
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// 2.6.5 data Construct
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def OpenACC_DataOp : OpenACC_Op<"data",
|
||||
[AttrSizedOperandSegments]> {
|
||||
let summary = "data construct";
|
||||
|
||||
let description = [{
|
||||
The "acc.data" operation represents a data construct. It defines vars to
|
||||
be allocated in the current device memory for the duration of the region,
|
||||
whether data should be copied from local memory to the current device
|
||||
memory upon region entry , and copied from device memory to local memory
|
||||
upon region exit.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
acc.data present(%a: memref<10x10xf32>, %b: memref<10x10xf32>,
|
||||
%c: memref<10xf32>, %d: memref<10xf32>) {
|
||||
// data region
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$presentOperands,
|
||||
Variadic<AnyType>:$copyOperands,
|
||||
Variadic<AnyType>:$copyinOperands,
|
||||
Variadic<AnyType>:$copyoutOperands,
|
||||
Variadic<AnyType>:$createOperands,
|
||||
Variadic<AnyType>:$noCreateOperands,
|
||||
Variadic<AnyType>:$deleteOperands,
|
||||
Variadic<AnyType>:$attachOperands,
|
||||
Variadic<AnyType>:$detachOperands);
|
||||
|
||||
let regions = (region AnyRegion:$region);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getAttachKeyword() { return "attach"; }
|
||||
static StringRef getDeleteKeyword() { return "delete"; }
|
||||
static StringRef getDetachKeyword() { return "detach"; }
|
||||
static StringRef getCopyinKeyword() { return "copyin"; }
|
||||
static StringRef getCopyKeyword() { return "copy"; }
|
||||
static StringRef getCopyoutKeyword() { return "copyout"; }
|
||||
static StringRef getCreateKeyword() { return "create"; }
|
||||
static StringRef getNoCreateKeyword() { return "no_create"; }
|
||||
static StringRef getPresentKeyword() { return "present"; }
|
||||
}];
|
||||
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
def OpenACC_TerminatorOp : OpenACC_Op<"terminator", [Terminator]> {
|
||||
let summary = "Generic terminator for OpenACC regions";
|
||||
|
||||
let description = [{
|
||||
A terminator operation for regions that appear in the body of OpenACC
|
||||
operation. Generic OpenACC construct regions are not expected to return any
|
||||
value so the terminator takes no operands. The terminator op returns control
|
||||
to the enclosing op.
|
||||
}];
|
||||
|
||||
let verifier = ?;
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// 2.9 loop Construct
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def OpenACC_LoopOp : OpenACC_Op<"loop",
|
||||
[AttrSizedOperandSegments]> {
|
||||
let summary = "loop construct";
|
||||
|
||||
let description = [{
|
||||
The "acc.loop" operation represents the OpenACC loop construct.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
acc.loop gang vector {
|
||||
scf.for %arg3 = %c0 to %c10 step %c1 {
|
||||
scf.for %arg4 = %c0 to %c10 step %c1 {
|
||||
scf.for %arg5 = %c0 to %c10 step %c1 {
|
||||
// ... body
|
||||
}
|
||||
}
|
||||
}
|
||||
acc.yield
|
||||
} attributes { collapse = 3 }
|
||||
```
|
||||
}];
|
||||
|
||||
|
||||
let arguments = (ins OptionalAttr<I64Attr>:$collapse,
|
||||
Variadic<AnyType>:$privateOperands,
|
||||
OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
|
||||
Variadic<AnyType>:$reductionOperands);
|
||||
|
||||
let results = (outs Variadic<AnyType>:$results);
|
||||
|
||||
let regions = (region AnyRegion:$region);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getCollapseAttrName() { return "collapse"; }
|
||||
static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
|
||||
static StringRef getGangAttrName() { return "gang"; }
|
||||
static StringRef getSeqAttrName() { return "seq"; }
|
||||
static StringRef getVectorAttrName() { return "vector"; }
|
||||
static StringRef getWorkerAttrName() { return "worker"; }
|
||||
static StringRef getPrivateKeyword() { return "private"; }
|
||||
static StringRef getReductionKeyword() { return "reduction"; }
|
||||
}];
|
||||
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
// Yield operation for the acc.loop and acc.parallel operations.
|
||||
def OpenACC_YieldOp : OpenACC_Op<"yield", [Terminator,
|
||||
ParentOneOf<["ParallelOp, LoopOp"]>]> {
|
||||
let summary = "Acc yield and termination operation";
|
||||
|
||||
let description = [{
|
||||
`acc.yield` is a special terminator operation for block inside regions in
|
||||
acc ops (parallel and loop). It returns values to the immediately enclosing
|
||||
acc op.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result",
|
||||
[{ /* nothing to do */ }]>
|
||||
];
|
||||
|
||||
let verifier = ?;
|
||||
|
||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
}
|
||||
|
||||
#endif // OPENACC_OPS
|
||||
|
||||
@@ -20,6 +20,7 @@ DEFINE_SYM_KIND_RANGE(QUANTIZATION)
|
||||
DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
|
||||
DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
|
||||
DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect
|
||||
DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect
|
||||
DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect
|
||||
DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
|
||||
DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
@@ -38,6 +39,7 @@ namespace mlir {
|
||||
// all the possible dialects to be made available to the context automatically.
|
||||
inline void registerAllDialects() {
|
||||
static bool init_once = []() {
|
||||
registerDialect<acc::OpenACCDialect>();
|
||||
registerDialect<AffineDialect>();
|
||||
registerDialect<avx512::AVX512Dialect>();
|
||||
registerDialect<gpu::GPUDialect>();
|
||||
|
||||
@@ -3,6 +3,7 @@ add_subdirectory(AVX512)
|
||||
add_subdirectory(GPU)
|
||||
add_subdirectory(Linalg)
|
||||
add_subdirectory(LLVMIR)
|
||||
add_subdirectory(OpenACC)
|
||||
add_subdirectory(OpenMP)
|
||||
add_subdirectory(Quant)
|
||||
add_subdirectory(SCF)
|
||||
|
||||
13
mlir/lib/Dialect/OpenACC/CMakeLists.txt
Normal file
13
mlir/lib/Dialect/OpenACC/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
add_mlir_dialect_library(MLIROpenACC
|
||||
IR/OpenACC.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
|
||||
|
||||
DEPENDS
|
||||
MLIROpenACCOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
579
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Normal file
579
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Normal file
@@ -0,0 +1,579 @@
|
||||
//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
|
||||
//
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
||||
#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace acc;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpenACC operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void OpenACCDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
template <typename StructureOp>
|
||||
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
|
||||
unsigned nRegions = 1) {
|
||||
|
||||
SmallVector<Region *, 2> regions;
|
||||
for (unsigned i = 0; i < nRegions; ++i)
|
||||
regions.push_back(state.addRegion());
|
||||
|
||||
for (Region *region : regions) {
|
||||
if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult
|
||||
parseOperandList(OpAsmParser &parser, StringRef keyword,
|
||||
SmallVectorImpl<OpAsmParser::OperandType> &args,
|
||||
SmallVectorImpl<Type> &argTypes, OperationState &result) {
|
||||
if (failed(parser.parseOptionalKeyword(keyword)))
|
||||
return success();
|
||||
|
||||
if (failed(parser.parseLParen()))
|
||||
return failure();
|
||||
|
||||
// Exit early if the list is empty.
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
do {
|
||||
OpAsmParser::OperandType arg;
|
||||
Type type;
|
||||
|
||||
if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
args.push_back(arg);
|
||||
argTypes.push_back(type);
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
|
||||
if (failed(parser.parseRParen()))
|
||||
return failure();
|
||||
|
||||
return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
|
||||
result.operands);
|
||||
}
|
||||
|
||||
static void printOperandList(Operation::operand_range operands,
|
||||
StringRef listName, OpAsmPrinter &printer) {
|
||||
|
||||
if (operands.size() > 0) {
|
||||
printer << " " << listName << "(";
|
||||
llvm::interleaveComma(operands, printer, [&](Value op) {
|
||||
printer << op << ": " << op.getType();
|
||||
});
|
||||
printer << ")";
|
||||
}
|
||||
}
|
||||
|
||||
static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
|
||||
OpAsmParser::OperandType &operand,
|
||||
Type type, bool &hasOptional,
|
||||
OperationState &result) {
|
||||
hasOptional = false;
|
||||
if (succeeded(parser.parseOptionalKeyword(keyword))) {
|
||||
hasOptional = true;
|
||||
if (parser.parseLParen() || parser.parseOperand(operand) ||
|
||||
parser.resolveOperand(operand, type, result.operands) ||
|
||||
parser.parseRParen())
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ParallelOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Parse acc.parallel operation
|
||||
/// operation := `acc.parallel` `async` `(` index `)`?
|
||||
/// `wait` `(` index-list `)`?
|
||||
/// `num_gangs` `(` value `)`?
|
||||
/// `num_workers` `(` value `)`?
|
||||
/// `vector_length` `(` value `)`?
|
||||
/// `if` `(` value `)`?
|
||||
/// `self` `(` value `)`?
|
||||
/// `reduction` `(` value-list `)`?
|
||||
/// `copy` `(` value-list `)`?
|
||||
/// `copyin` `(` value-list `)`?
|
||||
/// `copyout` `(` value-list `)`?
|
||||
/// `create` `(` value-list `)`?
|
||||
/// `no_create` `(` value-list `)`?
|
||||
/// `present` `(` value-list `)`?
|
||||
/// `deviceptr` `(` value-list `)`?
|
||||
/// `attach` `(` value-list `)`?
|
||||
/// `private` `(` value-list `)`?
|
||||
/// `firstprivate` `(` value-list `)`?
|
||||
/// region attr-dict?
|
||||
static ParseResult parseParallelOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
Builder &builder = parser.getBuilder();
|
||||
SmallVector<OpAsmParser::OperandType, 8> privateOperands,
|
||||
firstprivateOperands, createOperands, copyOperands, copyinOperands,
|
||||
copyoutOperands, noCreateOperands, presentOperands, devicePtrOperands,
|
||||
attachOperands, waitOperands, reductionOperands;
|
||||
SmallVector<Type, 8> operandTypes;
|
||||
OpAsmParser::OperandType async, numGangs, numWorkers, vectorLength, ifCond,
|
||||
selfCond;
|
||||
bool hasAsync = false, hasNumGangs = false, hasNumWorkers = false;
|
||||
bool hasVectorLength = false, hasIfCond = false, hasSelfCond = false;
|
||||
|
||||
Type indexType = builder.getIndexType();
|
||||
Type i1Type = builder.getI1Type();
|
||||
|
||||
// async()?
|
||||
if (failed(parseOptionalOperand(parser, ParallelOp::getAsyncKeyword(), async,
|
||||
indexType, hasAsync, result)))
|
||||
return failure();
|
||||
|
||||
// wait()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
|
||||
waitOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// num_gangs(value)?
|
||||
if (failed(parseOptionalOperand(parser, ParallelOp::getNumGangsKeyword(),
|
||||
numGangs, indexType, hasNumGangs, result)))
|
||||
return failure();
|
||||
|
||||
// num_workers(value)?
|
||||
if (failed(parseOptionalOperand(parser, ParallelOp::getNumWorkersKeyword(),
|
||||
numWorkers, indexType, hasNumWorkers,
|
||||
result)))
|
||||
return failure();
|
||||
|
||||
// vector_length(value)?
|
||||
if (failed(parseOptionalOperand(parser, ParallelOp::getVectorLengthKeyword(),
|
||||
vectorLength, indexType, hasVectorLength,
|
||||
result)))
|
||||
return failure();
|
||||
|
||||
// if()?
|
||||
if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond,
|
||||
i1Type, hasIfCond, result)))
|
||||
return failure();
|
||||
|
||||
// self()?
|
||||
if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(),
|
||||
selfCond, i1Type, hasSelfCond, result)))
|
||||
return failure();
|
||||
|
||||
// reduction()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
|
||||
reductionOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// copy()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
|
||||
copyOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// copyin()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
|
||||
copyinOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// copyout()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
|
||||
copyoutOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// create()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
|
||||
createOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// no_create()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
|
||||
noCreateOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// present()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
|
||||
presentOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// deviceptr()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
|
||||
devicePtrOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// attach()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
|
||||
attachOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// private()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
|
||||
privateOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// firstprivate()?
|
||||
if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
|
||||
firstprivateOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// Parallel op region
|
||||
if (failed(parseRegions<ParallelOp>(parser, result)))
|
||||
return failure();
|
||||
|
||||
result.addAttribute(ParallelOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr(
|
||||
{static_cast<int32_t>(hasAsync ? 1 : 0),
|
||||
static_cast<int32_t>(waitOperands.size()),
|
||||
static_cast<int32_t>(hasNumGangs ? 1 : 0),
|
||||
static_cast<int32_t>(hasNumWorkers ? 1 : 0),
|
||||
static_cast<int32_t>(hasVectorLength ? 1 : 0),
|
||||
static_cast<int32_t>(hasIfCond ? 1 : 0),
|
||||
static_cast<int32_t>(hasSelfCond ? 1 : 0),
|
||||
static_cast<int32_t>(reductionOperands.size()),
|
||||
static_cast<int32_t>(copyOperands.size()),
|
||||
static_cast<int32_t>(copyinOperands.size()),
|
||||
static_cast<int32_t>(copyoutOperands.size()),
|
||||
static_cast<int32_t>(createOperands.size()),
|
||||
static_cast<int32_t>(noCreateOperands.size()),
|
||||
static_cast<int32_t>(presentOperands.size()),
|
||||
static_cast<int32_t>(devicePtrOperands.size()),
|
||||
static_cast<int32_t>(attachOperands.size()),
|
||||
static_cast<int32_t>(privateOperands.size()),
|
||||
static_cast<int32_t>(firstprivateOperands.size())}));
|
||||
|
||||
// Additional attributes
|
||||
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &printer, ParallelOp &op) {
|
||||
printer << ParallelOp::getOperationName();
|
||||
|
||||
// async()?
|
||||
if (auto async = op.async())
|
||||
printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ")";
|
||||
|
||||
// wait()?
|
||||
printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer);
|
||||
|
||||
// num_gangs()?
|
||||
if (auto numGangs = op.numGangs())
|
||||
printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
|
||||
<< ")";
|
||||
|
||||
// num_workers()?
|
||||
if (auto numWorkers = op.numWorkers())
|
||||
printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
|
||||
<< ")";
|
||||
|
||||
// if()?
|
||||
if (Value ifCond = op.ifCond())
|
||||
printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";
|
||||
|
||||
// self()?
|
||||
if (Value selfCond = op.selfCond())
|
||||
printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";
|
||||
|
||||
// reduction()?
|
||||
printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(),
|
||||
printer);
|
||||
|
||||
// copy()?
|
||||
printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer);
|
||||
|
||||
// copyin()?
|
||||
printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(),
|
||||
printer);
|
||||
|
||||
// copyout()?
|
||||
printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(),
|
||||
printer);
|
||||
|
||||
// create()?
|
||||
printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(),
|
||||
printer);
|
||||
|
||||
// no_create()?
|
||||
printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(),
|
||||
printer);
|
||||
|
||||
// present()?
|
||||
printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(),
|
||||
printer);
|
||||
|
||||
// deviceptr()?
|
||||
printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
|
||||
printer);
|
||||
|
||||
// attach()?
|
||||
printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(),
|
||||
printer);
|
||||
|
||||
// private()?
|
||||
printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
|
||||
printer);
|
||||
|
||||
// firstprivate()?
|
||||
printOperandList(op.gangFirstPrivateOperands(),
|
||||
ParallelOp::getFirstPrivateKeyword(), printer);
|
||||
|
||||
printer.printRegion(op.region(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
printer.printOptionalAttrDictWithKeyword(
|
||||
op.getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DataOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Parse acc.data operation
|
||||
/// operation := `acc.parallel` `present` `(` value-list `)`?
|
||||
/// `copy` `(` value-list `)`?
|
||||
/// `copyin` `(` value-list `)`?
|
||||
/// `copyout` `(` value-list `)`?
|
||||
/// `create` `(` value-list `)`?
|
||||
/// `no_create` `(` value-list `)`?
|
||||
/// `delete` `(` value-list `)`?
|
||||
/// `attach` `(` value-list `)`?
|
||||
/// `detach` `(` value-list `)`?
|
||||
/// region attr-dict?
|
||||
static ParseResult parseDataOp(OpAsmParser &parser, OperationState &result) {
|
||||
Builder &builder = parser.getBuilder();
|
||||
SmallVector<OpAsmParser::OperandType, 8> presentOperands, copyOperands,
|
||||
copyinOperands, copyoutOperands, createOperands, noCreateOperands,
|
||||
deleteOperands, attachOperands, detachOperands;
|
||||
SmallVector<Type, 8> operandsTypes;
|
||||
|
||||
// present(value-list)?
|
||||
if (failed(parseOperandList(parser, DataOp::getPresentKeyword(),
|
||||
presentOperands, operandsTypes, result)))
|
||||
return failure();
|
||||
|
||||
// copy(value-list)?
|
||||
if (failed(parseOperandList(parser, DataOp::getCopyKeyword(), copyOperands,
|
||||
operandsTypes, result)))
|
||||
return failure();
|
||||
|
||||
// copyin(value-list)?
|
||||
if (failed(parseOperandList(parser, DataOp::getCopyinKeyword(),
|
||||
copyinOperands, operandsTypes, result)))
|
||||
return failure();
|
||||
|
||||
// copyout(value-list)?
|
||||
if (failed(parseOperandList(parser, DataOp::getCopyoutKeyword(),
|
||||
copyoutOperands, operandsTypes, result)))
|
||||
return failure();
|
||||
|
||||
// create(value-list)?
|
||||
if (failed(parseOperandList(parser, DataOp::getCreateKeyword(),
|
||||
createOperands, operandsTypes, result)))
|
||||
return failure();
|
||||
|
||||
// no_create(value-list)?
|
||||
if (failed(parseOperandList(parser, DataOp::getCreateKeyword(),
|
||||
noCreateOperands, operandsTypes, result)))
|
||||
return failure();
|
||||
|
||||
// delete(value-list)?
|
||||
if (failed(parseOperandList(parser, DataOp::getDeleteKeyword(),
|
||||
deleteOperands, operandsTypes, result)))
|
||||
return failure();
|
||||
|
||||
// attach(value-list)?
|
||||
if (failed(parseOperandList(parser, DataOp::getAttachKeyword(),
|
||||
attachOperands, operandsTypes, result)))
|
||||
return failure();
|
||||
|
||||
// detach(value-list)?
|
||||
if (failed(parseOperandList(parser, DataOp::getDetachKeyword(),
|
||||
detachOperands, operandsTypes, result)))
|
||||
return failure();
|
||||
|
||||
// Data op region
|
||||
if (failed(parseRegions<ParallelOp>(parser, result)))
|
||||
return failure();
|
||||
|
||||
result.addAttribute(
|
||||
ParallelOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr({static_cast<int32_t>(presentOperands.size()),
|
||||
static_cast<int32_t>(copyOperands.size()),
|
||||
static_cast<int32_t>(copyinOperands.size()),
|
||||
static_cast<int32_t>(copyoutOperands.size()),
|
||||
static_cast<int32_t>(createOperands.size()),
|
||||
static_cast<int32_t>(noCreateOperands.size()),
|
||||
static_cast<int32_t>(deleteOperands.size()),
|
||||
static_cast<int32_t>(attachOperands.size()),
|
||||
static_cast<int32_t>(detachOperands.size())}));
|
||||
|
||||
// Additional attributes
|
||||
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &printer, DataOp &op) {
|
||||
printer << DataOp::getOperationName();
|
||||
|
||||
// present(value-list)?
|
||||
printOperandList(op.presentOperands(), DataOp::getPresentKeyword(), printer);
|
||||
|
||||
// copy(value-list)?
|
||||
printOperandList(op.copyOperands(), DataOp::getCopyKeyword(), printer);
|
||||
|
||||
// copyin(value-list)?
|
||||
printOperandList(op.copyinOperands(), DataOp::getCopyinKeyword(), printer);
|
||||
|
||||
// copyout(value-list)?
|
||||
printOperandList(op.copyoutOperands(), DataOp::getCopyoutKeyword(), printer);
|
||||
|
||||
// create(value-list)?
|
||||
printOperandList(op.createOperands(), DataOp::getCreateKeyword(), printer);
|
||||
|
||||
// no_create(value-list)?
|
||||
printOperandList(op.noCreateOperands(), DataOp::getNoCreateKeyword(),
|
||||
printer);
|
||||
|
||||
// delete(value-list)?
|
||||
printOperandList(op.deleteOperands(), DataOp::getDeleteKeyword(), printer);
|
||||
|
||||
// attach(value-list)?
|
||||
printOperandList(op.attachOperands(), DataOp::getAttachKeyword(), printer);
|
||||
|
||||
// detach(value-list)?
|
||||
printOperandList(op.detachOperands(), DataOp::getDetachKeyword(), printer);
|
||||
|
||||
printer.printRegion(op.region(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
printer.printOptionalAttrDictWithKeyword(
|
||||
op.getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LoopOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Parse acc.loop operation
|
||||
/// operation := `acc.loop` `gang`? `vector`? `worker`? `seq`?
|
||||
/// `private` `(` value-list `)`?
|
||||
/// `reduction` `(` value-list `)`?
|
||||
/// region attr-dict?
|
||||
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
Builder &builder = parser.getBuilder();
|
||||
unsigned executionMapping = 0;
|
||||
SmallVector<Type, 8> operandTypes;
|
||||
SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
|
||||
|
||||
// gang?
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangAttrName())))
|
||||
executionMapping |= OpenACCExecMapping::GANG;
|
||||
|
||||
// vector?
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorAttrName())))
|
||||
executionMapping |= OpenACCExecMapping::VECTOR;
|
||||
|
||||
// worker?
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerAttrName())))
|
||||
executionMapping |= OpenACCExecMapping::WORKER;
|
||||
|
||||
// seq?
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getSeqAttrName())))
|
||||
executionMapping |= OpenACCExecMapping::SEQ;
|
||||
|
||||
// private()?
|
||||
if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
|
||||
privateOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// reduction()?
|
||||
if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(),
|
||||
reductionOperands, operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
if (executionMapping != 0)
|
||||
result.addAttribute(LoopOp::getExecutionMappingAttrName(),
|
||||
builder.getI64IntegerAttr(executionMapping));
|
||||
|
||||
// Parse optional results in case there is a reduce.
|
||||
if (parser.parseOptionalArrowTypeList(result.types))
|
||||
return failure();
|
||||
|
||||
if (failed(parseRegions<LoopOp>(parser, result)))
|
||||
return failure();
|
||||
|
||||
result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr(
|
||||
{static_cast<int32_t>(privateOperands.size()),
|
||||
static_cast<int32_t>(reductionOperands.size())}));
|
||||
|
||||
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &printer, LoopOp &op) {
|
||||
printer << LoopOp::getOperationName();
|
||||
|
||||
unsigned execMapping =
|
||||
(op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName()) !=
|
||||
nullptr)
|
||||
? op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName())
|
||||
.getInt()
|
||||
: 0;
|
||||
if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG)
|
||||
printer << " " << LoopOp::getGangAttrName();
|
||||
|
||||
if ((execMapping & OpenACCExecMapping::WORKER) == OpenACCExecMapping::WORKER)
|
||||
printer << " " << LoopOp::getWorkerAttrName();
|
||||
|
||||
if ((execMapping & OpenACCExecMapping::VECTOR) == OpenACCExecMapping::VECTOR)
|
||||
printer << " " << LoopOp::getVectorAttrName();
|
||||
|
||||
if ((execMapping & OpenACCExecMapping::SEQ) == OpenACCExecMapping::SEQ)
|
||||
printer << " " << LoopOp::getSeqAttrName();
|
||||
|
||||
// private()?
|
||||
printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);
|
||||
|
||||
// reduction()?
|
||||
printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(),
|
||||
printer);
|
||||
|
||||
if (op.getNumResults() > 0)
|
||||
printer << " -> (" << op.getResultTypes() << ")";
|
||||
|
||||
printer.printRegion(op.region(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
|
||||
printer.printOptionalAttrDictWithKeyword(
|
||||
op.getAttrs(), {LoopOp::getExecutionMappingAttrName(),
|
||||
LoopOp::getOperandSegmentSizeAttr()});
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
|
||||
187
mlir/test/Dialect/OpenACC/ops.mlir
Normal file
187
mlir/test/Dialect/OpenACC/ops.mlir
Normal file
@@ -0,0 +1,187 @@
|
||||
// RUN: mlir-opt %s | FileCheck %s
|
||||
// Verify the printed output can be parsed.
|
||||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
||||
// Verify the generic form can be parsed.
|
||||
// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
|
||||
|
||||
func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%c10 = constant 10 : index
|
||||
%c1 = constant 1 : index
|
||||
|
||||
acc.parallel async(%c1) {
|
||||
acc.loop gang vector {
|
||||
scf.for %arg3 = %c0 to %c10 step %c1 {
|
||||
scf.for %arg4 = %c0 to %c10 step %c1 {
|
||||
scf.for %arg5 = %c0 to %c10 step %c1 {
|
||||
%a = load %A[%arg3, %arg5] : memref<10x10xf32>
|
||||
%b = load %B[%arg5, %arg4] : memref<10x10xf32>
|
||||
%cij = load %C[%arg3, %arg4] : memref<10x10xf32>
|
||||
%p = mulf %a, %b : f32
|
||||
%co = addf %cij, %p : f32
|
||||
store %co, %C[%arg3, %arg4] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
acc.yield
|
||||
} attributes { collapse = 3 }
|
||||
acc.yield
|
||||
}
|
||||
|
||||
return %C : memref<10x10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @compute1(
|
||||
// CHECK-NEXT: %{{.*}} = constant 0 : index
|
||||
// CHECK-NEXT: %{{.*}} = constant 10 : index
|
||||
// CHECK-NEXT: %{{.*}} = constant 1 : index
|
||||
// CHECK-NEXT: acc.parallel async(%{{.*}}) {
|
||||
// CHECK-NEXT: acc.loop gang vector {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: } attributes {collapse = 3 : i64}
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %{{.*}} : memref<10x10xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%c10 = constant 10 : index
|
||||
%c1 = constant 1 : index
|
||||
|
||||
acc.parallel {
|
||||
acc.loop seq {
|
||||
scf.for %arg3 = %c0 to %c10 step %c1 {
|
||||
scf.for %arg4 = %c0 to %c10 step %c1 {
|
||||
scf.for %arg5 = %c0 to %c10 step %c1 {
|
||||
%a = load %A[%arg3, %arg5] : memref<10x10xf32>
|
||||
%b = load %B[%arg5, %arg4] : memref<10x10xf32>
|
||||
%cij = load %C[%arg3, %arg4] : memref<10x10xf32>
|
||||
%p = mulf %a, %b : f32
|
||||
%co = addf %cij, %p : f32
|
||||
store %co, %C[%arg3, %arg4] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
acc.yield
|
||||
}
|
||||
acc.yield
|
||||
}
|
||||
|
||||
return %C : memref<10x10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @compute2(
|
||||
// CHECK-NEXT: %{{.*}} = constant 0 : index
|
||||
// CHECK-NEXT: %{{.*}} = constant 10 : index
|
||||
// CHECK-NEXT: %{{.*}} = constant 1 : index
|
||||
// CHECK-NEXT: acc.parallel {
|
||||
// CHECK-NEXT: acc.loop seq {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %{{.*}} : memref<10x10xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
||||
func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> {
|
||||
%lb = constant 0 : index
|
||||
%st = constant 1 : index
|
||||
%c10 = constant 10 : index
|
||||
|
||||
acc.data present(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) {
|
||||
acc.parallel num_gangs(%c10) num_workers(%c10) private(%c : memref<10xf32>) {
|
||||
acc.loop gang {
|
||||
scf.for %x = %lb to %c10 step %st {
|
||||
acc.loop worker {
|
||||
scf.for %y = %lb to %c10 step %st {
|
||||
%axy = load %a[%x, %y] : memref<10x10xf32>
|
||||
%bxy = load %b[%x, %y] : memref<10x10xf32>
|
||||
%tmp = addf %axy, %bxy : f32
|
||||
store %tmp, %c[%y] : memref<10xf32>
|
||||
}
|
||||
acc.yield
|
||||
}
|
||||
|
||||
acc.loop seq {
|
||||
// for i = 0 to 10 step 1
|
||||
// d[x] += c[i]
|
||||
scf.for %i = %lb to %c10 step %st {
|
||||
%ci = load %c[%i] : memref<10xf32>
|
||||
%dx = load %d[%x] : memref<10xf32>
|
||||
%z = addf %ci, %dx : f32
|
||||
store %z, %d[%x] : memref<10xf32>
|
||||
}
|
||||
acc.yield
|
||||
}
|
||||
}
|
||||
acc.yield
|
||||
}
|
||||
acc.yield
|
||||
}
|
||||
acc.terminator
|
||||
}
|
||||
|
||||
return %d : memref<10xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @compute3({{.*}}: memref<10x10xf32>, {{.*}}: memref<10x10xf32>, [[ARG2:%.*]]: memref<10xf32>, {{.*}}: memref<10xf32>) -> memref<10xf32> {
|
||||
// CHECK-NEXT: [[C0:%.*]] = constant 0 : index
|
||||
// CHECK-NEXT: [[C1:%.*]] = constant 1 : index
|
||||
// CHECK-NEXT: [[C10:%.*]] = constant 10 : index
|
||||
// CHECK-NEXT: acc.data present(%{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10xf32>, %{{.*}}: memref<10xf32>) {
|
||||
// CHECK-NEXT: acc.parallel num_gangs([[C10]]) num_workers([[C10]]) private([[ARG2]]: memref<10xf32>) {
|
||||
// CHECK-NEXT: acc.loop gang {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
|
||||
// CHECK-NEXT: acc.loop worker {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.loop seq {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.terminator
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %{{.*}} : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
Reference in New Issue
Block a user