mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[flang][cuda][NFC] Extract is cuda device attribute logic (#100809)
This commit is contained in:
committed by
GitHub
parent
dbb8b7a0f4
commit
0ff92593d2
@@ -1243,22 +1243,30 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
|
||||
const std::optional<ActualArgument> &, const std::string &procName,
|
||||
const std::string &argName);
|
||||
|
||||
// Get the number of distinct symbols with CUDA attribute in the expression.
|
||||
inline bool IsCUDADeviceSymbol(const Symbol &sym) {
|
||||
if (const auto *details =
|
||||
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
|
||||
if (details->cudaDataAttr() &&
|
||||
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get the number of distinct symbols with CUDA device
|
||||
// attribute in the expression.
|
||||
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
|
||||
semantics::UnorderedSymbolSet symbols;
|
||||
for (const Symbol &sym : CollectCudaSymbols(expr)) {
|
||||
if (const auto *details =
|
||||
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
|
||||
if (details->cudaDataAttr() &&
|
||||
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
|
||||
symbols.insert(sym);
|
||||
}
|
||||
if (IsCUDADeviceSymbol(sym)) {
|
||||
symbols.insert(sym);
|
||||
}
|
||||
}
|
||||
return symbols.size();
|
||||
}
|
||||
|
||||
// Check if any of the symbols part of the expression has a CUDA data
|
||||
// Check if any of the symbols part of the expression has a CUDA device
|
||||
// attribute.
|
||||
template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
|
||||
return GetNbOfCUDADeviceSymbols(expr) > 0;
|
||||
@@ -1270,26 +1278,15 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
|
||||
unsigned hostSymbols{0};
|
||||
unsigned deviceSymbols{0};
|
||||
for (const Symbol &sym : CollectCudaSymbols(expr)) {
|
||||
if (const auto *details =
|
||||
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
|
||||
if (details->cudaDataAttr() &&
|
||||
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
|
||||
++deviceSymbols;
|
||||
} else {
|
||||
if (sym.owner().IsDerivedType()) {
|
||||
if (const auto *details =
|
||||
sym.owner()
|
||||
.GetSymbol()
|
||||
->GetUltimate()
|
||||
.detailsIf<semantics::ObjectEntityDetails>()) {
|
||||
if (details->cudaDataAttr() &&
|
||||
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
|
||||
++deviceSymbols;
|
||||
}
|
||||
}
|
||||
if (IsCUDADeviceSymbol(sym)) {
|
||||
++deviceSymbols;
|
||||
} else {
|
||||
if (sym.owner().IsDerivedType()) {
|
||||
if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) {
|
||||
++deviceSymbols;
|
||||
}
|
||||
++hostSymbols;
|
||||
}
|
||||
++hostSymbols;
|
||||
}
|
||||
}
|
||||
return hostSymbols > 0 && deviceSymbols > 0;
|
||||
|
||||
Reference in New Issue
Block a user