mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 02:38:07 +08:00
[flang] Update target rewrite to support workgroup and private attributions (#164515)
Some operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution. Update the target rewrite pass to be aware of that when adding argument at the end of the function signature. If any trailing arguments are present, the new argument will be inserted just before them.
This commit is contained in:
committed by
GitHub
parent
23ead47655
commit
47ea8543e2
@@ -872,6 +872,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// Count the number of arguments that have to stay in place at the end of
|
||||
// the argument list.
|
||||
unsigned trailingArgs = 0;
|
||||
if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) {
|
||||
trailingArgs =
|
||||
func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions();
|
||||
}
|
||||
|
||||
// Convert return value(s)
|
||||
for (auto ty : funcTy.getResults())
|
||||
llvm::TypeSwitch<mlir::Type>(ty)
|
||||
@@ -981,6 +989,16 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// Add the argument at the end if the number of trailing arguments is 0,
|
||||
// otherwise insert the argument at the appropriate index.
|
||||
auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) {
|
||||
unsigned inputIndex = func.front().getArguments().size() - trailingArgs;
|
||||
auto newArg = trailingArgs == 0
|
||||
? func.front().addArgument(ty, loc)
|
||||
: func.front().insertArgument(inputIndex, ty, loc);
|
||||
return newArg;
|
||||
};
|
||||
|
||||
if (!func.empty()) {
|
||||
// If the function has a body, then apply the fixups to the arguments and
|
||||
// return ops as required. These fixups are done in place.
|
||||
@@ -1117,8 +1135,7 @@ public:
|
||||
// original arguments. (Boxchar arguments.)
|
||||
auto newBufArg =
|
||||
func.front().insertArgument(fixup.index, fixupType, loc);
|
||||
auto newLenArg =
|
||||
func.front().addArgument(trailingTys[fixup.second], loc);
|
||||
auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
|
||||
auto boxTy = oldArgTys[fixup.index - offset];
|
||||
rewriter->setInsertionPointToStart(&func.front());
|
||||
auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg,
|
||||
@@ -1133,8 +1150,7 @@ public:
|
||||
// appended after all the original arguments.
|
||||
auto newProcPointerArg =
|
||||
func.front().insertArgument(fixup.index, fixupType, loc);
|
||||
auto newLenArg =
|
||||
func.front().addArgument(trailingTys[fixup.second], loc);
|
||||
auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
|
||||
auto tupleType = oldArgTys[fixup.index - offset];
|
||||
rewriter->setInsertionPointToStart(&func.front());
|
||||
fir::FirOpBuilder builder(*rewriter, getModule());
|
||||
|
||||
@@ -55,3 +55,56 @@ func.func @main(%arg0: complex<f64>) {
|
||||
// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
|
||||
// CHECK: gpu.return
|
||||
// CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64) {cuf.proc_attr = #cuf.cuda_proc<global>}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {gpu.container_module, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
|
||||
gpu.module @testmod {
|
||||
gpu.func @_QMbarPfoo(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
gpu.return
|
||||
}
|
||||
// CHECK-LABEL: gpu.func @_QMbarPfoo(
|
||||
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WORKGROUP:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
|
||||
// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
|
||||
// CHECK: memref.store %{{.*}}, %[[WORKGROUP]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
|
||||
gpu.func @_QMbarPfoo2(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %arg4 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
memref.store %arg0, %arg4[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
gpu.return
|
||||
}
|
||||
// CHECK-LABEL: gpu.func @_QMbarPfoo2(
|
||||
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG1:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %[[WG2:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
|
||||
// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
|
||||
// CHECK: memref.store %{{.*}}, %[[WG1]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
// CHECK: memref.store %{{.*}}, %[[WG2]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
|
||||
gpu.func @_QMbarPprivate(%arg0: f32, %arg1: !fir.boxchar<1>) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%arg3 : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<private>>
|
||||
gpu.return
|
||||
}
|
||||
// CHECK-LABEL: gpu.func @_QMbarPprivate(
|
||||
// CHECK-SAME: %{{.*}}: f32, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%[[PRIVATE:.*]] : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
|
||||
// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
|
||||
// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
// CHECK: memref.store %{{.*}}, %[[PRIVATE]][%{{.*}}] : memref<1xf32, #gpu.address_space<private>>
|
||||
|
||||
gpu.func @test_with_char_proc(%arg0: f32, %arg1: tuple<() -> (), i64> {fir.char_proc}) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
gpu.return
|
||||
}
|
||||
// CHECK-LABEL: gpu.func @test_with_char_proc(
|
||||
// CHECK-SAME: %{{.*}}: f32, %[[CHARPROC:.*]]: () -> () {fir.char_proc}, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>>) {
|
||||
// CHECK: %{{.*}} = fir.undefined tuple<() -> (), i64>
|
||||
// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[CHARPROC]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64>
|
||||
// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[LENGTH]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64>
|
||||
// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ int main(int argc, char **argv) {
|
||||
#endif
|
||||
DialectRegistry registry;
|
||||
fir::support::registerDialects(registry);
|
||||
registry.insert<mlir::memref::MemRefDialect>();
|
||||
fir::support::addFIRExtensions(registry);
|
||||
return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n",
|
||||
registry));
|
||||
|
||||
Reference in New Issue
Block a user