diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index 8555073a2d0d..8c6d3b37166a 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -1243,22 +1243,30 @@ bool CheckForCoindexedObject(parser::ContextualMessages &, const std::optional &, const std::string &procName, const std::string &argName); -// Get the number of distinct symbols with CUDA attribute in the expression. +inline bool IsCUDADeviceSymbol(const Symbol &sym) { + if (const auto *details = + sym.GetUltimate().detailsIf()) { + if (details->cudaDataAttr() && + *details->cudaDataAttr() != common::CUDADataAttr::Pinned) { + return true; + } + } + return false; +} + +// Get the number of distinct symbols with CUDA device +// attribute in the expression. template inline int GetNbOfCUDADeviceSymbols(const A &expr) { semantics::UnorderedSymbolSet symbols; for (const Symbol &sym : CollectCudaSymbols(expr)) { - if (const auto *details = - sym.GetUltimate().detailsIf()) { - if (details->cudaDataAttr() && - *details->cudaDataAttr() != common::CUDADataAttr::Pinned) { - symbols.insert(sym); - } + if (IsCUDADeviceSymbol(sym)) { + symbols.insert(sym); } } return symbols.size(); } -// Check if any of the symbols part of the expression has a CUDA data +// Check if any of the symbols part of the expression has a CUDA device // attribute. template inline bool HasCUDADeviceAttrs(const A &expr) { return GetNbOfCUDADeviceSymbols(expr) > 0; @@ -1270,26 +1278,15 @@ inline bool HasCUDAImplicitTransfer(const Expr &expr) { unsigned hostSymbols{0}; unsigned deviceSymbols{0}; for (const Symbol &sym : CollectCudaSymbols(expr)) { - if (const auto *details = - sym.GetUltimate().detailsIf()) { - if (details->cudaDataAttr() && - *details->cudaDataAttr() != common::CUDADataAttr::Pinned) { - ++deviceSymbols; - } else { - if (sym.owner().IsDerivedType()) { - if (const auto *details = - sym.owner() - .GetSymbol() - ->GetUltimate() - .detailsIf()) { - if (details->cudaDataAttr() && - *details->cudaDataAttr() != common::CUDADataAttr::Pinned) { - ++deviceSymbols; - } - } + if (IsCUDADeviceSymbol(sym)) { + ++deviceSymbols; + } else { + if (sym.owner().IsDerivedType()) { + if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) { + ++deviceSymbols; } - ++hostSymbols; } + ++hostSymbols; } } return hostSymbols > 0 && deviceSymbols > 0;