diff --git a/level_zero/core/source/module/module_imp.cpp b/level_zero/core/source/module/module_imp.cpp index fda2822772..d32de6b2c5 100644 --- a/level_zero/core/source/module/module_imp.cpp +++ b/level_zero/core/source/module/module_imp.cpp @@ -484,6 +484,7 @@ bool ModuleImp::initialize(const ze_module_desc_t *desc, NEO::Device *neoDevice) std::vector inputSpirVs; std::vector inputModuleSizes; std::vector specConstants; + const ze_module_constants_t *firstSpecConstants = nullptr; this->createBuildOptions(nullptr, buildOptions, internalBuildOptions); @@ -496,6 +497,9 @@ bool ModuleImp::initialize(const ze_module_desc_t *desc, NEO::Device *neoDevice) inputModuleSizes.push_back(inputSize); if (programExpDesc->pConstants) { specConstants.push_back(programExpDesc->pConstants[i]); + if (i == 0) { + firstSpecConstants = specConstants[0]; + } } if (programExpDesc->pBuildFlags) { this->createBuildOptions(programExpDesc->pBuildFlags[i], tmpBuildOptions, tmpInternalBuildOptions); @@ -503,12 +507,20 @@ bool ModuleImp::initialize(const ze_module_desc_t *desc, NEO::Device *neoDevice) internalBuildOptions = internalBuildOptions + tmpInternalBuildOptions; } } - - success = this->translationUnit->staticLinkSpirV(inputSpirVs, - inputModuleSizes, - buildOptions.c_str(), - internalBuildOptions.c_str(), - specConstants); + //If the user passed in only 1 SPIRV, then fallback to standard build + if (inputSpirVs.size() > 1) { + success = this->translationUnit->staticLinkSpirV(inputSpirVs, + inputModuleSizes, + buildOptions.c_str(), + internalBuildOptions.c_str(), + specConstants); + } else { + success = this->translationUnit->buildFromSpirV(reinterpret_cast(programExpDesc->pInputModules[0]), + inputModuleSizes[0], + buildOptions.c_str(), + internalBuildOptions.c_str(), + firstSpecConstants); + } } else { return false; } diff --git a/level_zero/core/test/unit_tests/sources/module/test_module.cpp b/level_zero/core/test/unit_tests/sources/module/test_module.cpp index c8864f5763..4ab6c80245 100644 --- a/level_zero/core/test/unit_tests/sources/module/test_module.cpp +++ b/level_zero/core/test/unit_tests/sources/module/test_module.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020-2021 Intel Corporation + * Copyright (C) 2020-2022 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -777,29 +777,35 @@ struct ModuleStaticLinkFixture : public DeviceFixture { DeviceFixture::TearDown(); } - void loadMultipleModules() { + void loadModules(bool multiple) { std::string testFile; retrieveBinaryKernelFilenameNoRevision(testFile, binaryFilename + "_", ".spv"); srcModule1 = loadDataFromFile(testFile.c_str(), sizeModule1); - srcModule2 = loadDataFromFile(testFile.c_str(), sizeModule2); + if (multiple) { + srcModule2 = loadDataFromFile(testFile.c_str(), sizeModule2); + } ASSERT_NE(0u, sizeModule1); - ASSERT_NE(0u, sizeModule2); ASSERT_NE(nullptr, srcModule1); - ASSERT_NE(nullptr, srcModule2); + if (multiple) { + ASSERT_NE(0u, sizeModule2); + ASSERT_NE(nullptr, srcModule2); + } } - void setupExpProgramDesc(ze_module_format_t format) { + void setupExpProgramDesc(ze_module_format_t format, bool multiple) { combinedModuleDesc.format = format; combinedModuleDesc.pNext = &staticLinkModuleDesc; inputSizes.push_back(sizeModule1); inputSpirVs.push_back(reinterpret_cast(srcModule1.get())); - inputSizes.push_back(sizeModule2); - inputSpirVs.push_back(reinterpret_cast(srcModule2.get())); - - staticLinkModuleDesc.count = 2; + staticLinkModuleDesc.count = 1; + if (multiple) { + inputSizes.push_back(sizeModule2); + inputSpirVs.push_back(reinterpret_cast(srcModule2.get())); + staticLinkModuleDesc.count = 2; + } staticLinkModuleDesc.inputSizes = inputSizes.data(); staticLinkModuleDesc.pInputModules = inputSpirVs.data(); } @@ -811,9 +817,9 @@ struct ModuleStaticLinkFixture : public DeviceFixture { rootDeviceEnvironment->compilerInterface.reset(mockCompiler); mockTranslationUnit = new MockModuleTranslationUnit(device); - loadMultipleModules(); + loadModules(testMultiple); - setupExpProgramDesc(ZE_MODULE_FORMAT_IL_SPIRV); + setupExpProgramDesc(ZE_MODULE_FORMAT_IL_SPIRV, testMultiple); auto module = new Module(device, nullptr, ModuleType::User); module->translationUnit.reset(mockTranslationUnit); @@ -829,9 +835,9 @@ struct ModuleStaticLinkFixture : public DeviceFixture { rootDeviceEnvironment->compilerInterface.reset(mockCompiler); mockTranslationUnit = new MockModuleTranslationUnit(device); - loadMultipleModules(); + loadModules(testMultiple); - setupExpProgramDesc(ZE_MODULE_FORMAT_NATIVE); + setupExpProgramDesc(ZE_MODULE_FORMAT_NATIVE, testMultiple); auto module = new Module(device, nullptr, ModuleType::User); module->translationUnit.reset(mockTranslationUnit); @@ -865,9 +871,9 @@ struct ModuleStaticLinkFixture : public DeviceFixture { rootDeviceEnvironment->compilerInterface.reset(mockCompiler); mockTranslationUnit = new MockModuleTranslationUnit(device); - loadMultipleModules(); + loadModules(testMultiple); - setupExpProgramDesc(ZE_MODULE_FORMAT_IL_SPIRV); + setupExpProgramDesc(ZE_MODULE_FORMAT_IL_SPIRV, testMultiple); std::vector buildFlags; std::string module1BuildFlags("-ze-opt-disable"); @@ -884,6 +890,30 @@ struct ModuleStaticLinkFixture : public DeviceFixture { EXPECT_TRUE(success); module->destroy(); } + void runSprivLinkBuildWithOneModule() { + MockCompilerInterface *mockCompiler; + mockCompiler = new MockCompilerInterface(); + auto rootDeviceEnvironment = neoDevice->getExecutionEnvironment()->rootDeviceEnvironments[0].get(); + rootDeviceEnvironment->compilerInterface.reset(mockCompiler); + mockTranslationUnit = new MockModuleTranslationUnit(device); + + loadModules(testSingle); + + setupExpProgramDesc(ZE_MODULE_FORMAT_IL_SPIRV, testSingle); + + std::vector buildFlags; + std::string module1BuildFlags("-ze-opt-disable"); + buildFlags.push_back(const_cast(module1BuildFlags.c_str())); + + staticLinkModuleDesc.pBuildFlags = const_cast(buildFlags.data()); + + auto module = new Module(device, nullptr, ModuleType::User); + module->translationUnit.reset(mockTranslationUnit); + + bool success = module->initialize(&combinedModuleDesc, neoDevice); + EXPECT_TRUE(success); + module->destroy(); + } const std::string binaryFilename = "test_kernel"; const std::string kernelName = "test"; MockModuleTranslationUnit *mockTranslationUnit; @@ -894,6 +924,8 @@ struct ModuleStaticLinkFixture : public DeviceFixture { std::vector inputSizes; ze_module_desc_t combinedModuleDesc = {ZE_STRUCTURE_TYPE_MODULE_DESC}; ze_module_program_exp_desc_t staticLinkModuleDesc = {ZE_STRUCTURE_TYPE_MODULE_PROGRAM_EXP_DESC}; + bool testMultiple = true; + bool testSingle = false; }; using ModuleStaticLinkTests = Test; @@ -914,6 +946,10 @@ TEST_F(ModuleStaticLinkTests, givenMultipleModulesProvidedForSpirVStaticLinkAndB runSprivLinkBuildFlags(); } +TEST_F(ModuleStaticLinkTests, givenSingleModuleProvidedForSpirVStaticLinkAndBuildFlagsRequestedThenSuccessisReturned) { + runSprivLinkBuildWithOneModule(); +} + using ModuleLinkingTest = Test; HWTEST_F(ModuleLinkingTest, whenExternFunctionsAllocationIsPresentThenItsBeingAddedToResidencyContainer) {