[mlir][rocdl] Change the translation of GridDim*Op to __ockl_get_num_groups

Currently, `ROCDL::GridDim*Op` is being translated to `__ockl_get_global_size`, however
to match the meaning of `gpu.grid_dim` it should instead be translated to
`__ockl_get_num_groups`. This change would also make it agree with the meaning
of `gridDimx.*` in HIP, see:
https://github.com/ROCm-Developer-Tools/hipamd/blob/develop/include/hip/amd_detail/amd_hip_runtime.h#L257

Difference between the functions:
```
__ockl_get_global_size =  blockDim * numBlocks
__ockl_get_num_groups = numBlocks
```

Reviewed By: krzysz00

Differential Revision: https://reviews.llvm.org/D156009
This commit is contained in:
Fabian Mora
2023-07-22 11:32:42 +00:00
parent ff380ced05
commit 4538347fb2
2 changed files with 6 additions and 6 deletions

View File

@@ -101,13 +101,13 @@ def ROCDL_BlockDimZOp : ROCDL_DeviceFunctionOp<"workgroup.dim.z",
"__ockl_get_local_size", 2>;
def ROCDL_GridDimXOp : ROCDL_DeviceFunctionOp<"grid.dim.x",
"__ockl_get_global_size", 0>;
"__ockl_get_num_groups", 0>;
def ROCDL_GridDimYOp : ROCDL_DeviceFunctionOp<"grid.dim.y",
"__ockl_get_global_size", 1>;
"__ockl_get_num_groups", 1>;
def ROCDL_GridDimZOp : ROCDL_DeviceFunctionOp<"grid.dim.z",
"__ockl_get_global_size", 2>;
"__ockl_get_num_groups", 2>;
//===----------------------------------------------------------------------===//
// Synchronization primitives

View File

@@ -20,11 +20,11 @@ llvm.func @rocdl_special_regs() -> i32 {
%8 = rocdl.workgroup.dim.y : i64
// CHECK: call i64 @__ockl_get_local_size(i32 2)
%9 = rocdl.workgroup.dim.z : i64
// CHECK: call i64 @__ockl_get_global_size(i32 0)
// CHECK: call i64 @__ockl_get_num_groups(i32 0)
%10 = rocdl.grid.dim.x : i64
// CHECK: call i64 @__ockl_get_global_size(i32 1)
// CHECK: call i64 @__ockl_get_num_groups(i32 1)
%11 = rocdl.grid.dim.y : i64
// CHECK: call i64 @__ockl_get_global_size(i32 2)
// CHECK: call i64 @__ockl_get_num_groups(i32 2)
%12 = rocdl.grid.dim.z : i64
// CHECK: call i32 @llvm.amdgcn.workitem.id.x(),{{.*}} !range ![[$RANGE:[0-9]+]]