cuda module: use typed_pos_args for most methods
The min_driver_version function has an extensive, informative custom error message, so leave that in place. The other two functions didn't have much information there, and it's fairly evident that the cuda compiler itself is the best thing to have here. Moreover, there was some fairly gnarly code to validate the allowed values, which we can greatly simplify by uplifting the typechecking parts to the dedicated decorators that are both really good at it, and have nicely formatted error messages complete with reference to the problematic functions.
This commit is contained in:
parent
1b15176168
commit
5899daf25b
|
@ -13,14 +13,13 @@ from ..interpreter.type_checking import NoneType
|
|||
from . import NewExtensionModule, ModuleInfo
|
||||
|
||||
from ..interpreterbase import (
|
||||
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
|
||||
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args,
|
||||
)
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from . import ModuleState
|
||||
from ..compilers import Compiler
|
||||
|
||||
class ArchFlagsKwargs(TypedDict):
|
||||
detected: T.Optional[T.List[str]]
|
||||
|
@ -95,17 +94,19 @@ class CudaModule(NewExtensionModule):
|
|||
|
||||
return driver_version
|
||||
|
||||
@typed_pos_args('cuda.nvcc_arch_flags', (str, CudaCompiler), varargs=str)
|
||||
@typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
|
||||
def nvcc_arch_flags(self, state: 'ModuleState',
|
||||
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
|
||||
args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
|
||||
kwargs: ArchFlagsKwargs) -> T.List[str]:
|
||||
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
|
||||
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
|
||||
return ret
|
||||
|
||||
@typed_pos_args('cuda.nvcc_arch_readable', (str, CudaCompiler), varargs=str)
|
||||
@typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
|
||||
def nvcc_arch_readable(self, state: 'ModuleState',
|
||||
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
|
||||
args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
|
||||
kwargs: ArchFlagsKwargs) -> T.List[str]:
|
||||
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
|
||||
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
|
||||
|
@ -123,21 +124,15 @@ class CudaModule(NewExtensionModule):
|
|||
return [c.detected_cc]
|
||||
return []
|
||||
|
||||
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
|
||||
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
|
||||
def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs):
|
||||
|
||||
if len(args) < 1:
|
||||
raise argerror
|
||||
compiler = args[0]
|
||||
if isinstance(compiler, CudaCompiler):
|
||||
cuda_version = compiler.version
|
||||
else:
|
||||
compiler = args[0]
|
||||
if isinstance(compiler, CudaCompiler):
|
||||
cuda_version = compiler.version
|
||||
elif isinstance(compiler, str):
|
||||
cuda_version = compiler
|
||||
else:
|
||||
raise argerror
|
||||
cuda_version = compiler
|
||||
|
||||
arch_list = [] if len(args) <= 1 else flatten(args[1:])
|
||||
arch_list = args[1]
|
||||
arch_list = [self._break_arch_string(a) for a in arch_list]
|
||||
arch_list = flatten(arch_list)
|
||||
if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):
|
||||
|
|
Loading…
Reference in New Issue