cuda module: use typed_kwargs
This officially only ever accepted string or array of strings.
This commit is contained in:
parent
6f7e745052
commit
cf35d9b4ce
|
@ -8,18 +8,26 @@ import re
|
|||
|
||||
from ..mesonlib import version_compare
|
||||
from ..compilers.cuda import CudaCompiler
|
||||
from ..interpreter.type_checking import NoneType
|
||||
|
||||
from . import NewExtensionModule, ModuleInfo
|
||||
|
||||
from ..interpreterbase import (
|
||||
flatten, permittedKwargs, noKwargs,
|
||||
InvalidArguments
|
||||
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
|
||||
)
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from . import ModuleState
|
||||
from ..compilers import Compiler
|
||||
|
||||
class ArchFlagsKwargs(TypedDict):
|
||||
detected: T.Optional[T.List[str]]
|
||||
|
||||
|
||||
DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True)
|
||||
|
||||
class CudaModule(NewExtensionModule):
|
||||
|
||||
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)
|
||||
|
@ -87,18 +95,18 @@ class CudaModule(NewExtensionModule):
|
|||
|
||||
return driver_version
|
||||
|
||||
@permittedKwargs(['detected'])
|
||||
@typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
|
||||
def nvcc_arch_flags(self, state: 'ModuleState',
|
||||
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
|
||||
kwargs: T.Dict[str, T.Any]) -> T.List[str]:
|
||||
kwargs: ArchFlagsKwargs) -> T.List[str]:
|
||||
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
|
||||
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
|
||||
return ret
|
||||
|
||||
@permittedKwargs(['detected'])
|
||||
@typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
|
||||
def nvcc_arch_readable(self, state: 'ModuleState',
|
||||
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
|
||||
kwargs: T.Dict[str, T.Any]) -> T.List[str]:
|
||||
kwargs: ArchFlagsKwargs) -> T.List[str]:
|
||||
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
|
||||
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
|
||||
return ret
|
||||
|
@ -110,10 +118,10 @@ class CudaModule(NewExtensionModule):
|
|||
return s
|
||||
|
||||
@staticmethod
|
||||
def _detected_cc_from_compiler(c):
|
||||
def _detected_cc_from_compiler(c) -> T.List[str]:
|
||||
if isinstance(c, CudaCompiler):
|
||||
return c.detected_cc
|
||||
return ''
|
||||
return [c.detected_cc]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _version_from_compiler(c):
|
||||
|
@ -123,7 +131,7 @@ class CudaModule(NewExtensionModule):
|
|||
return c
|
||||
return 'unknown'
|
||||
|
||||
def _validate_nvcc_arch_args(self, args, kwargs):
|
||||
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
|
||||
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
|
||||
|
||||
if len(args) < 1:
|
||||
|
@ -141,8 +149,7 @@ class CudaModule(NewExtensionModule):
|
|||
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
|
||||
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
|
||||
|
||||
detected = kwargs.get('detected', self._detected_cc_from_compiler(compiler))
|
||||
detected = flatten([detected])
|
||||
detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler)
|
||||
detected = [self._break_arch_string(a) for a in detected]
|
||||
detected = flatten(detected)
|
||||
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):
|
||||
|
|
Loading…
Reference in New Issue