mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[flang][cuda] Accept scalar expression for bytes in kernel call (#165040)
This commit is contained in:
committed by
GitHub
parent
8c29bce1e9
commit
825eefe856
@@ -3274,13 +3274,13 @@ struct FunctionReference {
|
||||
// R1521 call-stmt -> CALL procedure-designator [ chevrons ]
|
||||
// [( [actual-arg-spec-list] )]
|
||||
// (CUDA) chevrons -> <<< * | scalar-expr, scalar-expr [,
|
||||
// scalar-int-expr [, scalar-int-expr ] ] >>>
|
||||
// scalar-expr [, scalar-int-expr ] ] >>>
|
||||
struct CallStmt {
|
||||
BOILERPLATE(CallStmt);
|
||||
WRAPPER_CLASS(StarOrExpr, std::optional<ScalarExpr>);
|
||||
struct Chevrons {
|
||||
TUPLE_CLASS_BOILERPLATE(Chevrons);
|
||||
std::tuple<StarOrExpr, ScalarExpr, std::optional<ScalarIntExpr>,
|
||||
std::tuple<StarOrExpr, ScalarExpr, std::optional<ScalarExpr>,
|
||||
std::optional<ScalarIntExpr>>
|
||||
t;
|
||||
};
|
||||
|
||||
@@ -484,7 +484,7 @@ constexpr auto starOrExpr{
|
||||
applyFunction(presentOptional<ScalarExpr>, scalarExpr))};
|
||||
TYPE_PARSER(extension<LanguageFeature::CUDA>(
|
||||
"<<<" >> construct<CallStmt::Chevrons>(starOrExpr, ", " >> scalarExpr,
|
||||
maybe("," >> scalarIntExpr), maybe("," >> scalarIntExpr)) /
|
||||
maybe("," >> scalarExpr), maybe("," >> scalarIntExpr)) /
|
||||
">>>"))
|
||||
constexpr auto actualArgSpecList{optionalList(actualArgSpec)};
|
||||
TYPE_CONTEXT_PARSER("CALL statement"_en_US,
|
||||
|
||||
@@ -16,6 +16,7 @@ contains
|
||||
subroutine host()
|
||||
real, device :: a
|
||||
integer(8) :: stream
|
||||
integer(4) :: nbytes
|
||||
|
||||
! CHECK-LABEL: func.func @_QMtest_callPhost()
|
||||
! CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QMtest_callFhostEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
|
||||
@@ -57,6 +58,10 @@ contains
|
||||
call dev_kernel1<<<*,32,0,stream>>>(a)
|
||||
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}, %c0{{.*}}, %{{.*}} : !fir.ref<i64>>>>(%{{.*}}) : (!fir.ref<f32>)
|
||||
|
||||
call dev_kernel1<<<*, 32, 0.8 * nbytes>>>(a)
|
||||
! CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} fastmath<contract> : f32
|
||||
! CHECK: %[[BYTES:.*]] = fir.convert %[[MUL]] : (f32) -> i32
|
||||
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[BYTES]]>>>(%{{.*}}) : (!fir.ref<f32>)
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
@@ -43,6 +43,7 @@ module m
|
||||
call globalsub<<<1, 2>>>
|
||||
call globalsub<<<1, 2, 3>>>
|
||||
call globalsub<<<1, 2, 3, 4>>>
|
||||
call globalsub<<<1, 2, 0.9*10, 4>>>
|
||||
call globalsub<<<*,5>>>
|
||||
allocate(pa(32), pinned = isPinned)
|
||||
end subroutine
|
||||
|
||||
@@ -178,7 +178,7 @@ include "cuf-sanity-common"
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
|
||||
!CHECK: | | | | | | Scalar -> Expr = '2_4'
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
|
||||
!CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4'
|
||||
!CHECK: | | | | | | Scalar -> Expr = '3_4'
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3'
|
||||
!CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> CallStmt = 'CALL globalsub<<<1_4,2_4,3_4,4_4>>>()'
|
||||
!CHECK: | | | | | Call
|
||||
@@ -188,10 +188,27 @@ include "cuf-sanity-common"
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
|
||||
!CHECK: | | | | | | Scalar -> Expr = '2_4'
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
|
||||
!CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4'
|
||||
!CHECK: | | | | | | Scalar -> Expr = '3_4'
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3'
|
||||
!CHECK: | | | | | | Scalar -> Integer -> Expr = '4_4'
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '4'
|
||||
!CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> CallStmt = 'CALL globalsub<<<1_4,2_4,9._4,4_4>>>()'
|
||||
!CHECK: | | | | | Call
|
||||
!CHECK: | | | | | | ProcedureDesignator -> Name = 'globalsub'
|
||||
!CHECK: | | | | | Chevrons
|
||||
!CHECK: | | | | | | StarOrExpr -> Scalar -> Expr = '1_4'
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
|
||||
!CHECK: | | | | | | Scalar -> Expr = '2_4'
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
|
||||
!CHECK: | | | | | | Scalar -> Expr = '9._4'
|
||||
!CHECK: | | | | | | | Multiply
|
||||
!CHECK: | | | | | | | | Expr = '8.9999997615814208984375e-1_4'
|
||||
!CHECK: | | | | | | | | | LiteralConstant -> RealLiteralConstant
|
||||
!CHECK: | | | | | | | | | | Real = '0.9'
|
||||
!CHECK: | | | | | | | | Expr = '10_4'
|
||||
!CHECK: | | | | | | | | | LiteralConstant -> IntLiteralConstant = '10'
|
||||
!CHECK: | | | | | | Scalar -> Integer -> Expr = '4_4'
|
||||
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '4'
|
||||
!CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AllocateStmt
|
||||
!CHECK: | | | | | Allocation
|
||||
!CHECK: | | | | | | AllocateObject = 'pa'
|
||||
|
||||
Reference in New Issue
Block a user