cuda module: use typed_kwargs

This officially only ever accepted string or array of strings.
This commit is contained in:
Eli Schwartz 2023-10-15 19:01:34 -04:00
parent 6f7e745052
commit cf35d9b4ce
No known key found for this signature in database
GPG Key ID: CEB167EFB5722BD6
1 changed files with 19 additions and 12 deletions

View File

@ -8,18 +8,26 @@ import re
from ..mesonlib import version_compare from ..mesonlib import version_compare
from ..compilers.cuda import CudaCompiler from ..compilers.cuda import CudaCompiler
from ..interpreter.type_checking import NoneType
from . import NewExtensionModule, ModuleInfo from . import NewExtensionModule, ModuleInfo
from ..interpreterbase import ( from ..interpreterbase import (
flatten, permittedKwargs, noKwargs, ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
InvalidArguments
) )
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
from typing_extensions import TypedDict
from . import ModuleState from . import ModuleState
from ..compilers import Compiler from ..compilers import Compiler
class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]]
DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True)
class CudaModule(NewExtensionModule): class CudaModule(NewExtensionModule):
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True) INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)
@ -87,18 +95,18 @@ class CudaModule(NewExtensionModule):
return driver_version return driver_version
@permittedKwargs(['detected']) @typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
def nvcc_arch_flags(self, state: 'ModuleState', def nvcc_arch_flags(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
kwargs: T.Dict[str, T.Any]) -> T.List[str]: kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0] ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
return ret return ret
@permittedKwargs(['detected']) @typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
def nvcc_arch_readable(self, state: 'ModuleState', def nvcc_arch_readable(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]], args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
kwargs: T.Dict[str, T.Any]) -> T.List[str]: kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs) nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1] ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
return ret return ret
@ -110,10 +118,10 @@ class CudaModule(NewExtensionModule):
return s return s
@staticmethod @staticmethod
def _detected_cc_from_compiler(c): def _detected_cc_from_compiler(c) -> T.List[str]:
if isinstance(c, CudaCompiler): if isinstance(c, CudaCompiler):
return c.detected_cc return [c.detected_cc]
return '' return []
@staticmethod @staticmethod
def _version_from_compiler(c): def _version_from_compiler(c):
@ -123,7 +131,7 @@ class CudaModule(NewExtensionModule):
return c return c
return 'unknown' return 'unknown'
def _validate_nvcc_arch_args(self, args, kwargs): def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!') argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
if len(args) < 1: if len(args) < 1:
@ -141,8 +149,7 @@ class CudaModule(NewExtensionModule):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''') raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
detected = kwargs.get('detected', self._detected_cc_from_compiler(compiler)) detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler)
detected = flatten([detected])
detected = [self._break_arch_string(a) for a in detected] detected = [self._break_arch_string(a) for a in detected]
detected = flatten(detected) detected = flatten(detected)
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}): if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):