cuda module: inline single-shot function to get compiler version

It is pretty trivial and more confusing when standalone, especially the
use of a sentinel "unknown" string as a standin for "this isn't one of
the allowed object types". Much easier to directly raise an error in the
fallthrough/else.
This commit is contained in:
Eli Schwartz 2024-01-03 23:29:19 -05:00
parent 8ff25c0bca
commit 1b15176168
No known key found for this signature in database
GPG Key ID: CEB167EFB5722BD6
1 changed files with 5 additions and 10 deletions

View File

@ -123,14 +123,6 @@ class CudaModule(NewExtensionModule):
return [c.detected_cc] return [c.detected_cc]
return [] return []
@staticmethod
def _version_from_compiler(c):
if isinstance(c, CudaCompiler):
return c.version
if isinstance(c, str):
return c
return 'unknown'
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs): def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!') argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
@ -138,8 +130,11 @@ class CudaModule(NewExtensionModule):
raise argerror raise argerror
else: else:
compiler = args[0] compiler = args[0]
cuda_version = self._version_from_compiler(compiler) if isinstance(compiler, CudaCompiler):
if cuda_version == 'unknown': cuda_version = compiler.version
elif isinstance(compiler, str):
cuda_version = compiler
else:
raise argerror raise argerror
arch_list = [] if len(args) <= 1 else flatten(args[1:]) arch_list = [] if len(args) <= 1 else flatten(args[1:])