cuda module: use typed_kwargs
This officially only ever accepted string or array of strings.
This commit is contained in:
parent
6f7e745052
commit
cf35d9b4ce
|
@ -8,18 +8,26 @@ import re
|
||||||
|
|
||||||
from ..mesonlib import version_compare
|
from ..mesonlib import version_compare
|
||||||
from ..compilers.cuda import CudaCompiler
|
from ..compilers.cuda import CudaCompiler
|
||||||
|
from ..interpreter.type_checking import NoneType
|
||||||
|
|
||||||
from . import NewExtensionModule, ModuleInfo
|
from . import NewExtensionModule, ModuleInfo
|
||||||
|
|
||||||
from ..interpreterbase import (
|
from ..interpreterbase import (
|
||||||
flatten, permittedKwargs, noKwargs,
|
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
|
||||||
InvalidArguments
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if T.TYPE_CHECKING:
|
if T.TYPE_CHECKING:
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from . import ModuleState
|
from . import ModuleState
|
||||||
from ..compilers import Compiler
|
from ..compilers import Compiler
|
||||||
|
|
||||||
|
class ArchFlagsKwargs(TypedDict):
|
||||||
|
detected: T.Optional[T.List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True)
|
||||||
|
|
||||||
class CudaModule(NewExtensionModule):
|
class CudaModule(NewExtensionModule):
|
||||||
|
|
||||||
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)
|
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)
|
||||||
|
@ -87,18 +95,18 @@ class CudaModule(NewExtensionModule):
|
||||||
|
|
||||||
return driver_version
|
return driver_version
|
||||||
|
|
||||||
@permittedKwargs(['detected'])
|
@typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
|
||||||
def nvcc_arch_flags(self, state: 'ModuleState',
|
def nvcc_arch_flags(self, state: 'ModuleState',
|
||||||
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
|
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
|
||||||
kwargs: T.Dict[str, T.Any]) -> T.List[str]:
|
kwargs: ArchFlagsKwargs) -> T.List[str]:
|
||||||
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
|
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
|
||||||
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
|
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@permittedKwargs(['detected'])
|
@typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
|
||||||
def nvcc_arch_readable(self, state: 'ModuleState',
|
def nvcc_arch_readable(self, state: 'ModuleState',
|
||||||
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
|
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
|
||||||
kwargs: T.Dict[str, T.Any]) -> T.List[str]:
|
kwargs: ArchFlagsKwargs) -> T.List[str]:
|
||||||
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
|
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
|
||||||
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
|
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
|
||||||
return ret
|
return ret
|
||||||
|
@ -110,10 +118,10 @@ class CudaModule(NewExtensionModule):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _detected_cc_from_compiler(c):
|
def _detected_cc_from_compiler(c) -> T.List[str]:
|
||||||
if isinstance(c, CudaCompiler):
|
if isinstance(c, CudaCompiler):
|
||||||
return c.detected_cc
|
return [c.detected_cc]
|
||||||
return ''
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _version_from_compiler(c):
|
def _version_from_compiler(c):
|
||||||
|
@ -123,7 +131,7 @@ class CudaModule(NewExtensionModule):
|
||||||
return c
|
return c
|
||||||
return 'unknown'
|
return 'unknown'
|
||||||
|
|
||||||
def _validate_nvcc_arch_args(self, args, kwargs):
|
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
|
||||||
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
|
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
|
||||||
|
|
||||||
if len(args) < 1:
|
if len(args) < 1:
|
||||||
|
@ -141,8 +149,7 @@ class CudaModule(NewExtensionModule):
|
||||||
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
|
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
|
||||||
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
|
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
|
||||||
|
|
||||||
detected = kwargs.get('detected', self._detected_cc_from_compiler(compiler))
|
detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler)
|
||||||
detected = flatten([detected])
|
|
||||||
detected = [self._break_arch_string(a) for a in detected]
|
detected = [self._break_arch_string(a) for a in detected]
|
||||||
detected = flatten(detected)
|
detected = flatten(detected)
|
||||||
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):
|
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):
|
||||||
|
|
Loading…
Reference in New Issue