[mlir][LLVM] Add LLVMAddrSpaceAttrInterface and NVVMMemorySpaceAttr (#157339)

This patch introduces the `LLVMAddrSpaceAttrInterface` for defining
compatible LLVM address space attributes

To test this interface, this patch also adds:
- Adds NVVMMemorySpaceAttr implementing both LLVMAddrSpaceAttrInterface
and MemorySpaceAttrInterface
- Converts NVVM memory space constants from enum to MLIR enums
- Updates all NVVM memory space references to use new attribute system
- Adds support for NVVM memory spaces in ptr dialect translation

Example:
```mlir
llvm.func @nvvm_ptr_address_space(
    !ptr.ptr<#nvvm.memory_space<global>>,
    !ptr.ptr<#nvvm.memory_space<shared>>,
    !ptr.ptr<#nvvm.memory_space<constant>>,
    !ptr.ptr<#nvvm.memory_space<local>>,
    !ptr.ptr<#nvvm.memory_space<tensor>>,
    !ptr.ptr<#nvvm.memory_space<shared_cluster>>
  ) -> !ptr.ptr<#nvvm.memory_space<generic>>
```
Translating the above code to LLVM produces:
```llvm
declare ptr @nvvm_ptr_address_space(ptr addrspace(1), ptr addrspace(3), ptr addrspace(4), ptr addrspace(5), ptr addrspace(6), ptr addrspace(7))
```


To convert the memory space enum to the new enum class use:
```bash
grep -r . -e "NVVMMemorySpace::kGenericMemorySpace" -l | xargs sed -i -e "s/NVVMMemorySpace::kGenericMemorySpace/NVVMMemorySpace::Generic/g"
grep -r . -e "NVVMMemorySpace::kGlobalMemorySpace" -l | xargs sed -i -e "s/NVVMMemorySpace::kGlobalMemorySpace/NVVMMemorySpace::Global/g"
grep -r . -e "NVVMMemorySpace::kSharedMemorySpace" -l | xargs sed -i -e "s/NVVMMemorySpace::kSharedMemorySpace/NVVMMemorySpace::Shared/g"
grep -r . -e "NVVMMemorySpace::kConstantMemorySpace" -l | xargs sed -i -e "s/NVVMMemorySpace::kConstantMemorySpace/NVVMMemorySpace::Constant/g"
grep -r . -e "NVVMMemorySpace::kLocalMemorySpace" -l | xargs sed -i -e "s/NVVMMemorySpace::kLocalMemorySpace/NVVMMemorySpace::Local/g"
grep -r . -e "NVVMMemorySpace::kTensorMemorySpace" -l | xargs sed -i -e "s/NVVMMemorySpace::kTensorMemorySpace/NVVMMemorySpace::Tensor/g"
grep -r . -e "NVVMMemorySpace::kSharedClusterMemorySpace" -l | xargs sed -i -e "s/NVVMMemorySpace::kSharedClusterMemorySpace/NVVMMemorySpace::SharedCluster/g"
```

NOTE: A future patch will add support for ROCDL, it wasn't added here to
keep the patch small.
This commit is contained in:
Fabian Mora
2025-09-14 09:05:28 -04:00
committed by GitHub
parent 9ee1f159dc
commit 48babe1931
18 changed files with 220 additions and 77 deletions

View File

@@ -3211,7 +3211,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
if (global.getDataAttr() &&
*global.getDataAttr() == cuf::DataAttribute::Shared)
g.setAddrSpace(mlir::NVVM::NVVMMemorySpace::kSharedMemorySpace);
g.setAddrSpace(
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared));
rewriter.eraseOp(global);
return mlir::success();

View File

@@ -221,7 +221,8 @@ static mlir::Value createAddressOfOp(mlir::ConversionPatternRewriter &rewriter,
gpu::GPUModuleOp gpuMod,
std::string &sharedGlobalName) {
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(
rewriter.getContext(), mlir::NVVM::NVVMMemorySpace::kSharedMemorySpace);
rewriter.getContext(),
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared));
if (auto g = gpuMod.lookupSymbol<fir::GlobalOp>(sharedGlobalName))
return mlir::LLVM::AddressOfOp::create(rewriter, loc, llvmPtrTy,
g.getSymName());

View File

@@ -30,6 +30,7 @@ class LLVM_Attr<string name, string attrMnemonic,
def LLVM_AddressSpaceAttr :
LLVM_Attr<"AddressSpace", "address_space", [
LLVM_LLVMAddrSpaceAttrInterface,
DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>
]> {
let summary = "LLVM address space";

View File

@@ -93,6 +93,14 @@ public:
using cconv::CConv;
using linkage::Linkage;
using tailcallkind::TailCallKind;
namespace detail {
/// Checks whether the given type is an LLVM type that can be loaded or stored.
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering,
std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError);
} // namespace detail
} // namespace LLVM
} // namespace mlir

View File

@@ -533,6 +533,24 @@ def LLVM_DIRecursiveTypeAttrInterface
];
}
def LLVM_LLVMAddrSpaceAttrInterface :
AttrInterface<"LLVMAddrSpaceAttrInterface"> {
let description = [{
An interface for attributes that represent LLVM address spaces.
Implementing attributes should provide access to the address space value
as an unsigned integer.
}];
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<
/*description=*/"Returns the address space as an unsigned integer.",
/*retTy=*/"unsigned",
/*methodName=*/"getAddressSpace",
/*args=*/(ins)
>
];
}
def LLVM_TargetAttrInterface
: AttrInterface<"TargetAttrInterface", [DLTIQueryInterface]> {
let description = [{

View File

@@ -19,6 +19,7 @@
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h"
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -30,31 +31,23 @@
namespace mlir {
namespace NVVM {
/// Utility functions to compare NVVMMemorySpace with unsigned values.
inline bool operator==(unsigned as, NVVMMemorySpace memSpace) {
return as == static_cast<unsigned>(memSpace);
}
inline bool operator==(NVVMMemorySpace memSpace, unsigned as) {
return static_cast<unsigned>(memSpace) == as;
}
inline bool operator!=(unsigned as, NVVMMemorySpace memSpace) {
return as != static_cast<unsigned>(memSpace);
}
inline bool operator!=(NVVMMemorySpace memSpace, unsigned as) {
return static_cast<unsigned>(memSpace) != as;
}
// Shared memory has 128-bit alignment
constexpr int kSharedMemoryAlignmentBit = 128;
/// NVVM memory space identifiers.
enum NVVMMemorySpace {
/// Generic memory space identifier.
kGenericMemorySpace = 0,
/// Global memory space identifier.
kGlobalMemorySpace = 1,
/// Shared memory space identifier.
kSharedMemorySpace = 3,
/// Constant memory space identifier.
kConstantMemorySpace = 4,
/// Local memory space identifier.
kLocalMemorySpace = 5,
/// Tensor memory space identifier.
/// Tensor memory is available only in arch-accelerated
/// variants from sm100 onwards.
kTensorMemorySpace = 6,
/// Distributed shared memory space identifier.
/// Distributed shared memory is available only in sm90+.
kSharedClusterMemorySpace = 7,
};
/// A pair type of LLVM's Intrinsic ID and args (which are llvm values).
/// This type is returned by the getIntrinsicIDAndArgs() methods.
using IDArgPair =

View File

@@ -17,6 +17,7 @@ include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td"
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -192,6 +193,40 @@ def CacheEvictionPriorityAttr : EnumAttr<NVVM_Dialect, CacheEvictionPriority,
let assemblyFormat = "$value";
}
// Memory Space enum definitions
/// Generic memory space identifier.
def MemSpaceGeneric : I32EnumCase<"Generic", 0, "generic">;
/// Global memory space identifier.
def MemSpaceGlobal : I32EnumCase<"Global", 1, "global">;
/// Shared memory space identifier.
def MemSpaceShared : I32EnumCase<"Shared", 3, "shared">;
/// Constant memory space identifier.
def MemSpaceConstant : I32EnumCase<"Constant", 4, "constant">;
/// Local memory space identifier.
def MemSpaceLocal : I32EnumCase<"Local", 5, "local">;
/// Tensor memory space identifier.
/// Tensor memory is available only in arch-accelerated
/// variants from sm100 onwards.
def MemSpaceTensor : I32EnumCase<"Tensor", 6, "tensor">;
/// Distributed shared memory space identifier.
/// Distributed shared memory is available only in sm90+.
def MemSpaceSharedCluster : I32EnumCase<"SharedCluster", 7, "shared_cluster">;
def NVVMMemorySpace : I32Enum<"NVVMMemorySpace", "NVVM Memory Space",
[MemSpaceGeneric, MemSpaceGlobal, MemSpaceShared,
MemSpaceConstant, MemSpaceLocal, MemSpaceTensor,
MemSpaceSharedCluster]> {
let cppNamespace = "::mlir::NVVM";
}
def NVVMMemorySpaceAttr :
EnumAttr<NVVM_Dialect, NVVMMemorySpace, "memory_space", [
DeclareAttrInterfaceMethods<LLVM_LLVMAddrSpaceAttrInterface>,
DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>
]> {
let assemblyFormat = "`<` $value `>`";
}
//===----------------------------------------------------------------------===//
// NVVM intrinsic operations
//===----------------------------------------------------------------------===//
@@ -3592,7 +3627,7 @@ def NVVM_MapaOp: NVVM_Op<"mapa",
string llvmBuilder = [{
int addrSpace = llvm::cast<LLVMPointerType>(op.getA().getType()).getAddressSpace();
bool isSharedMemory = addrSpace == NVVM::NVVMMemorySpace::kSharedMemorySpace;
bool isSharedMemory = addrSpace == static_cast<int> (NVVM::NVVMMemorySpace::Shared);
auto intId = isSharedMemory? llvm::Intrinsic::nvvm_mapa_shared_cluster : llvm::Intrinsic::nvvm_mapa;
$res = createIntrinsicCall(builder, intId, {$a, $b});

View File

@@ -451,16 +451,14 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
converter, [](gpu::AddressSpace space) -> unsigned {
switch (space) {
case gpu::AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
case gpu::AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
case gpu::AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
});
// Lowering for MMAMatrixType.
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
@@ -648,7 +646,7 @@ void mlir::populateGpuToNVVMConversionPatterns(
GPUFuncOpLoweringOptions{
/*allocaAddrSpace=*/0,
/*workgroupAddrSpace=*/
static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
StringAttr::get(&converter.getContext(),
NVVM::NVVMDialect::getKernelFuncAttrName()),
StringAttr::get(&converter.getContext(),

View File

@@ -405,16 +405,14 @@ struct ConvertNVGPUToNVVMPass
converter, [](gpu::AddressSpace space) -> unsigned {
switch (space) {
case gpu::AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
case gpu::AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
case gpu::AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
});
/// device-side async tokens cannot be materialized in nvvm. We just
/// convert them to a dummy i32 type in order to easily drop them during
@@ -677,7 +675,7 @@ struct NVGPUAsyncCopyLowering
adaptor.getSrcIndices());
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
op->getContext(), static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =

View File

@@ -71,16 +71,14 @@ void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
llvmTypeConverter, [](AddressSpace space) -> unsigned {
switch (space) {
case AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
case AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
case AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
});
// Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
// TODO: We should have a single to_nvvm_type_converter.

View File

@@ -57,10 +57,10 @@ void LLVMDialect::registerAttributes() {
//===----------------------------------------------------------------------===//
/// Checks whether the given type is an LLVM type that can be loaded or stored.
static bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering,
std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) {
bool LLVM::detail::isValidLoadStoreImpl(
Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) {
if (!isLoadableType(type)) {
if (emitError)
emitError() << "type must be LLVM type with size, but got " << type;
@@ -87,14 +87,16 @@ bool AddressSpaceAttr::isValidLoad(
Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError);
return detail::isValidLoadStoreImpl(type, ordering, alignment, dataLayout,
emitError);
}
bool AddressSpaceAttr::isValidStore(
Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError);
return detail::isValidLoadStoreImpl(type, ordering, alignment, dataLayout,
emitError);
}
bool AddressSpaceAttr::isValidAtomicOp(

View File

@@ -703,14 +703,14 @@ const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
//===----------------------------------------------------------------------===//
/// Check whether type is a compatible ptr type. These are pointer-like types
/// with no element type, no metadata, and using the LLVM AddressSpaceAttr
/// memory space.
/// with no element type, no metadata, and using the LLVM
/// LLVMAddrSpaceAttrInterface memory space.
static bool isCompatiblePtrType(Type type) {
auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type);
if (!ptrTy)
return false;
return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr &&
isa<AddressSpaceAttr>(ptrTy.getMemorySpace());
isa<LLVMAddrSpaceAttrInterface>(ptrTy.getMemorySpace());
}
bool mlir::LLVM::isCompatibleOuterType(Type type) {

View File

@@ -800,8 +800,8 @@ inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
LogicalResult NVVM::WMMALoadOp::verify() {
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
addressSpace != NVVM::kSharedMemorySpace)
if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
addressSpace != NVVMMemorySpace::Shared)
return emitOpError("expected source pointer in memory "
"space 0, 1, 3");
@@ -821,8 +821,8 @@ LogicalResult NVVM::WMMALoadOp::verify() {
LogicalResult NVVM::WMMAStoreOp::verify() {
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
addressSpace != NVVM::kSharedMemorySpace)
if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
addressSpace != NVVMMemorySpace::Shared)
return emitOpError("expected operands to be a source pointer in memory "
"space 0, 1, 3");
@@ -1339,8 +1339,8 @@ LogicalResult NVVM::PrefetchOp::verify() {
return emitOpError("cannot specify both tensormap and cache level");
if (getTensormap()) {
if (addressSpace != MemSpace::kGenericMemorySpace &&
addressSpace != MemSpace::kConstantMemorySpace) {
if (addressSpace != MemSpace::Generic &&
addressSpace != MemSpace::Constant) {
return emitOpError(
"prefetch tensormap requires a generic or constant pointer");
}
@@ -1350,15 +1350,14 @@ LogicalResult NVVM::PrefetchOp::verify() {
"prefetch tensormap does not support eviction priority");
}
if (getInParamSpace() && addressSpace != MemSpace::kGenericMemorySpace) {
if (getInParamSpace() && addressSpace != MemSpace::Generic) {
return emitOpError(
"in_param_space can only be specified for a generic pointer");
}
} else if (cacheLevel) {
if (addressSpace != MemSpace::kGenericMemorySpace &&
addressSpace != MemSpace::kGlobalMemorySpace &&
addressSpace != MemSpace::kLocalMemorySpace) {
if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
addressSpace != MemSpace::Local) {
return emitOpError("prefetch to cache level requires a generic, global, "
"or local pointer");
}
@@ -1370,7 +1369,7 @@ LogicalResult NVVM::PrefetchOp::verify() {
"cache level is L1");
}
if (addressSpace != MemSpace::kGenericMemorySpace) {
if (addressSpace != MemSpace::Generic) {
return emitOpError(
"prefetch to uniform cache requires a generic pointer");
}
@@ -1381,7 +1380,7 @@ LogicalResult NVVM::PrefetchOp::verify() {
return emitOpError(
"cache eviction priority supported only for cache level L2");
if (addressSpace != MemSpace::kGlobalMemorySpace)
if (addressSpace != MemSpace::Global)
return emitOpError("cache eviction priority requires a global pointer");
if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
@@ -1796,7 +1795,7 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
.getAddressSpace();
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
bool isShared = as == NVVMMemorySpace::Shared;
bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
llvm::Intrinsic::ID id;
@@ -1845,7 +1844,7 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
.getAddressSpace();
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
bool isShared = as == NVVMMemorySpace::Shared;
bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
@@ -2051,18 +2050,18 @@ PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
}
}
switch (addressSpace) {
case MemSpace::kGenericMemorySpace:
switch (static_cast<MemSpace>(addressSpace)) {
case MemSpace::Generic:
return *cacheLevel == CacheLevel::L1
? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
: NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
case MemSpace::kGlobalMemorySpace:
case MemSpace::Global:
return *cacheLevel == CacheLevel::L1
? NVVM::IDArgPair(
{llvm::Intrinsic::nvvm_prefetch_global_L1, args})
: NVVM::IDArgPair(
{llvm::Intrinsic::nvvm_prefetch_global_L2, args});
case MemSpace::kLocalMemorySpace:
case MemSpace::Local:
return *cacheLevel == CacheLevel::L1
? NVVM::IDArgPair(
{llvm::Intrinsic::nvvm_prefetch_local_L1, args})
@@ -2185,6 +2184,66 @@ LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
return success();
}
//===----------------------------------------------------------------------===//
// NVVM Address Space Attr
//===----------------------------------------------------------------------===//
unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
return static_cast<unsigned>(getValue());
}
bool NVVMMemorySpaceAttr::isValidLoad(
Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
dataLayout, emitError);
}
bool NVVMMemorySpaceAttr::isValidStore(
Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
dataLayout, emitError);
}
bool NVVMMemorySpaceAttr::isValidAtomicOp(
ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once `ptr.atomic_rmw` is implemented.
assert(false && "unimplemented, see TODO in the source.");
return false;
}
bool NVVMMemorySpaceAttr::isValidAtomicXchg(
Type type, ptr::AtomicOrdering successOrdering,
ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
assert(false && "unimplemented, see TODO in the source.");
return false;
}
bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once the `ptr.addrspace_cast` op is added to the
// dialect.
assert(false && "unimplemented, see TODO in the source.");
return false;
}
bool NVVMMemorySpaceAttr::isValidPtrIntCast(
Type intLikeTy, Type ptrLikeTy,
function_ref<InFlightDiagnostic()> emitError) const {
// TODO: update this method once the int-cast ops are added to the `ptr`
// dialect.
assert(false && "unimplemented, see TODO in the source.");
return false;
}
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//

View File

@@ -53,16 +53,14 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
switch (space) {
case gpu::AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
case gpu::AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
case gpu::AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
});
llvmTypeConverter.addConversion(
[&](nvgpu::DeviceAsyncTokenType type) -> Type {

View File

@@ -253,8 +253,8 @@ getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
/// Return the intrinsic ID associated with st.bulk for the given address type.
static llvm::Intrinsic::ID
getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
bool isSharedMemory =
addrType.getAddressSpace() == NVVM::NVVMMemorySpace::kSharedMemorySpace;
bool isSharedMemory = addrType.getAddressSpace() ==
static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
return isSharedMemory ? llvm::Intrinsic::nvvm_st_bulk_shared_cta
: llvm::Intrinsic::nvvm_st_bulk;
}

View File

@@ -152,8 +152,9 @@ private:
/// Translates the given ptr type.
llvm::Type *translate(PtrLikeTypeInterface type) {
auto memSpace = dyn_cast<LLVM::AddressSpaceAttr>(type.getMemorySpace());
assert(memSpace && "expected pointer with the LLVM address space");
auto memSpace =
dyn_cast<LLVM::LLVMAddrSpaceAttrInterface>(type.getMemorySpace());
assert(memSpace && "expected pointer with an LLVM address space");
assert(!type.hasPtrMetadata() && "expected pointer without metadata");
return llvm::PointerType::get(context, memSpace.getAddressSpace());
}

View File

@@ -620,6 +620,16 @@ func.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
return
}
// CHECK-LABEL: @nvvm_address_space
func.func private @nvvm_address_space(
!ptr.ptr<#nvvm.memory_space<global>>,
!ptr.ptr<#nvvm.memory_space<shared>>,
!ptr.ptr<#nvvm.memory_space<constant>>,
!ptr.ptr<#nvvm.memory_space<local>>,
!ptr.ptr<#nvvm.memory_space<tensor>>,
!ptr.ptr<#nvvm.memory_space<shared_cluster>>
) -> !ptr.ptr<#nvvm.memory_space<generic>>
// -----
// Just check these don't emit errors.

View File

@@ -233,3 +233,25 @@ llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.addr
%res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
}
// CHECK-LABEL: declare ptr @nvvm_ptr_address_space(ptr addrspace(1), ptr addrspace(3), ptr addrspace(4), ptr addrspace(5), ptr addrspace(6), ptr addrspace(7))
llvm.func @nvvm_ptr_address_space(
!ptr.ptr<#nvvm.memory_space<global>>,
!ptr.ptr<#nvvm.memory_space<shared>>,
!ptr.ptr<#nvvm.memory_space<constant>>,
!ptr.ptr<#nvvm.memory_space<local>>,
!ptr.ptr<#nvvm.memory_space<tensor>>,
!ptr.ptr<#nvvm.memory_space<shared_cluster>>
) -> !ptr.ptr<#nvvm.memory_space<generic>>
// CHECK-LABEL: define void @llvm_ops_with_ptr_nvvm_values
// CHECK-SAME: (ptr %[[ARG:.*]]) {
// CHECK-NEXT: %[[V0:.*]] = load ptr addrspace(1), ptr %[[ARG]], align 8
// CHECK-NEXT: store ptr addrspace(1) %[[V0]], ptr %[[ARG]], align 8
// CHECK-NEXT: ret void
// CHECK-NEXT: }
llvm.func @llvm_ops_with_ptr_nvvm_values(%arg0: !llvm.ptr) {
%1 = llvm.load %arg0 : !llvm.ptr -> !ptr.ptr<#nvvm.memory_space<global>>
llvm.store %1, %arg0 : !ptr.ptr<#nvvm.memory_space<global>>, !llvm.ptr
llvm.return
}