[flang][cuda] Update attribute compatibily check for unified matching rule (#90679)

This patch updates the compatibility checks for CUDA attribute iin
preparation to implement the matching rules described in section 3.2.3.
We this patch the compiler will still emit an error when there is
multiple specific procedures that matches since the matching distances
is not yet implemented. This will be done in a separate patch.


https://docs.nvidia.com/hpc-sdk/archive/24.3/compilers/cuda-fortran-prog-guide/index.html#cfref-var-attr-unified-data

gpu=unified and gpu=managed are not part of this patch since these
options are not recognized by flang yet.
This commit is contained in:
Valentin Clement (バレンタイン クレメン)
2024-04-30 19:58:46 -07:00
committed by GitHub
parent ef1dbcd60f
commit 86e5d6f1d8
5 changed files with 43 additions and 9 deletions

View File

@@ -114,8 +114,8 @@ static constexpr IgnoreTKRSet ignoreTKRAll{IgnoreTKR::Type, IgnoreTKR::Kind,
IgnoreTKR::Rank, IgnoreTKR::Device, IgnoreTKR::Managed};
std::string AsFortran(IgnoreTKRSet);
bool AreCompatibleCUDADataAttrs(
std::optional<CUDADataAttr>, std::optional<CUDADataAttr>, IgnoreTKRSet);
bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr>,
std::optional<CUDADataAttr>, IgnoreTKRSet, bool allowUnifiedMatchingRule);
static constexpr char blankCommonObjectName[] = "__BLNK__";

View File

@@ -97,8 +97,12 @@ std::string AsFortran(IgnoreTKRSet tkr) {
return result;
}
/// Check compatibilty of CUDA attribute.
/// When `allowUnifiedMatchingRule` is enabled, argument `x` represents the
/// dummy argument attribute while `y` represents the actual argument attribute.
bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR) {
std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR,
bool allowUnifiedMatchingRule) {
if (!x && !y) {
return true;
} else if (x && y && *x == *y) {
@@ -114,6 +118,24 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
x.value_or(CUDADataAttr::Managed) == CUDADataAttr::Managed &&
y.value_or(CUDADataAttr::Managed) == CUDADataAttr::Managed) {
return true;
} else if (allowUnifiedMatchingRule) {
if (!x) { // Dummy argument has no attribute -> host
if (y && *y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified) {
return true;
}
} else {
if (*x == CUDADataAttr::Device && y &&
(*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) {
return true;
} else if (*x == CUDADataAttr::Managed && y &&
*y == CUDADataAttr::Unified) {
return true;
} else if (*x == CUDADataAttr::Unified && y &&
*y == CUDADataAttr::Managed) {
return true;
}
}
return false;
} else {
return false;
}

View File

@@ -362,8 +362,9 @@ bool DummyDataObject::IsCompatibleWith(const DummyDataObject &actual,
}
}
if (!attrs.test(Attr::Value) &&
!common::AreCompatibleCUDADataAttrs(
cudaDataAttr, actual.cudaDataAttr, ignoreTKR)) {
!common::AreCompatibleCUDADataAttrs(cudaDataAttr, actual.cudaDataAttr,
ignoreTKR,
/*allowUnifiedMatchingRule=*/false)) {
if (whyNot) {
*whyNot = "incompatible CUDA data attributes";
}
@@ -1754,8 +1755,9 @@ bool DistinguishUtils::Distinguishable(
} else if (y.attrs.test(Attr::Allocatable) && x.attrs.test(Attr::Pointer) &&
x.intent != common::Intent::In) {
return true;
} else if (!common::AreCompatibleCUDADataAttrs(
x.cudaDataAttr, y.cudaDataAttr, x.ignoreTKR | y.ignoreTKR)) {
} else if (!common::AreCompatibleCUDADataAttrs(x.cudaDataAttr, y.cudaDataAttr,
x.ignoreTKR | y.ignoreTKR,
/*allowUnifiedMatchingRule=*/false)) {
return true;
} else if (features_.IsEnabled(
common::LanguageFeature::DistinguishableSpecifics) &&

View File

@@ -897,8 +897,9 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
actualDataAttr = common::CUDADataAttr::Device;
}
}
if (!common::AreCompatibleCUDADataAttrs(
dummyDataAttr, actualDataAttr, dummy.ignoreTKR)) {
if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr,
dummy.ignoreTKR,
/*allowUnifiedMatchingRule=*/true)) {
auto toStr{[](std::optional<common::CUDADataAttr> x) {
return x ? "ATTRIBUTES("s +
parser::ToUpperCaseLetters(common::EnumToString(*x)) + ")"s

View File

@@ -6,6 +6,10 @@ module matching
module procedure sub_device
end interface
interface subman
module procedure sub_host
end interface
contains
subroutine sub_host(a)
integer :: a(:)
@@ -21,8 +25,13 @@ program m
use matching
integer, pinned, allocatable :: a(:)
integer, managed, allocatable :: b(:)
logical :: plog
allocate(a(100), pinned = plog)
allocate(b(200))
call sub(a)
call subman(b)
end