mirror of
				https://github.com/intel/intel-graphics-compiler.git
				synced 2025-10-30 08:18:26 +08:00 
			
		
		
		
	Rewrite TargetExtTy retyper using ValueMapTypeRemapper
This change rewrites the TargetExtTy retyper to use the ValueMapTypeRemapper infrastructure, significantly improving the overall design and maintainability of the code. The change also removes unused cases added for additional safety if earlier retyping logic fails. Two additional test cases are added, covering more complex retyping scenarios.
This commit is contained in:
		 Michal Paszkowski
					Michal Paszkowski
				
			
				
					committed by
					
						 igcbot
						igcbot
					
				
			
			
				
	
			
			
			 igcbot
						igcbot
					
				
			
						parent
						
							ba8538b4e6
						
					
				
				
					commit
					d7a41cf31b
				
			| @ -0,0 +1,22 @@ | ||||
| ;=========================== begin_copyright_notice ============================ | ||||
| ; | ||||
| ; Copyright (C) 2025 Intel Corporation | ||||
| ; | ||||
| ; SPDX-License-Identifier: MIT | ||||
| ; | ||||
| ;============================ end_copyright_notice ============================= | ||||
|  | ||||
| ; REQUIRES: llvm-16-plus | ||||
| ; RUN: igc_opt --opaque-pointers -igc-preprocess-spvir -S < %s | FileCheck %s | ||||
|  | ||||
| ; This test verifies that PreprocessSPVIR retypes TargetExtTys in function | ||||
| ; declarations. | ||||
|  | ||||
| define spir_func void @foo(i64 %in) { | ||||
|   %img = call spir_func target("spirv.SampledImage", void, 1, 0, 0, 0, 0, 0, 0) @_Z90__spirv_ConvertHandleToSampledImageINTEL_RPU3AS140__spirv_SampledImage__void_1_0_0_0_0_0_0m(i64 %in) | ||||
|   ret void | ||||
| } | ||||
|  | ||||
| declare spir_func target("spirv.SampledImage", void, 1, 0, 0, 0, 0, 0, 0) @_Z90__spirv_ConvertHandleToSampledImageINTEL_RPU3AS140__spirv_SampledImage__void_1_0_0_0_0_0_0m(i64) | ||||
|  | ||||
| ; CHECK: declare spir_func ptr @_Z90__spirv_ConvertHandleToSampledImageINTEL_RPU3AS140__spirv_SampledImage__void_1_0_0_0_0_0_0m(i64) | ||||
| @ -0,0 +1,27 @@ | ||||
| ;=========================== begin_copyright_notice ============================ | ||||
| ; | ||||
| ; Copyright (C) 2025 Intel Corporation | ||||
| ; | ||||
| ; SPDX-License-Identifier: MIT | ||||
| ; | ||||
| ;============================ end_copyright_notice ============================= | ||||
|  | ||||
| ; REQUIRES: llvm-16-plus | ||||
| ; RUN: igc_opt --opaque-pointers -igc-preprocess-spvir -S < %s | FileCheck %s | ||||
|  | ||||
| ; This test verifies PreprocessSPVIR retypes TargetExtTy constant | ||||
| ; zeroinitializer arguments to pointer null. | ||||
|  | ||||
| declare spir_func void @UseEvent(target("spirv.Event") %A) | ||||
|  | ||||
| define spir_kernel void @TestKernel() { | ||||
| entry: | ||||
|   call spir_func void @UseEvent(target("spirv.Event") zeroinitializer) | ||||
|   ret void | ||||
| } | ||||
|  | ||||
| ; CHECK-LABEL: define spir_kernel void @TestKernel( | ||||
| ; CHECK: call spir_func void @UseEvent( | ||||
| ; CHECK-SAME: ptr null | ||||
| ; CHECK-NOT: target("spirv.Event") | ||||
| ; CHECK: ret void | ||||
| @ -0,0 +1,30 @@ | ||||
| ;=========================== begin_copyright_notice ============================ | ||||
| ; | ||||
| ; Copyright (C) 2025 Intel Corporation | ||||
| ; | ||||
| ; SPDX-License-Identifier: MIT | ||||
| ; | ||||
| ;============================ end_copyright_notice ============================= | ||||
|  | ||||
| ; REQUIRES: llvm-16-plus | ||||
| ; RUN: igc_opt --opaque-pointers -igc-preprocess-spvir -S < %s | FileCheck %s | ||||
|  | ||||
| ; This test verifies PreprocessSPVIR retypes TargetExtTy inside structs used by | ||||
| ; GEPs. | ||||
|  | ||||
| ; CHECK-NOT: target("spirv.Event") | ||||
|  | ||||
| %"class.sycl::_V1::device_event" = type { target("spirv.Event") } | ||||
| ; CHECK: %"class.sycl::_V1::device_event" = type { ptr } | ||||
|  | ||||
| define internal spir_func void @_ZN4sycl3_V112device_event4waitEv(ptr addrspace(4) align 8 %arg) { | ||||
|   %a = alloca ptr addrspace(4), align 8 | ||||
|   %b = addrspacecast ptr %a to ptr addrspace(4) | ||||
|   store ptr addrspace(4) %arg, ptr addrspace(4) %b, align 8 | ||||
|   %c = load ptr addrspace(4), ptr addrspace(4) %b, align 8 | ||||
| ; CHECK: %d = getelementptr inbounds %"class.sycl::_V1::device_event", ptr addrspace(4) %c, i32 0, i32 0 | ||||
|   %d = getelementptr inbounds %"class.sycl::_V1::device_event", ptr addrspace(4) %c, i32 0, i32 0 | ||||
| ; CHECK: %e = getelementptr inbounds %"class.sycl::_V1::device_event", ptr addrspace(4) %c, i32 0, i32 0 | ||||
|   %e = getelementptr inbounds %"class.sycl::_V1::device_event", ptr addrspace(4) %c, i32 0, i32 0 | ||||
|   ret void | ||||
| } | ||||
| @ -0,0 +1,49 @@ | ||||
| ;=========================== begin_copyright_notice ============================ | ||||
| ; | ||||
| ; Copyright (C) 2025 Intel Corporation | ||||
| ; | ||||
| ; SPDX-License-Identifier: MIT | ||||
| ; | ||||
| ;============================ end_copyright_notice ============================= | ||||
|  | ||||
| ; REQUIRES: llvm-16-plus | ||||
| ; RUN: igc_opt --opaque-pointers -igc-preprocess-spvir -S < %s | FileCheck %s | ||||
|  | ||||
| ; Verify that the PreprocessSPVIR pass retypes struct members of TargetExtTy | ||||
| ; type as pointers. This is needed so that the following passes can perform | ||||
| ; pointer optimizations correctly and the IR can be linked with builtins module | ||||
| ; coming from Clang (LLVM 16 Clang does not support TargetExtTy). | ||||
|  | ||||
| %struct.SamplerHolder = type { target("spirv.Sampler"), i32 } | ||||
| %struct.Wrapper = type { %struct.SamplerHolder, i64 } | ||||
|  | ||||
| define spir_kernel void @Kernel(target("spirv.Sampler") %arg) { | ||||
| entry: | ||||
|   %holder = alloca %struct.SamplerHolder, align 8 | ||||
|   %wrapper = alloca %struct.Wrapper, align 8 | ||||
|   %holder.sampler.field.gep = getelementptr inbounds %struct.SamplerHolder, ptr %holder, i32 0, i32 0 | ||||
|   store target("spirv.Sampler") %arg, ptr %holder.sampler.field.gep, align 8 | ||||
|   %wrapper.holder.gep = getelementptr inbounds %struct.Wrapper, ptr %wrapper, i32 0, i32 0 | ||||
|   %wrapper.holder.sampler.gep = getelementptr inbounds %struct.SamplerHolder, ptr %wrapper.holder.gep, i32 0, i32 0 | ||||
|   %loaded.sampler = load target("spirv.Sampler"), ptr %holder.sampler.field.gep, align 8 | ||||
|   store target("spirv.Sampler") %loaded.sampler, ptr %wrapper.holder.sampler.gep, align 8 | ||||
|   call spir_func void @Helper(target("spirv.Sampler") %loaded.sampler) | ||||
|   ret void | ||||
| } | ||||
|  | ||||
| define internal spir_func void @Helper(target("spirv.Sampler") %S) { | ||||
| entry: | ||||
|   ret void | ||||
| } | ||||
|  | ||||
| ; Check that struct fields are retyped, including nested/wrapped structs. | ||||
| ; CHECK: %struct.SamplerHolder = type { ptr addrspace(2), i32 } | ||||
| ; CHECK: %struct.Wrapper = type { %struct.SamplerHolder, i64 } | ||||
|  | ||||
| ; CHECK-LABEL: define spir_kernel void @Kernel( | ||||
| ; CHECK-SAME: ptr addrspace(2) %arg | ||||
| ; CHECK: %holder = alloca %struct.SamplerHolder | ||||
| ; CHECK: %wrapper = alloca %struct.Wrapper | ||||
|  | ||||
| ; CHECK-NOT: target("spirv.Sampler") | ||||
|  | ||||
| @ -0,0 +1,30 @@ | ||||
| ;=========================== begin_copyright_notice ============================ | ||||
| ; | ||||
| ; Copyright (C) 2025 Intel Corporation | ||||
| ; | ||||
| ; SPDX-License-Identifier: MIT | ||||
| ; | ||||
| ;============================ end_copyright_notice ============================= | ||||
|  | ||||
| ; REQUIRES: llvm-16-plus | ||||
| ; RUN: igc_opt --opaque-pointers -igc-preprocess-spvir -S < %s | FileCheck %s | ||||
|  | ||||
| ; This test verifies PreprocessSPVIR retypes TargetExtTy inside structs used in | ||||
| ; function argument sret attributes. | ||||
|  | ||||
| %union.anon = type { ptr addrspace(1) } | ||||
| %"class.sycl::_V1::multi_ptr" = type { ptr addrspace(3) } | ||||
| %"class.sycl::_V1::device_event" = type { target("spirv.Event") } | ||||
|  | ||||
| ; CHECK-NOT: target("spirv.Event") | ||||
| ; CHECK: %"class.sycl::_V1::device_event" = type { ptr } | ||||
|  | ||||
| define void @f(ptr addrspace(4) noalias sret(%"class.sycl::_V1::device_event") align 8 %arg1, ptr addrspace(4) align 1 %arg2, ptr byval(%"class.sycl::_V1::multi_ptr") align 8 %arg3, ptr byval(%union.anon) align 8 %arg4, i64 %arg5) { | ||||
|   call spir_func void @g(ptr addrspace(4) noalias sret(%"class.sycl::_V1::device_event") align 8 %arg1, ptr addrspace(4) align 1 %arg2, ptr byval(%"class.sycl::_V1::multi_ptr") align 8 %arg3, ptr byval(%union.anon) align 8 %arg4, i64 1, i64 1) | ||||
| ; CHECK: call spir_func void @g(ptr addrspace(4) noalias sret(%"class.sycl::_V1::device_event") align 8 %arg1, ptr addrspace(4) align 1 %arg2, ptr byval(%"class.sycl::_V1::multi_ptr") align 8 %arg3, ptr byval(%union.anon) align 8 %arg4, i64 1, i64 1) | ||||
|   ret void | ||||
| } | ||||
|  | ||||
| define void @g(ptr addrspace(4) noalias sret(%"class.sycl::_V1::device_event") align 8 %arg1, ptr addrspace(4) align 1 %arg2, ptr byval(%"class.sycl::_V1::multi_ptr") align 8 %arg3, ptr byval(%union.anon) align 8 %arg4, i64 %arg5, i64 %arg6) { | ||||
|   ret void | ||||
| } | ||||
| @ -26,26 +26,23 @@ SPDX-License-Identifier: MIT | ||||
| #include "Compiler/CodeGenPublicEnums.h" | ||||
| #include "Probe/Assertion.h" | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <limits> | ||||
|  | ||||
| using namespace llvm; | ||||
|  | ||||
| namespace IGC { | ||||
| bool isTargetExtTy(const llvm::Type *Ty) { | ||||
| bool isTargetExtTy(const Type *Ty) { | ||||
| #if LLVM_VERSION_MAJOR >= 16 | ||||
|   return Ty->isTargetExtTy(); | ||||
| #endif | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| bool isImageBuiltinType(const llvm::Type *BuiltinTy) { | ||||
| bool isImageBuiltinType(const Type *BuiltinTy) { | ||||
|   if (BuiltinTy->isPointerTy() && !IGCLLVM::isOpaquePointerTy(BuiltinTy)) | ||||
|     BuiltinTy = IGCLLVM::getNonOpaquePtrEltTy(BuiltinTy); | ||||
|  | ||||
|   if (const StructType *StructTy = dyn_cast<StructType>(BuiltinTy); StructTy && StructTy->isOpaque()) { | ||||
|     StringRef BuiltinName = StructTy->getName(); | ||||
|     llvm::SmallVector<llvm::StringRef, 3> Buffer; | ||||
|     SmallVector<StringRef, 3> Buffer; | ||||
|     BuiltinName.split(Buffer, "."); | ||||
|     if (Buffer.size() < 2) | ||||
|       return false; | ||||
| @ -67,8 +64,8 @@ bool isImageBuiltinType(const llvm::Type *BuiltinTy) { | ||||
| } | ||||
|  | ||||
| #if LLVM_VERSION_MAJOR >= 16 | ||||
| static bool isNonOpenCLBuiltinType(const llvm::Type *Ty) { | ||||
|   const llvm::TargetExtType *TET = dyn_cast<llvm::TargetExtType>(Ty); | ||||
| static bool isNonOpenCLBuiltinType(const Type *Ty) { | ||||
|   const TargetExtType *TET = dyn_cast<TargetExtType>(Ty); | ||||
|   if (!TET) | ||||
|     return false; | ||||
|  | ||||
| @ -76,633 +73,326 @@ static bool isNonOpenCLBuiltinType(const llvm::Type *Ty) { | ||||
|   return Name.starts_with("spirv.CooperativeMatrixKHR") || Name.starts_with("spirv.JointMatrixINTEL"); | ||||
| } | ||||
|  | ||||
| static bool isAnyArgOpenCLTargetExtTy(const llvm::Function &F) { | ||||
|   for (const llvm::Argument &A : F.args()) { | ||||
|     const Type *ArgTy = A.getType(); | ||||
|     if (isTargetExtTy(ArgTy) && !isNonOpenCLBuiltinType(ArgTy)) | ||||
|       return true; | ||||
|   } | ||||
| static bool isOpenCLTargetExtType(const Type *Ty) { return isTargetExtTy(Ty) && !isNonOpenCLBuiltinType(Ty); } | ||||
|  | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| static bool isDeclarationWithOpenCLTargetExtTyRet(const llvm::Function &F) { | ||||
|   const Type *RetTy = F.getReturnType(); | ||||
|   return F.isDeclaration() && isTargetExtTy(RetTy) && !isNonOpenCLBuiltinType(RetTy); | ||||
| } | ||||
|  | ||||
| static unsigned getAddressSpaceForTargetExtTy(const llvm::TargetExtType *TargetExtTy) { | ||||
|   StringRef TyName = TargetExtTy->getName(); | ||||
|   if (TyName.startswith("spirv.Queue")) | ||||
|     return ADDRESS_SPACE_PRIVATE; | ||||
|   else if (TyName.startswith("spirv.Image")) | ||||
|     return ADDRESS_SPACE_GLOBAL; | ||||
|   else if (TyName.startswith("spirv.Sampler")) | ||||
|     return ADDRESS_SPACE_CONSTANT; | ||||
|  | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| static void retypeLocalTargetExtAllocasLoadsStores(Function &F) { | ||||
|   SmallVector<AllocaInst *, 8> Allocas; | ||||
|   SmallVector<LoadInst *, 16> Loads; | ||||
|   SmallVector<StoreInst *, 16> Stores; | ||||
|  | ||||
|   for (BasicBlock &BB : F) { | ||||
|     for (Instruction &I : BB) { | ||||
|       if (auto *AI = dyn_cast<AllocaInst>(&I)) { | ||||
|         Type *AllocTy = AI->getAllocatedType(); | ||||
|         if (isTargetExtTy(AllocTy) && !isNonOpenCLBuiltinType(AllocTy)) | ||||
|           Allocas.push_back(AI); | ||||
|       } else if (auto *LI = dyn_cast<LoadInst>(&I)) { | ||||
|         Type *ValTy = LI->getType(); | ||||
|         if (isTargetExtTy(ValTy) && !isNonOpenCLBuiltinType(ValTy)) | ||||
|           Loads.push_back(LI); | ||||
|       } else if (auto *SI = dyn_cast<StoreInst>(&I)) { | ||||
|         Value *Val = SI->getValueOperand(); | ||||
|         Type *ValTy = Val->getType(); | ||||
|         if (isTargetExtTy(ValTy) && !isNonOpenCLBuiltinType(ValTy)) | ||||
|           Stores.push_back(SI); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   DenseMap<Value *, Value *> RetypedValueMap; | ||||
|   LLVMContext &C = F.getContext(); | ||||
|  | ||||
|   // Retype allocas (preserve original names). | ||||
|   for (AllocaInst *AI : Allocas) { | ||||
|     std::string OldName = AI->hasName() ? std::string(AI->getName()) : std::string(); | ||||
|     const auto *TET = cast<TargetExtType>(AI->getAllocatedType()); | ||||
|     Type *PtrElemTy = PointerType::get(C, getAddressSpaceForTargetExtTy(cast<TargetExtType>(TET))); | ||||
|  | ||||
|     // Create new alloca without name to avoid temporary suffixing. | ||||
|     AllocaInst *NewAI = new AllocaInst(PtrElemTy, AI->getAddressSpace(), nullptr, AI->getAlign(), "", AI); | ||||
|     NewAI->setDebugLoc(AI->getDebugLoc()); | ||||
|  | ||||
|     // Copy metadata | ||||
|     SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; | ||||
|     AI->getAllMetadata(MDs); | ||||
|     for (auto &MDPair : MDs) | ||||
|       NewAI->setMetadata(MDPair.first, MDPair.second); | ||||
|  | ||||
|     RetypedValueMap[AI] = NewAI; | ||||
|     AI->replaceAllUsesWith(NewAI); | ||||
|     AI->eraseFromParent(); | ||||
|  | ||||
|     if (!OldName.empty()) | ||||
|       NewAI->setName(OldName); | ||||
|   } | ||||
|  | ||||
|   struct RetypedLoadInfo { | ||||
|     LoadInst *Old; | ||||
|     LoadInst *New; | ||||
|     std::string OldName; | ||||
|   }; | ||||
|   SmallVector<RetypedLoadInfo, 16> RetypedLoads; | ||||
|  | ||||
|   // Create replacement loads (do NOT RAUW directly due to type change). | ||||
|   for (LoadInst *LI : Loads) { | ||||
|     std::string OldName = LI->hasName() ? std::string(LI->getName()) : std::string(); | ||||
|     const auto *TET = cast<TargetExtType>(LI->getType()); | ||||
|     Type *NewValTy = PointerType::get(C, getAddressSpaceForTargetExtTy(cast<TargetExtType>(TET))); | ||||
|  | ||||
|     Value *PtrOp = LI->getPointerOperand(); | ||||
|     PointerType *DesiredPtrTy = PointerType::get(NewValTy, PtrOp->getType()->getPointerAddressSpace()); | ||||
|     if (PtrOp->getType() != DesiredPtrTy) | ||||
|       PtrOp = new BitCastInst(PtrOp, DesiredPtrTy, PtrOp->getName() + ".retycast", LI); | ||||
|  | ||||
|     // Create unnamed new load before old one. | ||||
|     LoadInst *NewLoad = new LoadInst(NewValTy, PtrOp, "", LI); | ||||
|     NewLoad->setAlignment(LI->getAlign()); | ||||
|     NewLoad->setDebugLoc(LI->getDebugLoc()); | ||||
|  | ||||
|     SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; | ||||
|     LI->getAllMetadata(MDs); | ||||
|     for (auto &MDPair : MDs) | ||||
|       NewLoad->setMetadata(MDPair.first, MDPair.second); | ||||
|  | ||||
|     RetypedValueMap[LI] = NewLoad; | ||||
|  | ||||
|     // Update non-store users immediately where the type already matches. | ||||
|     SmallVector<Use *, 8> UsesToUpdate; | ||||
|     for (Use &U : LI->uses()) | ||||
|       UsesToUpdate.push_back(&U); | ||||
|  | ||||
|     for (Use *U : UsesToUpdate) { | ||||
|       User *Usr = U->getUser(); | ||||
|       if (isa<StoreInst>(Usr)) | ||||
|         continue; // Store handled later. | ||||
|       if (U->get()->getType() == NewLoad->getType()) | ||||
|         *U = NewLoad; | ||||
|     } | ||||
|  | ||||
|     RetypedLoads.push_back({LI, NewLoad, OldName}); | ||||
|   } | ||||
|  | ||||
|   // Retype stores whose value operand was retyped. | ||||
|   for (StoreInst *SI : Stores) { | ||||
|     Value *OldVal = SI->getValueOperand(); | ||||
|     auto *TET = cast<TargetExtType>(OldVal->getType()); | ||||
|     Type *NewValTy = PointerType::get(C, getAddressSpaceForTargetExtTy(cast<TargetExtType>(TET))); | ||||
|  | ||||
|     Value *NewVal = RetypedValueMap.lookup(OldVal); | ||||
|     if (!NewVal) | ||||
|       continue; // Producer not transformed, leave store as-is. | ||||
|  | ||||
|     Value *PtrOp = SI->getPointerOperand(); | ||||
|     PointerType *DesiredPtrTy = PointerType::get(NewValTy, PtrOp->getType()->getPointerAddressSpace()); | ||||
|     if (PtrOp->getType() != DesiredPtrTy) | ||||
|       PtrOp = new BitCastInst(PtrOp, DesiredPtrTy, PtrOp->getName() + ".retycast", SI); | ||||
|  | ||||
|     StoreInst *NewStore = new StoreInst(NewVal, PtrOp, SI); | ||||
|     NewStore->setAlignment(SI->getAlign()); | ||||
|     NewStore->setDebugLoc(SI->getDebugLoc()); | ||||
|  | ||||
|     SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; | ||||
|     SI->getAllMetadata(MDs); | ||||
|     for (auto &MDPair : MDs) | ||||
|       NewStore->setMetadata(MDPair.first, MDPair.second); | ||||
|  | ||||
|     SI->eraseFromParent(); | ||||
|   } | ||||
|  | ||||
|   // Now safely remove old loads and restore original names. | ||||
|   for (auto &RL : RetypedLoads) { | ||||
|     if (RL.Old->use_empty()) { | ||||
|       RL.Old->eraseFromParent(); | ||||
|       if (!RL.OldName.empty()) | ||||
|         RL.New->setName(RL.OldName); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| static Function *cloneFunctionWithPtrArgsandRetInsteadTargetExtTy(Function &F, StringRef NameSuffix) { | ||||
|   Module &M = *F.getParent(); | ||||
|   LLVMContext &C = M.getContext(); | ||||
|  | ||||
|   SmallVector<Type *, 8> ParamTys; | ||||
|   ParamTys.reserve(F.arg_size()); | ||||
|   for (Argument &Arg : F.args()) { | ||||
|     Type *T = Arg.getType(); | ||||
|     if (isTargetExtTy(T) && !isNonOpenCLBuiltinType(T)) { | ||||
|       T = PointerType::get(C, getAddressSpaceForTargetExtTy(cast<TargetExtType>(T))); | ||||
|     } | ||||
|     ParamTys.push_back(T); | ||||
|   } | ||||
|  | ||||
|   Type *NewRetTy = F.getReturnType(); | ||||
|   if (F.isDeclaration() && isTargetExtTy(NewRetTy) && !isNonOpenCLBuiltinType(NewRetTy)) | ||||
|     NewRetTy = PointerType::get(C, getAddressSpaceForTargetExtTy(cast<TargetExtType>(NewRetTy))); | ||||
|  | ||||
|   FunctionType *NewFTy = FunctionType::get(NewRetTy, ParamTys, F.isVarArg()); | ||||
|  | ||||
|   Function *NewF = Function::Create(NewFTy, F.getLinkage(), F.getAddressSpace(), F.getName() + NameSuffix, &M); | ||||
|   NewF->copyAttributesFrom(&F); | ||||
|  | ||||
|   ValueToValueMapTy VMap; | ||||
|   auto NI = NewF->arg_begin(); | ||||
|   for (Argument &OI : F.args()) { | ||||
|     NI->setName(OI.getName()); | ||||
|     VMap[&OI] = &*NI++; | ||||
|   } | ||||
|  | ||||
|   SmallVector<ReturnInst *, 8> Rets; | ||||
|   CloneFunctionInto(NewF, &F, VMap, CloneFunctionChangeType::LocalChangesOnly, Rets); | ||||
|  | ||||
|   return NewF; | ||||
| } | ||||
|  | ||||
| static void replaceFunctionAtCallsites(Function &OldF, Function &NewF) { | ||||
|   // Helper for finding an earlier load in the same basic block that already loads the retyped (pointer) value from the | ||||
|   // same address. The non-retyped loads still remain since it is not possible to retype callsites (their users) before | ||||
|   // retyping function signatures. | ||||
|   auto findEarlierRetypedLoad = [](LoadInst *OldLI) -> LoadInst * { | ||||
|     if (!OldLI) | ||||
|       return nullptr; | ||||
|     Value *PtrOp = OldLI->getPointerOperand(); | ||||
|     for (Instruction *I = OldLI->getPrevNode(); I; I = I->getPrevNode()) { | ||||
|       if (auto *LI = dyn_cast<LoadInst>(I)) { | ||||
|         if (LI->getPointerOperand() == PtrOp && LI->getType()->isPointerTy()) | ||||
|           return LI; | ||||
|       } | ||||
|     } | ||||
|     return nullptr; | ||||
|   }; | ||||
|  | ||||
|   SmallVector<User *, 16> Uses(OldF.users()); | ||||
|  | ||||
|   for (User *U : Uses) { | ||||
|     auto *CB = dyn_cast<CallBase>(U); | ||||
|     if (!CB) | ||||
|       continue; | ||||
|  | ||||
|     IRBuilder<> IRB(CB); | ||||
|     SmallVector<Value *, 8> NewArgs; | ||||
|     NewArgs.reserve(CB->arg_size()); | ||||
|  | ||||
|     SmallVector<LoadInst *, 4> DeadTargetExtLoads; | ||||
|  | ||||
|     unsigned Idx = 0; | ||||
|     for (Value *Actual : CB->args()) { | ||||
|       const Argument &Formal = *std::next(NewF.arg_begin(), Idx++); | ||||
|       Value *V = Actual; | ||||
|  | ||||
|       if (Formal.getType()->isPointerTy()) { | ||||
|         unsigned FormalAS = Formal.getType()->getPointerAddressSpace(); | ||||
|  | ||||
|         if (isTargetExtTy(V->getType())) { | ||||
|           if (auto *C = dyn_cast<Constant>(V)) { | ||||
|             if (C->isNullValue()) { | ||||
|               V = ConstantPointerNull::get(PointerType::get(CB->getContext(), FormalAS)); | ||||
|               NewArgs.push_back(V); | ||||
|               continue; | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|  | ||||
|         if (!V->getType()->isPointerTy()) { | ||||
|           if (auto *LI = dyn_cast<LoadInst>(V)) { | ||||
|             if (isTargetExtTy(LI->getType())) { | ||||
|               // Try to find the already inserted retyped (pointer) load. | ||||
|               if (LoadInst *Retyped = findEarlierRetypedLoad(LI)) { | ||||
|                 V = Retyped; | ||||
|                 DeadTargetExtLoads.push_back(LI); | ||||
|                 // Cast to the formal parameter type if needed. | ||||
|                 if (V->getType()->getPointerAddressSpace() != FormalAS) | ||||
|                   V = IRB.CreateAddrSpaceCast(V, Formal.getType()); | ||||
|                 else if (V->getType() != Formal.getType()) | ||||
|                   V = IRB.CreateBitCast(V, Formal.getType()); | ||||
|                 NewArgs.push_back(V); | ||||
|                 continue; | ||||
|               } | ||||
|               // Fallback, use the address (pointer operand). | ||||
|               // TODO: Remove this path. This should not be needed if the instruction-level retyping is done correctly. | ||||
|               // Consider adding an assert. | ||||
|               Value *PtrOp = LI->getPointerOperand(); | ||||
|               if (PtrOp->getType()->getPointerAddressSpace() != FormalAS) { | ||||
|                 PtrOp = IRB.CreateAddrSpaceCast(PtrOp, Formal.getType()); | ||||
|               } else if (PtrOp->getType() != Formal.getType()) { | ||||
|                 PtrOp = IRB.CreateBitCast(PtrOp, Formal.getType()); | ||||
|               } | ||||
|               NewArgs.push_back(PtrOp); | ||||
|               continue; | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|  | ||||
|         if (!V->getType()->isPointerTy()) { | ||||
|           // TODO: Remove this path. This should not be needed if the instruction-level retyping is done correctly. | ||||
|           // Consider adding an assert. | ||||
|           IRBuilder<> EntryB(&*CB->getFunction()->getEntryBlock().begin()); | ||||
|           AllocaInst *Tmp = EntryB.CreateAlloca(V->getType(), nullptr, V->getName() + ".addr"); | ||||
|           EntryB.CreateStore(V, Tmp); | ||||
|           V = Tmp; | ||||
|         } | ||||
|  | ||||
|         if (V->getType()->getPointerAddressSpace() != FormalAS) { | ||||
|           V = IRB.CreateAddrSpaceCast(V, Formal.getType()); | ||||
|         } else if (V->getType() != Formal.getType()) { | ||||
|           V = IRB.CreateBitCast(V, Formal.getType()); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       NewArgs.push_back(V); | ||||
|     } | ||||
|  | ||||
|     CallBase *NewCall = IRB.CreateCall(NewF.getFunctionType(), &NewF, NewArgs); | ||||
|     NewCall->copyMetadata(*CB); | ||||
|     if (CB->getType() != NewCall->getType()) | ||||
|       CB->mutateType(NewCall->getType()); | ||||
|     NewCall->setCallingConv(CB->getCallingConv()); | ||||
|     NewCall->setAttributes(CB->getAttributes()); | ||||
|  | ||||
|     CB->replaceAllUsesWith(NewCall); | ||||
|     CB->eraseFromParent(); | ||||
|  | ||||
|     // Remove now-dead TargetExtTy loads. | ||||
|     for (LoadInst *L : DeadTargetExtLoads) { | ||||
|       if (L->use_empty()) | ||||
|         L->eraseFromParent(); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| static bool structContainsOpenCLTargetExtTy(StructType *ST) { | ||||
| static bool isStructWithOpenCLTargetExtTyInside(const Type *Ty) { | ||||
|   const StructType *ST = dyn_cast<StructType>(Ty); | ||||
|   if (!ST || ST->isOpaque()) | ||||
|     return false; | ||||
|   for (Type *Elt : ST->elements()) { | ||||
|     if (isTargetExtTy(Elt) && !isNonOpenCLBuiltinType(Elt)) | ||||
|  | ||||
|   for (Type *EltTy : ST->elements()) { | ||||
|     if (isOpenCLTargetExtType(EltTy)) | ||||
|       return true; | ||||
|     if (auto *NestedST = dyn_cast<StructType>(Elt)) | ||||
|       if (structContainsOpenCLTargetExtTy(NestedST)) | ||||
|  | ||||
|     if (auto *NestedST = dyn_cast<StructType>(EltTy)) | ||||
|       if (isStructWithOpenCLTargetExtTyInside(NestedST)) | ||||
|         return true; | ||||
|   } | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| static Type *retypeElementIfNeeded(Type *EltTy, DenseMap<StructType *, StructType *> &StructMap); | ||||
|  | ||||
| static StructType *getOrCreateRetypedStruct(StructType *Old, DenseMap<StructType *, StructType *> &StructMap) { | ||||
|   if (auto It = StructMap.find(Old); It != StructMap.end()) | ||||
|     return It->second; | ||||
|  | ||||
|   LLVMContext &C = Old->getContext(); | ||||
|   StructType *New = StructType::create(C, Old->getName()); | ||||
|   StructMap[Old] = New; // Insert early to break cycles. | ||||
|  | ||||
|   SmallVector<Type *, 8> NewElts; | ||||
|   NewElts.reserve(Old->getNumElements()); | ||||
|   for (Type *Elt : Old->elements()) { | ||||
|     NewElts.push_back(retypeElementIfNeeded(Elt, StructMap)); | ||||
|   } | ||||
|   New->setBody(NewElts, Old->isPacked()); | ||||
|   return New; | ||||
| static bool checkIfNeedsRetyping(const Type *Ty) { | ||||
|   return isOpenCLTargetExtType(Ty) || isStructWithOpenCLTargetExtTyInside(Ty); | ||||
| } | ||||
|  | ||||
| static Type *retypeElementIfNeeded(Type *EltTy, DenseMap<StructType *, StructType *> &StructMap) { | ||||
|   LLVMContext &C = EltTy->getContext(); | ||||
| static bool isAnyArgOpenCLTargetExtTy(const Function &F) { | ||||
|   for (const Argument &A : F.args()) { | ||||
|     const Type *ArgTy = A.getType(); | ||||
|     if (checkIfNeedsRetyping(ArgTy)) | ||||
|       return true; | ||||
|  | ||||
|   if (isTargetExtTy(EltTy) && !isNonOpenCLBuiltinType(EltTy)) | ||||
|     return PointerType::get(C, getAddressSpaceForTargetExtTy(cast<TargetExtType>(EltTy))); | ||||
|  | ||||
|   if (auto *ST = dyn_cast<StructType>(EltTy)) { | ||||
|     if (structContainsOpenCLTargetExtTy(ST)) | ||||
|       return getOrCreateRetypedStruct(ST, StructMap); | ||||
|     if (A.hasStructRetAttr()) { | ||||
|       if (checkIfNeedsRetyping(A.getParamStructRetType())) | ||||
|         return true; | ||||
|     } | ||||
|   } | ||||
|   return EltTy; | ||||
|  | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| static void buildStructRetypeMap(Module &M, DenseMap<StructType *, StructType *> &StructMap) { | ||||
|   std::vector<StructType *> Structs = M.getIdentifiedStructTypes(); | ||||
|   for (StructType *ST : Structs) { | ||||
|     if (!ST || ST->isOpaque()) | ||||
| namespace { | ||||
| class OpenCLTargetExtTypeMapper : public ValueMapTypeRemapper { | ||||
| public: | ||||
|   OpenCLTargetExtTypeMapper(Function &F, DenseMap<StructType *, StructType *> &TETtoRetypedStructs) | ||||
|       : Fn(F), Ctx(F.getContext()), TETtoRetypedStructs(TETtoRetypedStructs) {} | ||||
|  | ||||
|   Type *remapType(Type *SrcTy) override { | ||||
|     if (!SrcTy) | ||||
|       return SrcTy; | ||||
|  | ||||
|     if (auto *FTy = dyn_cast<FunctionType>(SrcTy)) | ||||
|       return remapFunctionType(FTy); | ||||
|  | ||||
|     if (auto *TET = dyn_cast<TargetExtType>(SrcTy)) | ||||
|       return remapTargetExtType(TET); | ||||
|  | ||||
|     if (auto *ST = dyn_cast<StructType>(SrcTy)) | ||||
|       return remapStructType(ST); | ||||
|  | ||||
|     // Possibly no need to retype, otherwise new cases need to be added (above). | ||||
|     return SrcTy; | ||||
|   } | ||||
|  | ||||
| private: | ||||
|   Function &Fn; | ||||
|   LLVMContext &Ctx; | ||||
|   DenseMap<StructType *, StructType *> &TETtoRetypedStructs; | ||||
|  | ||||
|   FunctionType *remapFunctionType(FunctionType *FTy) { | ||||
|     SmallVector<Type *, 6> NewParamTys; | ||||
|     NewParamTys.reserve(FTy->getNumParams()); | ||||
|     bool AnyChange = false; | ||||
|  | ||||
|     for (Type *ParamTy : FTy->params()) { | ||||
|       Type *NewParamTy = remapType(ParamTy); | ||||
|       if (NewParamTy != ParamTy) | ||||
|         AnyChange = true; | ||||
|       NewParamTys.push_back(NewParamTy); | ||||
|     } | ||||
|  | ||||
|     Type *RetType = FTy->getReturnType(); | ||||
|     Type *NewRetTy = remapType(RetType); | ||||
|     if (NewRetTy != RetType) | ||||
|       AnyChange = true; | ||||
|  | ||||
|     if (!AnyChange) { | ||||
|       return FTy; | ||||
|     } | ||||
|     return FunctionType::get(NewRetTy, NewParamTys, FTy->isVarArg()); | ||||
|   } | ||||
|  | ||||
|   Type *remapTargetExtType(TargetExtType *TET) { | ||||
|     if (isNonOpenCLBuiltinType(TET)) | ||||
|       return TET; | ||||
|  | ||||
|     StringRef TyName = TET->getName(); | ||||
|     unsigned AS = ADDRESS_SPACE_PRIVATE; | ||||
|     if (TyName.startswith("spirv.Image")) | ||||
|       AS = ADDRESS_SPACE_GLOBAL; | ||||
|     else if (TyName.startswith("spirv.Sampler")) | ||||
|       AS = ADDRESS_SPACE_CONSTANT; | ||||
|  | ||||
|     return PointerType::get(Ctx, AS); | ||||
|   } | ||||
|  | ||||
|   Type *remapStructType(StructType *StructTy) { | ||||
|     if (!StructTy || StructTy->isOpaque()) { | ||||
|       return StructTy; | ||||
|     } | ||||
|  | ||||
|     // Scan first to avoid unnecessary retyping/cloning. | ||||
|     if (!isStructWithOpenCLTargetExtTyInside(StructTy)) | ||||
|       return StructTy; | ||||
|  | ||||
|     return cloneStructRetyped(StructTy); | ||||
|   } | ||||
|  | ||||
|   StructType *cloneStructRetyped(StructType *Old) { | ||||
|     // Reuse mapping if already retyped. | ||||
|     auto It = TETtoRetypedStructs.find(Old); | ||||
|     if (It != TETtoRetypedStructs.end()) { | ||||
|       return It->second; | ||||
|     } | ||||
|  | ||||
|     std::string OrigName = Old->getName().str(); | ||||
|     Old->setName(OrigName + ".preretype"); | ||||
|  | ||||
|     // Early insert placeholder to break cycles. | ||||
|     StructType *NewST = StructType::create(Ctx, OrigName); | ||||
|     TETtoRetypedStructs[Old] = NewST; | ||||
|  | ||||
|     SmallVector<Type *, 8> NewElems; | ||||
|     NewElems.reserve(Old->getNumElements()); | ||||
|     bool AnyChange = false; | ||||
|     for (Type *Elt : Old->elements()) { | ||||
|       Type *NewElt = remapType(Elt); | ||||
|       if (NewElt != Elt) | ||||
|         AnyChange = true; | ||||
|       NewElems.push_back(NewElt); | ||||
|     } | ||||
|  | ||||
|     if (!AnyChange) { | ||||
|       // No change, reuse original and discard temp. | ||||
|       TETtoRetypedStructs[Old] = Old; | ||||
|       return Old; | ||||
|     } | ||||
|  | ||||
|     NewST->setBody(NewElems, Old->isPacked()); | ||||
|     return NewST; | ||||
|   } | ||||
| }; | ||||
| } // namespace | ||||
|  | ||||
| void retypeOpenCLTargetExtTyAsPointers(Module *M) { | ||||
|   struct FunctionSignatureChange { | ||||
|     FunctionType *NewFuncTy; | ||||
|     AttributeList NewAttrs; | ||||
|   }; | ||||
|  | ||||
|   // Global mapping between TargetExtTy structs and their retyped variant (they are shared between functions). | ||||
|   DenseMap<StructType *, StructType *> TETtoRetypedStructs; | ||||
|   DenseMap<Function *, FunctionSignatureChange> PendingSigChanges; | ||||
|  | ||||
|   // Remap bodies and collect function signature changes. | ||||
|   for (Function &F : *M) { | ||||
|     bool ArgsOrRetTypeNeedsRetyping = isAnyArgOpenCLTargetExtTy(F) || checkIfNeedsRetyping(F.getReturnType()); | ||||
|  | ||||
|     // Need to process declarations that have TargetExtTy return/args, skip others. | ||||
|     if (F.isDeclaration() && !ArgsOrRetTypeNeedsRetyping) | ||||
|       continue; | ||||
|     if (structContainsOpenCLTargetExtTy(ST)) | ||||
|       (void)getOrCreateRetypedStruct(ST, StructMap); | ||||
|   } | ||||
| } | ||||
|  | ||||
| static void replaceStructTypeUsesInFunction(Function &F, const DenseMap<StructType *, StructType *> &StructMap) { | ||||
|   SmallVector<Instruction *, 32> ToErase; | ||||
|  | ||||
|   for (BasicBlock &BB : F) { | ||||
|     for (Instruction &I : BB) { | ||||
|       if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { | ||||
|         Type *SrcTy = GEP->getSourceElementType(); | ||||
|         auto *OldST = dyn_cast<StructType>(SrcTy); | ||||
|         if (!OldST) | ||||
|           continue; | ||||
|         auto It = StructMap.find(OldST); | ||||
|         if (It == StructMap.end()) | ||||
|           continue; | ||||
|  | ||||
|         SmallVector<Value *, 8> Indices(GEP->idx_begin(), GEP->idx_end()); | ||||
|         GetElementPtrInst *NewGEP = | ||||
|             GetElementPtrInst::Create(It->second, GEP->getPointerOperand(), Indices, GEP->getName(), GEP); | ||||
|         NewGEP->setIsInBounds(GEP->isInBounds()); | ||||
|         NewGEP->setDebugLoc(GEP->getDebugLoc()); | ||||
|         SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; | ||||
|         GEP->getAllMetadata(MDs); | ||||
|         for (auto &P : MDs) | ||||
|           NewGEP->setMetadata(P.first, P.second); | ||||
|  | ||||
|         GEP->replaceAllUsesWith(NewGEP); | ||||
|         ToErase.push_back(GEP); | ||||
|         continue; | ||||
|       } | ||||
|  | ||||
|       if (auto *AI = dyn_cast<AllocaInst>(&I)) { | ||||
|         auto *OldST = dyn_cast<StructType>(AI->getAllocatedType()); | ||||
|         if (!OldST) | ||||
|           continue; | ||||
|         auto It = StructMap.find(OldST); | ||||
|         if (It == StructMap.end()) | ||||
|           continue; | ||||
|  | ||||
|         AllocaInst *NewAI = | ||||
|             new AllocaInst(It->second, AI->getAddressSpace(), AI->getArraySize(), AI->getAlign(), AI->getName(), AI); | ||||
|         NewAI->setDebugLoc(AI->getDebugLoc()); | ||||
|         SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; | ||||
|         AI->getAllMetadata(MDs); | ||||
|         for (auto &P : MDs) | ||||
|           NewAI->setMetadata(P.first, P.second); | ||||
|  | ||||
|         AI->replaceAllUsesWith(NewAI); | ||||
|         ToErase.push_back(AI); | ||||
|         continue; | ||||
|       } | ||||
|  | ||||
|       if (auto *LI = dyn_cast<LoadInst>(&I)) { | ||||
|         auto *OldST = dyn_cast<StructType>(LI->getType()); | ||||
|         if (!OldST) | ||||
|           continue; | ||||
|         auto It = StructMap.find(OldST); | ||||
|         if (It == StructMap.end()) | ||||
|           continue; | ||||
|  | ||||
|         LoadInst *NewLoad = new LoadInst(It->second, LI->getPointerOperand(), LI->getName(), LI); | ||||
|         NewLoad->setAlignment(LI->getAlign()); | ||||
|         NewLoad->setDebugLoc(LI->getDebugLoc()); | ||||
|         SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; | ||||
|         LI->getAllMetadata(MDs); | ||||
|         for (auto &P : MDs) | ||||
|           NewLoad->setMetadata(P.first, P.second); | ||||
|  | ||||
|         LI->replaceAllUsesWith(NewLoad); | ||||
|         ToErase.push_back(LI); | ||||
|         continue; | ||||
|       } | ||||
|  | ||||
|       if (auto *SI = dyn_cast<StoreInst>(&I)) { | ||||
|         auto *OldST = dyn_cast<StructType>(SI->getValueOperand()->getType()); | ||||
|         if (!OldST) | ||||
|           continue; | ||||
|         auto It = StructMap.find(OldST); | ||||
|         if (It == StructMap.end()) | ||||
|           continue; | ||||
|  | ||||
|         // Only handle simple constant zero/undef cases safely. | ||||
|         Value *V = SI->getValueOperand(); | ||||
|         Value *NewV = nullptr; | ||||
|         if (auto *C = dyn_cast<Constant>(V)) { | ||||
|           if (C->isNullValue()) | ||||
|             NewV = Constant::getNullValue(It->second); | ||||
|           else if (isa<UndefValue>(C)) | ||||
|             NewV = UndefValue::get(It->second); | ||||
|     // Scan function to see if it uses TargetExtTy. | ||||
|     bool UsesTargetExt = ArgsOrRetTypeNeedsRetyping; | ||||
|     if (!UsesTargetExt && !F.isDeclaration()) { | ||||
|       for (BasicBlock &BB : F) { | ||||
|         for (Instruction &I : BB) { | ||||
|           Type *Ty = I.getType(); | ||||
|           if (checkIfNeedsRetyping(Ty)) { | ||||
|             UsesTargetExt = true; | ||||
|             break; | ||||
|           } | ||||
|           // Also scan operand types (structs carrying TargetExt). | ||||
|           for (Value *Op : I.operands()) { | ||||
|             if (auto *OpTy = Op->getType(); checkIfNeedsRetyping(OpTy)) { | ||||
|               UsesTargetExt = true; | ||||
|               break; | ||||
|             } | ||||
|           } | ||||
|           if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { | ||||
|             Type *SrcElemTy = GEP->getSourceElementType(); | ||||
|             if (checkIfNeedsRetyping(SrcElemTy)) { | ||||
|               UsesTargetExt = true; | ||||
|               break; | ||||
|             } | ||||
|           } | ||||
|           if (UsesTargetExt) | ||||
|             break; | ||||
|         } | ||||
|         if (!NewV) | ||||
|           continue; // TODO: Support complex producer chains. | ||||
|  | ||||
|         StoreInst *NewStore = new StoreInst(NewV, SI->getPointerOperand(), SI); | ||||
|         NewStore->setAlignment(SI->getAlign()); | ||||
|         NewStore->setDebugLoc(SI->getDebugLoc()); | ||||
|         SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; | ||||
|         SI->getAllMetadata(MDs); | ||||
|         for (auto &P : MDs) | ||||
|           NewStore->setMetadata(P.first, P.second); | ||||
|  | ||||
|         ToErase.push_back(SI); | ||||
|         continue; | ||||
|         if (UsesTargetExt) | ||||
|           break; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   for (Instruction *I : ToErase) | ||||
|     I->eraseFromParent(); | ||||
| } | ||||
|  | ||||
| static AttributeList rewriteAttrListWithStructMap(LLVMContext &C, const AttributeList &AL, unsigned NumParams, | ||||
|                                                   const DenseMap<StructType *, StructType *> &StructMap) { | ||||
|   // Function and return attribute sets are preserved as-is. | ||||
|   AttributeSet FnSet = AL.getFnAttrs(); | ||||
|   AttributeSet RetSet = AL.getRetAttrs(); | ||||
|  | ||||
|   SmallVector<AttributeSet, 16> ParamSets; | ||||
|   ParamSets.reserve(NumParams); | ||||
|  | ||||
|   for (unsigned I = 0; I < NumParams; ++I) { | ||||
|     AttributeSet AS = AL.getParamAttrs(I); | ||||
|     if (!AS.hasAttributes()) { | ||||
|       ParamSets.emplace_back(); | ||||
|     // If neither args/return/instructions use OpenCL TargetExtTy, skip. | ||||
|     if (!UsesTargetExt) { | ||||
|       continue; | ||||
|     } | ||||
|  | ||||
|     SmallVector<llvm::Attribute, 8> NewAttrs; | ||||
|     for (llvm::Attribute A : AS) { | ||||
|       if (A.isTypeAttribute()) { | ||||
|         Type *Ty = A.getValueAsType(); | ||||
|         if (auto *OldST = dyn_cast<StructType>(Ty)) { | ||||
|           if (auto It = StructMap.find(OldST); It != StructMap.end()) { | ||||
|             switch (A.getKindAsEnum()) { | ||||
|             case llvm::Attribute::ByVal: | ||||
|               A = llvm::Attribute::get(C, llvm::Attribute::ByVal, It->second); | ||||
|               break; | ||||
|             case llvm::Attribute::StructRet: | ||||
|               A = llvm::Attribute::get(C, llvm::Attribute::StructRet, It->second); | ||||
|               break; | ||||
|             case llvm::Attribute::ByRef: | ||||
|               A = llvm::Attribute::get(C, llvm::Attribute::ByRef, It->second); | ||||
|               break; | ||||
|             default: | ||||
|               break; | ||||
|     OpenCLTargetExtTypeMapper Mapper(F, TETtoRetypedStructs); | ||||
|     ValueToValueMapTy VM; | ||||
|  | ||||
|     // Handle constants of target extension types. | ||||
|     for (BasicBlock &BB : F) { | ||||
|       for (Instruction &I : BB) { | ||||
|         for (Use &U : I.operands()) { | ||||
|           if (Constant *C = dyn_cast<Constant>(U.get())) { | ||||
|             Type *Ty = C->getType(); | ||||
|             if (checkIfNeedsRetyping(Ty)) { | ||||
|               Type *NewTy = Mapper.remapType(Ty); | ||||
|               if (NewTy != Ty) { | ||||
|                 VM[C] = Constant::getNullValue(NewTy); | ||||
|               } | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       NewAttrs.push_back(A); | ||||
|     } | ||||
|     ParamSets.push_back(AttributeSet::get(C, NewAttrs)); | ||||
|  | ||||
|     RemapFunction(F, VM, RF_IgnoreMissingLocals | RF_ReuseAndMutateDistinctMDs, &Mapper); | ||||
|  | ||||
|     // We only need to replace function whose signature changes. | ||||
|     if (ArgsOrRetTypeNeedsRetyping) { | ||||
|       // Remap function argument and return types. | ||||
|       FunctionType *NewFTy = cast<FunctionType>(Mapper.remapType(F.getFunctionType())); | ||||
|  | ||||
|       // Remap types used in attributes. | ||||
|       AttributeList OldAttrs = F.getAttributes(); | ||||
|       SmallVector<AttributeSet, 8> NewArgAttrs; | ||||
|       NewArgAttrs.reserve(OldAttrs.getNumAttrSets()); | ||||
|       bool AttrsChanged = false; | ||||
|  | ||||
|       for (const Argument &Arg : F.args()) { | ||||
|         AttributeSet Attrs = OldAttrs.getParamAttrs(Arg.getArgNo()); | ||||
|         if (Attrs.hasAttribute(llvm::Attribute::StructRet)) { | ||||
|           Type *OldSRetTy = Arg.getParamStructRetType(); | ||||
|           Type *NewSRetTy = Mapper.remapType(OldSRetTy); | ||||
|           if (NewSRetTy != OldSRetTy) { | ||||
|             AttrBuilder AB(M->getContext(), Attrs); | ||||
|             AB.removeAttribute(llvm::Attribute::StructRet); | ||||
|             AB.addStructRetAttr(NewSRetTy); | ||||
|             Attrs = AttributeSet::get(M->getContext(), AB); | ||||
|             AttrsChanged = true; | ||||
|           } | ||||
|         } | ||||
|         NewArgAttrs.push_back(Attrs); | ||||
|       } | ||||
|  | ||||
|       AttributeList NewAttrs = | ||||
|           AttrsChanged ? AttributeList::get(M->getContext(), OldAttrs.getFnAttrs(), OldAttrs.getRetAttrs(), NewArgAttrs) | ||||
|                        : OldAttrs; | ||||
|  | ||||
|       PendingSigChanges.try_emplace(&F, FunctionSignatureChange{NewFTy, NewAttrs}); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return AttributeList::get(C, FnSet, RetSet, ParamSets); | ||||
| } | ||||
|   // Replace functions with changed signatures. | ||||
|   for (auto &KV : PendingSigChanges) { | ||||
|     Function *OldF = KV.first; | ||||
|     FunctionSignatureChange Change = KV.second; | ||||
|  | ||||
| static void rewriteFunctionSRetByValAttrs(Function &F, const DenseMap<StructType *, StructType *> &StructMap) { | ||||
|   AttributeList Old = F.getAttributes(); | ||||
|   AttributeList New = rewriteAttrListWithStructMap(F.getContext(), Old, F.arg_size(), StructMap); | ||||
|   if (Old != New) | ||||
|     F.setAttributes(New); | ||||
| } | ||||
|     // Preserve original name to restore after erasing OldF. | ||||
|     std::string OldName = OldF->getName().str(); | ||||
|  | ||||
| static void rewriteCallSRetByValAttrs(CallBase &CB, const DenseMap<StructType *, StructType *> &StructMap) { | ||||
|   AttributeList Old = CB.getAttributes(); | ||||
|   AttributeList New = rewriteAttrListWithStructMap(CB.getContext(), Old, CB.arg_size(), StructMap); | ||||
|   if (Old != New) | ||||
|     CB.setAttributes(New); | ||||
| } | ||||
|     // Create new function with same linkage & addr space (temporary unique name). | ||||
|     Function *NewF = Function::Create(Change.NewFuncTy, OldF->getLinkage(), OldF->getAddressSpace(), OldName, M); | ||||
|  | ||||
| static void retypeStructsWithTargetExt(Module &M) { | ||||
|   // Build struct retyping map for the whole module. | ||||
|   DenseMap<StructType *, StructType *> StructMap; | ||||
|   buildStructRetypeMap(M, StructMap); | ||||
|   if (StructMap.empty()) | ||||
|     return; | ||||
|     // Set remapped attributes. | ||||
|     NewF->setAttributes(Change.NewAttrs); | ||||
|  | ||||
|   // Update instructions first (alloca/GEP/etc). | ||||
|   for (Function &F : M) { | ||||
|     if (F.isDeclaration()) | ||||
|       continue; | ||||
|     replaceStructTypeUsesInFunction(F, StructMap); | ||||
|   } | ||||
|     // Copy calling convention, comdat. | ||||
|     NewF->setCallingConv(OldF->getCallingConv()); | ||||
|     if (OldF->getComdat()) | ||||
|       NewF->setComdat(OldF->getComdat()); | ||||
|  | ||||
|   // Fix sret/byval/byref attributes on function defs and call sites. | ||||
|   for (Function &F : M) { | ||||
|     rewriteFunctionSRetByValAttrs(F, StructMap); | ||||
|     for (BasicBlock &BB : F) { | ||||
|       for (Instruction &I : BB) { | ||||
|         if (auto *CB = dyn_cast<CallBase>(&I)) | ||||
|           rewriteCallSRetByValAttrs(*CB, StructMap); | ||||
|     // Transfer debug subprogram (if any). | ||||
|     if (OldF->getSubprogram()) { | ||||
|       NewF->setSubprogram(OldF->getSubprogram()); | ||||
|       OldF->setSubprogram(nullptr); | ||||
|     } | ||||
|  | ||||
|     // Copy all function-level metadata (except dbg already handled via Subprogram). | ||||
|     { | ||||
|       SmallVector<std::pair<unsigned, MDNode *>, 8> MDs; | ||||
|       OldF->getAllMetadata(MDs); | ||||
|       unsigned DbgKind = M->getContext().getMDKindID("dbg"); | ||||
|       for (auto &MDPair : MDs) { | ||||
|         if (MDPair.first == DbgKind) | ||||
|           continue; // Avoid duplicating debug info. | ||||
|         NewF->setMetadata(MDPair.first, MDPair.second); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Old identified struct types should become unused and disappear from IR, no explicit erasure is required. | ||||
| } | ||||
|  | ||||
| void retypeOpenCLTargetExtTyAsPointers(Module *M) { | ||||
|   // Step 1: Retype local allocas/loads/stores of OpenCL TargetExtTy to use pointer types instead. | ||||
|   for (Function &F : *M) { | ||||
|     if (F.isDeclaration()) | ||||
|       continue; | ||||
|     retypeLocalTargetExtAllocasLoadsStores(F); | ||||
|   } | ||||
|  | ||||
|   constexpr StringLiteral TempSuffix = ".__retype_tmp"; | ||||
|   SmallVector<Function *, 8> RetypedFuncs; | ||||
|  | ||||
|   // Step 2: Retype function signatures of functions that use OpenCL TargetExtTy in args or return type to use pointer | ||||
|   // types instead. | ||||
|   for (Function &F : *M) { | ||||
|     if (!isAnyArgOpenCLTargetExtTy(F) && !isDeclarationWithOpenCLTargetExtTyRet(F)) | ||||
|       continue; | ||||
|  | ||||
|     if (Function *NewF = cloneFunctionWithPtrArgsandRetInsteadTargetExtTy(F, TempSuffix)) | ||||
|       RetypedFuncs.push_back(NewF); | ||||
|   } | ||||
|  | ||||
|   // Schedule replacement in call-site order to keep IR stable. | ||||
|   DenseMap<const Function *, size_t> FirstUseIndex; | ||||
|   size_t InstIndex = 0; | ||||
|   for (Function &F : *M) { | ||||
|     for (BasicBlock &BB : F) { | ||||
|       for (Instruction &I : BB) { | ||||
|         if (auto *CB = dyn_cast<CallBase>(&I)) | ||||
|           if (const Function *Callee = CB->getCalledFunction()) | ||||
|             if (FirstUseIndex.find(Callee) == FirstUseIndex.end()) | ||||
|               FirstUseIndex[Callee] = InstIndex; | ||||
|         ++InstIndex; | ||||
|       } | ||||
|     // Move body (for definitions). | ||||
|     if (!OldF->isDeclaration()) { | ||||
|       NewF->splice(NewF->begin(), OldF); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   auto idxOf = [&](Function *NewF) { | ||||
|     StringRef OrigName = NewF->getName().drop_back(TempSuffix.size()); | ||||
|     const Function *OldF = M->getFunction(OrigName); | ||||
|     auto It = FirstUseIndex.find(OldF); | ||||
|     return It == FirstUseIndex.end() ? std::numeric_limits<size_t>::max() : It->second; | ||||
|   }; | ||||
|     // Map arguments. | ||||
|     auto OldIt = OldF->arg_begin(); | ||||
|     auto NewIt = NewF->arg_begin(); | ||||
|     for (; OldIt != OldF->arg_end(); ++OldIt, ++NewIt) { | ||||
|       NewIt->takeName(&*OldIt); | ||||
|       OldIt->replaceAllUsesWith(&*NewIt); | ||||
|     } | ||||
|  | ||||
|   std::stable_sort(RetypedFuncs.begin(), RetypedFuncs.end(), | ||||
|                    [&](Function *A, Function *B) { return idxOf(A) < idxOf(B); }); | ||||
|  | ||||
|   for (Function *NewF : RetypedFuncs) { | ||||
|     std::string OriginalName = NewF->getName().drop_back(TempSuffix.size()).str(); | ||||
|     Function *OldF = M->getFunction(OriginalName); | ||||
|  | ||||
|     replaceFunctionAtCallsites(*OldF, *NewF); | ||||
|  | ||||
|     // Keep the original symbol name by swapping. | ||||
|     Constant *NewFBitCast = ConstantExpr::getBitCast(NewF, OldF->getType()); | ||||
|     OldF->replaceAllUsesWith(NewFBitCast); | ||||
|  | ||||
|     OldF->setName(OriginalName + ".old"); | ||||
|     NewF->setName(OriginalName); | ||||
|     // Redirect users then remove old. | ||||
|     OldF->replaceAllUsesWith(NewF); | ||||
|     OldF->eraseFromParent(); | ||||
|   } | ||||
|  | ||||
|   // Step 3: Retype struct types that contain OpenCL TargetExtTy fields and update users. | ||||
|   retypeStructsWithTargetExt(*M); | ||||
|     // Restore original name. | ||||
|     if (NewF->getName() != OldName) | ||||
|       NewF->setName(OldName); | ||||
|   } | ||||
| } | ||||
|  | ||||
| #endif // LLVM_VERSION_MAJOR >= 16 | ||||
|  | ||||
		Reference in New Issue
	
	Block a user