cuda module: inline single-shot function to get compiler version
It is pretty trivial and more confusing when standalone, especially the use of a sentinel "unknown" string as a standin for "this isn't one of the allowed object types". Much easier to directly raise an error in the fallthrough/else.
This commit is contained in:
parent
8ff25c0bca
commit
1b15176168
|
@ -123,14 +123,6 @@ class CudaModule(NewExtensionModule):
|
||||||
return [c.detected_cc]
|
return [c.detected_cc]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _version_from_compiler(c):
|
|
||||||
if isinstance(c, CudaCompiler):
|
|
||||||
return c.version
|
|
||||||
if isinstance(c, str):
|
|
||||||
return c
|
|
||||||
return 'unknown'
|
|
||||||
|
|
||||||
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
|
def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
|
||||||
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
|
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
|
||||||
|
|
||||||
|
@ -138,8 +130,11 @@ class CudaModule(NewExtensionModule):
|
||||||
raise argerror
|
raise argerror
|
||||||
else:
|
else:
|
||||||
compiler = args[0]
|
compiler = args[0]
|
||||||
cuda_version = self._version_from_compiler(compiler)
|
if isinstance(compiler, CudaCompiler):
|
||||||
if cuda_version == 'unknown':
|
cuda_version = compiler.version
|
||||||
|
elif isinstance(compiler, str):
|
||||||
|
cuda_version = compiler
|
||||||
|
else:
|
||||||
raise argerror
|
raise argerror
|
||||||
|
|
||||||
arch_list = [] if len(args) <= 1 else flatten(args[1:])
|
arch_list = [] if len(args) <= 1 else flatten(args[1:])
|
||||||
|
|
Loading…
Reference in New Issue