cuda module: inline single-shot function to get compiler version
It is pretty trivial and more confusing when standalone, especially the use of a sentinel "unknown" string as a standin for "this isn't one of the allowed object types". Much easier to directly raise an error in the fallthrough/else.
This commit is contained in:
parent
8ff25c0bca
commit
1b15176168
|
@ -123,14 +123,6 @@ class CudaModule(NewExtensionModule):
|
|||
return [c.detected_cc]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _version_from_compiler(c):
|
||||
if isinstance(c, CudaCompiler):
|
||||
return c.version
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
return 'unknown'
|
||||
|
||||
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
|
||||
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
|
||||
|
||||
|
@ -138,8 +130,11 @@ class CudaModule(NewExtensionModule):
|
|||
raise argerror
|
||||
else:
|
||||
compiler = args[0]
|
||||
cuda_version = self._version_from_compiler(compiler)
|
||||
if cuda_version == 'unknown':
|
||||
if isinstance(compiler, CudaCompiler):
|
||||
cuda_version = compiler.version
|
||||
elif isinstance(compiler, str):
|
||||
cuda_version = compiler
|
||||
else:
|
||||
raise argerror
|
||||
|
||||
arch_list = [] if len(args) <= 1 else flatten(args[1:])
|
||||
|
|
Loading…
Reference in New Issue