diff --git a/mesonbuild/modules/cuda.py b/mesonbuild/modules/cuda.py index b52288ab5..690053868 100644 --- a/mesonbuild/modules/cuda.py +++ b/mesonbuild/modules/cuda.py @@ -13,14 +13,13 @@ from ..interpreter.type_checking import NoneType from . import NewExtensionModule, ModuleInfo from ..interpreterbase import ( - ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, + ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args, ) if T.TYPE_CHECKING: from typing_extensions import TypedDict from . import ModuleState - from ..compilers import Compiler class ArchFlagsKwargs(TypedDict): detected: T.Optional[T.List[str]] @@ -95,17 +94,19 @@ class CudaModule(NewExtensionModule): return driver_version + @typed_pos_args('cuda.nvcc_arch_flags', (str, CudaCompiler), varargs=str) @typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW) def nvcc_arch_flags(self, state: 'ModuleState', - args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], + args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]], kwargs: ArchFlagsKwargs) -> T.List[str]: nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[0] return ret + @typed_pos_args('cuda.nvcc_arch_readable', (str, CudaCompiler), varargs=str) @typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW) def nvcc_arch_readable(self, state: 'ModuleState', - args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], + args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]], kwargs: ArchFlagsKwargs) -> T.List[str]: nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) ret = self._nvcc_arch_flags(*nvcc_arch_args)[1] @@ -123,21 +124,15 @@ class CudaModule(NewExtensionModule): return [c.detected_cc] return [] - def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs): - argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!') + def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs): - if len(args) < 1: - raise argerror + compiler = args[0] + if isinstance(compiler, CudaCompiler): + cuda_version = compiler.version else: - compiler = args[0] - if isinstance(compiler, CudaCompiler): - cuda_version = compiler.version - elif isinstance(compiler, str): - cuda_version = compiler - else: - raise argerror + cuda_version = compiler - arch_list = [] if len(args) <= 1 else flatten(args[1:]) + arch_list = args[1] arch_list = [self._break_arch_string(a) for a in arch_list] arch_list = flatten(arch_list) if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):