Add generic type attribute mapping infrastructure, use it in GpuToX

Remapping memory spaces is a function often needed in type
conversions, most often when going to LLVM or to/from SPIR-V (a future
commit), and it is possible that such remappings may become more
common in the future as dialects take advantage of the more generic
memory space infrastructure.

Currently, memory space remappings are handled by running a
special-purpose conversion pass before the main conversion that
changes the address space attributes. In this commit, this approach is
replaced by adding a notion of type attribute conversions
TypeConverter, which is then used to convert memory space attributes.

Then, we use this infrastructure throughout the *ToLLVM conversions.
This has the advantage of loosing the requirements on the inputs to
those passes from "all address spaces must be integers" to "all
memory spaces must be convertible to integer spaces", a looser
requirement that reduces the coupling between portions of MLIR.

ON top of that, this change leads to the removal of most of the calls
to getMemorySpaceAsInt(), bringing us closer to removing it.

(A rework of the SPIR-V conversions to use this new system will be in
a folowup commit.)

As a note, one long-term motivation for this change is that I would
eventually like to add an allocaMemorySpace key to MLIR data layouts
and then call getMemRefAddressSpace(allocaMemorySpace) in the
relevant *ToLLVM in order to ensure all alloca()s, whether incoming or
produces during the LLVM lowering, have the correct address space for
a given target.

I expect that the type attribute conversion system may be useful in
other contexts.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D142159
This commit is contained in:
Krzysztof Drewniak
2023-01-19 21:56:04 +00:00
parent 3c565c2466
commit 499abb243c
24 changed files with 411 additions and 384 deletions

View File

@@ -147,6 +147,11 @@ public:
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
const DataLayout &layout);
/// Return the LLVM address space corresponding to the memory space of the
/// memref type `type` or failure if the memory space cannot be converted to
/// an integer.
FailureOr<unsigned> getMemRefAddressSpace(BaseMemRefType type);
/// Check if a memref type can be converted to a bare pointer.
static bool canConvertToBarePtr(BaseMemRefType type);

View File

@@ -25,7 +25,8 @@ def AMDGPU_Dialect : Dialect {
let dependentDialects = [
"arith::ArithDialect"
"arith::ArithDialect",
"gpu::GPUDialect"
];
let useDefaultAttributePrinterParser = 1;
}

View File

@@ -61,23 +61,6 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
}
namespace gpu {
/// A function that maps a MemorySpace enum to a target-specific integer value.
using MemorySpaceMapping =
std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
/// Populates type conversion rules for lowering memory space attributes to
/// numeric values.
void populateMemorySpaceAttributeTypeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
/// Populates patterns to lower memory space attributes to numeric values.
void populateMemorySpaceLoweringPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Populates legality rules for lowering memory space attriutes to numeric
/// values.
void populateLowerMemorySpaceOpLegality(ConversionTarget &target);
/// Returns the default annotation name for GPU binary blobs.
std::string getDefaultGpuBinaryAnnotation();

View File

@@ -37,23 +37,4 @@ def GpuMapParallelLoopsPass
let dependentDialects = ["mlir::gpu::GPUDialect"];
}
def GPULowerMemorySpaceAttributesPass
: Pass<"gpu-lower-memory-space-attributes"> {
let summary = "Assign numeric values to memref memory space symbolic placeholders";
let description = [{
Updates all memref types that have a memory space attribute
that is a `gpu::AddressSpaceAttr`. These attributes are
changed to `IntegerAttr`'s using a mapping that is given in the
options.
}];
let options = [
Option<"privateAddrSpace", "private", "unsigned", "5",
"private address space numeric value">,
Option<"workgroupAddrSpace", "workgroup", "unsigned", "3",
"workgroup address space numeric value">,
Option<"globalAddrSpace", "global", "unsigned", "1",
"global address space numeric value">
];
}
#endif // MLIR_DIALECT_GPU_PASSES

View File

@@ -21,6 +21,7 @@
namespace mlir {
// Forward declarations.
class Attribute;
class Block;
class ConversionPatternRewriter;
class MLIRContext;
@@ -87,6 +88,34 @@ public:
SmallVector<Type, 4> argTypes;
};
/// The general result of a type attribute conversion callback, allowing
/// for early termination. The default constructor creates the na case.
class AttributeConversionResult {
public:
constexpr AttributeConversionResult() : impl() {}
AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
static AttributeConversionResult result(Attribute attr);
static AttributeConversionResult na();
static AttributeConversionResult abort();
bool hasResult() const;
bool isNa() const;
bool isAbort() const;
Attribute getResult() const;
private:
AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
llvm::PointerIntPair<Attribute, 2> impl;
// Note that na is 0 so that we can use PointerIntPair's default
// constructor.
static constexpr unsigned naTag = 0;
static constexpr unsigned resultTag = 1;
static constexpr unsigned abortTag = 2;
};
/// Register a conversion function. A conversion function must be convertible
/// to any of the following forms(where `T` is a class derived from `Type`:
/// * std::optional<Type>(T)
@@ -156,6 +185,34 @@ public:
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// Register a conversion function for attributes within types. Type
/// converters may call this function in order to allow hoking into the
/// translation of attributes that exist within types. For example, a type
/// converter for the `memref` type could use these conversions to convert
/// memory spaces or layouts in an extensible way.
///
/// The conversion functions take a non-null Type or subclass of Type and a
/// non-null Attribute (or subclass of Attribute), and returns a
/// `AttributeConversionResult`. This result can either contan an `Attribute`,
/// which may be `nullptr`, representing the conversion's success,
/// `AttributeConversionResult::na()` (the default empty value), indicating
/// that the conversion function did not apply and that further conversion
/// functions should be checked, or `AttributeConversionResult::abort()`
/// indicating that the conversion process should be aborted.
///
/// Registered conversion functions are callled in the reverse of the order in
/// which they were registered.
template <
typename FnT,
typename T =
typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
typename A =
typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
void addTypeAttributeConversion(FnT &&callback) {
registerTypeAttributeConversion(
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
}
/// Convert the given type. This function should return failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
@@ -226,6 +283,12 @@ public:
resultType, inputs);
}
/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
/// registered, return std::nullopt. Note that the empty attribute/`nullptr`
/// is a valid return value for this function.
std::optional<Attribute> convertTypeAttribute(Type type, Attribute attr);
private:
/// The signature of the callback used to convert a type. If the new set of
/// types is empty, the type is removed and any usages of the existing value
@@ -237,6 +300,10 @@ private:
using MaterializationCallbackFn = std::function<std::optional<Value>(
OpBuilder &, Type, ValueRange, Location)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
std::function<AttributeConversionResult(Type, Attribute)>;
/// Attempt to materialize a conversion using one of the provided
/// materialization functions.
Value materializeConversion(
@@ -311,6 +378,32 @@ private:
};
}
/// Generate a wrapper for the given memory space conversion callback. The
/// callback may take any subclass of `Attribute` and the wrapper will check
/// for the target attribute to be of the expected class before calling the
/// callback.
template <typename T, typename A, typename FnT>
TypeAttributeConversionCallbackFn
wrapTypeAttributeConversion(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
Type type, Attribute attr) -> AttributeConversionResult {
if (T derivedType = type.dyn_cast<T>()) {
if (A derivedAttr = attr.dyn_cast_or_null<A>())
return callback(derivedType, derivedAttr);
}
return AttributeConversionResult::na();
};
}
/// Register a memory space conversion, clearing caches.
void
registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
typeAttributeConversions.emplace_back(std::move(callback));
// Clear type conversions in case a memory space is lingering inside.
cachedDirectConversions.clear();
cachedMultiConversions.clear();
}
/// The set of registered conversion functions.
SmallVector<ConversionCallbackFn, 4> conversions;
@@ -319,6 +412,9 @@ private:
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
/// The list of registered type attribute conversion functions.
SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
/// A set of cached conversions to avoid recomputing in the common case.
/// Direct 1-1 conversions are the most common, so this cache stores the
/// successful 1-1 conversions as well as all failed conversions.

View File

@@ -8,6 +8,7 @@
#include "GPUOpsLowering.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
@@ -474,3 +475,18 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
rewriter.replaceOp(op, result);
return success();
}
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}
void mlir::populateGpuMemorySpaceAttributeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addTypeAttributeConversion(
[mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
unsigned addressSpace = mapping(memorySpace);
return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
addressSpace);
});
}

View File

@@ -112,6 +112,14 @@ public:
}
};
/// A function that maps a MemorySpace enum to a target-specific integer value.
using MemorySpaceMapping =
std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
/// Populates memory space attribute conversion rules for lowering
/// gpu.address_space to integer values.
void populateGpuMemorySpaceAttributeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_

View File

@@ -241,38 +241,26 @@ struct LowerGpuOpsToNVVMOpsPass
return signalPassFailure();
}
// MemRef conversion for GPU to NVVM lowering.
{
RewritePatternSet patterns(m.getContext());
TypeConverter typeConverter;
typeConverter.addConversion([](Type t) { return t; });
// NVVM uses alloca in the default address space to represent private
// memory allocations, so drop private annotations. NVVM uses address
// space 3 for shared memory. NVVM uses the default address space to
// represent global memory.
gpu::populateMemorySpaceAttributeTypeConversions(
typeConverter, [](gpu::AddressSpace space) -> unsigned {
switch (space) {
case gpu::AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
case gpu::AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
case gpu::AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
ConversionTarget target(getContext());
gpu::populateLowerMemorySpaceOpLegality(target);
if (failed(applyFullConversion(m, target, std::move(patterns))))
return signalPassFailure();
}
LLVMTypeConverter converter(m.getContext(), options);
// NVVM uses alloca in the default address space to represent private
// memory allocations, so drop private annotations. NVVM uses address
// space 3 for shared memory. NVVM uses the default address space to
// represent global memory.
populateGpuMemorySpaceAttributeConversions(
converter, [](gpu::AddressSpace space) -> unsigned {
switch (space) {
case gpu::AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
case gpu::AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
case gpu::AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
// Lowering for MMAMatrixType.
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
return convertMMAToLLVMType(type);

View File

@@ -132,33 +132,21 @@ struct LowerGpuOpsToROCDLOpsPass
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
}
// Apply memory space lowering. The target uses 3 for workgroup memory and 5
// for private memory.
{
RewritePatternSet patterns(ctx);
TypeConverter typeConverter;
typeConverter.addConversion([](Type t) { return t; });
gpu::populateMemorySpaceAttributeTypeConversions(
typeConverter, [](gpu::AddressSpace space) {
switch (space) {
case gpu::AddressSpace::Global:
return 1;
case gpu::AddressSpace::Workgroup:
return 3;
case gpu::AddressSpace::Private:
return 5;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
ConversionTarget target(getContext());
gpu::populateLowerMemorySpaceOpLegality(target);
gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
if (failed(applyFullConversion(m, target, std::move(patterns))))
return signalPassFailure();
}
LLVMTypeConverter converter(ctx, options);
populateGpuMemorySpaceAttributeConversions(
converter, [](gpu::AddressSpace space) {
switch (space) {
case gpu::AddressSpace::Global:
return 1;
case gpu::AddressSpace::Workgroup:
return 3;
case gpu::AddressSpace::Private:
return 5;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
RewritePatternSet llvmPatterns(ctx);
mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);

View File

@@ -112,8 +112,10 @@ bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
auto elementType = type.getElementType();
auto structElementType = typeConverter->convertType(elementType);
return getTypeConverter()->getPointerType(structElementType,
type.getMemorySpaceAsInt());
auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
if (failed(addressSpace))
return {};
return getTypeConverter()->getPointerType(structElementType, *addressSpace);
}
void ConvertToLLVMPattern::getMemRefDescriptorSizes(

View File

@@ -158,6 +158,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
// Integer memory spaces map to themselves.
addTypeAttributeConversion(
[](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
}
/// Returns the MLIR context.
@@ -318,8 +322,17 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
if (!elementType)
return {};
LLVM::LLVMPointerType ptrTy =
getPointerType(elementType, type.getMemorySpaceAsInt());
FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
if (failed(addressSpace)) {
emitError(UnknownLoc::get(type.getContext()),
"conversion of memref memory space ")
<< type.getMemorySpace()
<< " to integer address space "
"failed. Consider adding memory space conversions.";
return {};
}
auto ptrTy = getPointerType(elementType, *addressSpace);
auto indexTy = getIndexType();
SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
@@ -337,7 +350,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
const DataLayout &layout) {
// Compute the descriptor size given that of its components indicated above.
unsigned space = type.getMemorySpaceAsInt();
unsigned space = *getMemRefAddressSpace(type);
return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
(1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
}
@@ -369,7 +382,7 @@ unsigned
LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
const DataLayout &layout) {
// Compute the descriptor size given that of its components indicated above.
unsigned space = type.getMemorySpaceAsInt();
unsigned space = *getMemRefAddressSpace(type);
return layout.getTypeSize(getIndexType()) +
llvm::divideCeil(getPointerBitwidth(space), 8);
}
@@ -381,6 +394,21 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
getUnrankedMemRefDescriptorFields());
}
FailureOr<unsigned>
LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
if (!type.getMemorySpace()) // Default memory space -> 0.
return 0;
Optional<Attribute> converted =
convertTypeAttribute(type, type.getMemorySpace());
if (!converted)
return failure();
if (!(*converted)) // Conversion to default is 0.
return 0;
if (auto explicitSpace = converted->dyn_cast_or_null<IntegerAttr>())
return explicitSpace.getInt();
return failure();
}
// Check if a memref type can be converted to a bare pointer.
bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
if (type.isa<UnrankedMemRefType>())
@@ -412,7 +440,10 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
Type elementType = convertType(type.getElementType());
if (!elementType)
return {};
return getPointerType(elementType, type.getMemorySpaceAsInt());
FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
if (failed(addressSpace))
return {};
return getPointerType(elementType, *addressSpace);
}
/// Convert an n-D vector type to an LLVM vector type:

View File

@@ -59,11 +59,12 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
MemRefType memRefType, Type elementPtrType,
LLVMTypeConverter &typeConverter) {
auto allocatedPtrTy = allocatedPtr.getType().cast<LLVM::LLVMPointerType>();
if (allocatedPtrTy.getAddressSpace() != memRefType.getMemorySpaceAsInt())
unsigned memrefAddrSpace = *typeConverter.getMemRefAddressSpace(memRefType);
if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
loc,
typeConverter.getPointerType(allocatedPtrTy.getElementType(),
memRefType.getMemorySpaceAsInt()),
memrefAddrSpace),
allocatedPtr);
if (!typeConverter.useOpaquePointers())

View File

@@ -96,8 +96,10 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
auto allocaOp = cast<memref::AllocaOp>(op);
auto elementType =
typeConverter->convertType(allocaOp.getType().getElementType());
auto elementPtrType = getTypeConverter()->getPointerType(
elementType, allocaOp.getType().getMemorySpaceAsInt());
unsigned addrSpace =
*getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
auto elementPtrType =
getTypeConverter()->getPointerType(elementType, addrSpace);
auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
loc, elementPtrType, elementType, sizeBytes,
@@ -400,10 +402,11 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
ConversionPatternRewriter &rewriter) const override {
Type operandType = dimOp.getSource().getType();
if (operandType.isa<UnrankedMemRefType>()) {
rewriter.replaceOp(
dimOp, {extractSizeOfUnrankedMemRef(
operandType, dimOp, adaptor.getOperands(), rewriter)});
FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
operandType, dimOp, adaptor.getOperands(), rewriter);
if (failed(extractedSize))
return failure();
rewriter.replaceOp(dimOp, {*extractedSize});
return success();
}
if (operandType.isa<MemRefType>()) {
@@ -416,15 +419,23 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
}
private:
Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
FailureOr<Value>
extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = dimOp.getLoc();
auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
auto scalarMemRefType =
MemRefType::get({}, unrankedMemRefType.getElementType());
unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
FailureOr<unsigned> maybeAddressSpace =
getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
if (failed(maybeAddressSpace)) {
dimOp.emitOpError("memref memory space must be convertible to an integer "
"address space");
return failure();
}
unsigned addressSpace = *maybeAddressSpace;
// Extract pointer to the underlying ranked descriptor and bitcast it to a
// memref<element_type> descriptor pointer to minimize the number of GEP
@@ -455,8 +466,9 @@ private:
Value sizePtr = rewriter.create<LLVM::GEPOp>(
loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
idxPlusOne);
return rewriter.create<LLVM::LoadOp>(
loc, getTypeConverter()->getIndexType(), sizePtr);
return rewriter
.create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
.getResult();
}
std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
@@ -670,10 +682,14 @@ struct GlobalMemrefOpLowering
}
uint64_t alignment = global.getAlignment().value_or(0);
FailureOr<unsigned> addressSpace =
getTypeConverter()->getMemRefAddressSpace(type);
if (failed(addressSpace))
return global.emitOpError(
"memory space cannot be converted to an integer address space");
auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
global, arrayTy, global.getConstant(), linkage, global.getSymName(),
initialValue, alignment, type.getMemorySpaceAsInt());
initialValue, alignment, *addressSpace);
if (!global.isExternal() && global.isUninitialized()) {
Block *blk = new Block();
newGlobal.getInitializerRegion().push_back(blk);
@@ -701,7 +717,10 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
Operation *op) const override {
auto getGlobalOp = cast<memref::GetGlobalOp>(op);
MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
unsigned memSpace = type.getMemorySpaceAsInt();
// This is called after a type conversion, which would have failed if this
// call fails.
unsigned memSpace = *getTypeConverter()->getMemRefAddressSpace(type);
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
Type resTy = getTypeConverter()->getPointerType(arrayTy, memSpace);
@@ -1097,8 +1116,9 @@ static void extractPointersAndOffset(Location loc,
return;
}
unsigned memorySpace =
operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
// These will all cause assert()s on unconvertible types.
unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
operandType.cast<UnrankedMemRefType>());
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
Type llvmElementType = typeConverter.convertType(elementType);
LLVM::LLVMPointerType elementPtrType =
@@ -1298,7 +1318,8 @@ private:
// Extract address space and element type.
auto targetType =
reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
unsigned addressSpace = targetType.getMemorySpaceAsInt();
unsigned addressSpace =
*getTypeConverter()->getMemRefAddressSpace(targetType);
Type elementType = targetType.getElementType();
// Create the unranked memref descriptor that holds the ranked one. The
@@ -1564,14 +1585,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
// Field 1: Copy the allocated pointer, used for malloc/free.
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
unsigned sourceMemorySpace =
*getTypeConverter()->getMemRefAddressSpace(srcMemRefType);
Value bitcastPtr;
if (getTypeConverter()->useOpaquePointers())
bitcastPtr = allocatedPtr;
else
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(targetElementTy,
srcMemRefType.getMemorySpaceAsInt()),
loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
allocatedPtr);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
@@ -1587,9 +1608,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
bitcastPtr = alignedPtr;
} else {
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(targetElementTy,
srcMemRefType.getMemorySpaceAsInt()),
loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
alignedPtr);
}

View File

@@ -572,16 +572,24 @@ struct NVGPUAsyncCopyLowering
Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
adaptor.getDstIndices(), rewriter);
auto i8Ty = IntegerType::get(op.getContext(), 8);
auto dstPointerType =
LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
FailureOr<unsigned> dstAddressSpace =
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
if (failed(dstAddressSpace))
return rewriter.notifyMatchFailure(
loc, "destination memref address space not convertible to integer");
auto dstPointerType = LLVM::LLVMPointerType::get(i8Ty, *dstAddressSpace);
dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
auto srcMemrefType = op.getSrc().getType().cast<MemRefType>();
FailureOr<unsigned> srcAddressSpace =
getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
if (failed(srcAddressSpace))
return rewriter.notifyMatchFailure(
loc, "source memref address space not convertible to integer");
Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
adaptor.getSrcIndices(), rewriter);
auto srcPointerType =
LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
auto srcPointerType = LLVM::LLVMPointerType::get(i8Ty, *srcAddressSpace);
scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(

View File

@@ -742,6 +742,8 @@ convertTransferReadToLoads(vector::TransferReadOp op,
if (failed(warpMatrixInfo))
return failure();
Attribute memorySpace =
op.getSource().getType().cast<MemRefType>().getMemorySpace();
bool isLdMatrixCompatible =
isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;

View File

@@ -8,6 +8,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -92,21 +93,25 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
}
// Check if the last stride is non-unit or the memory space is not zero.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
LLVMTypeConverter &converter) {
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (failed(successStrides) || strides.back() != 1 ||
memRefType.getMemorySpaceAsInt() != 0)
FailureOr<unsigned> addressSpace =
converter.getMemRefAddressSpace(memRefType);
if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) ||
*addressSpace != 0)
return failure();
return success();
}
// Add an index vector component to a base pointer.
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType memRefType, Value llvmMemref, Value base,
Value index, uint64_t vLen) {
assert(succeeded(isMemRefTypeSupported(memRefType)) &&
assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
"unsupported memref type");
auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
@@ -116,8 +121,10 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
// Casts a strided element pointer to a vector pointer. The vector pointer
// will be in the same address space as the incoming memref type.
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, MemRefType memRefType, Type vt) {
auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
Value ptr, MemRefType memRefType, Type vt,
LLVMTypeConverter &converter) {
unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType);
auto pType = LLVM::LLVMPointerType::get(vt, addressSpace);
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
}
@@ -245,7 +252,8 @@ public:
.template cast<VectorType>();
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype,
*this->getTypeConverter());
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
return success();
@@ -264,7 +272,7 @@ public:
MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
assert(memRefType && "The base should be bufferized");
if (failed(isMemRefTypeSupported(memRefType)))
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
auto loc = gather->getLoc();
@@ -283,8 +291,8 @@ public:
if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
auto vType = gather.getVectorType();
// Resolve address.
Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
adaptor.getIndexVec(),
Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
memRefType, base, ptr, adaptor.getIndexVec(),
/*vLen=*/vType.getDimSize(0));
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -293,11 +301,14 @@ public:
return success();
}
auto callback = [align, memRefType, base, ptr, loc, &rewriter](
Type llvm1DVectorTy, ValueRange vectorOperands) {
LLVMTypeConverter &typeConverter = *this->getTypeConverter();
auto callback = [align, memRefType, base, ptr, loc, &rewriter,
&typeConverter](Type llvm1DVectorTy,
ValueRange vectorOperands) {
// Resolve address.
Value ptrs = getIndexedPtrs(
rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
rewriter, loc, typeConverter, memRefType, base, ptr,
/*index=*/vectorOperands[0],
LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
// Create the gather intrinsic.
return rewriter.create<LLVM::masked_gather>(
@@ -323,7 +334,7 @@ public:
auto loc = scatter->getLoc();
MemRefType memRefType = scatter.getMemRefType();
if (failed(isMemRefTypeSupported(memRefType)))
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
// Resolve alignment.
@@ -335,9 +346,9 @@ public:
VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptrs =
getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
Value ptrs = getIndexedPtrs(
rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(

View File

@@ -13,6 +13,7 @@
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -48,7 +49,16 @@ void AMDGPUDialect::initialize() {
template <typename T>
static LogicalResult verifyRawBufferOp(T &op) {
MemRefType bufferType = op.getMemref().getType().template cast<MemRefType>();
if (bufferType.getMemorySpaceAsInt() != 0)
Attribute memorySpace = bufferType.getMemorySpace();
bool isGlobal = false;
if (!memorySpace)
isGlobal = true;
else if (auto intMemorySpace = memorySpace.dyn_cast<IntegerAttr>())
isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
else if (auto gpuMemorySpace = memorySpace.dyn_cast<gpu::AddressSpaceAttr>())
isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
if (!isGlobal)
return op.emitOpError(
"Buffer ops must operate on a memref in global memory");
if (!bufferType.hasRank())

View File

@@ -11,6 +11,8 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
LINK_LIBS PUBLIC
MLIRArithDialect
# Needed for GPU address space enum definition
MLIRGPUOps
MLIRIR
MLIRSideEffectInterfaces
)

View File

@@ -52,7 +52,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/SerializeToBlob.cpp
Transforms/SerializeToCubin.cpp
Transforms/SerializeToHsaco.cpp
Transforms/LowerMemorySpaceAttributes.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU

View File

@@ -1,179 +0,0 @@
//===- LowerMemorySpaceAttributes.cpp ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// Implementation of a pass that rewrites the IR so that uses of
/// `gpu::AddressSpaceAttr` in memref memory space annotations are replaced
/// with caller-specified numeric values.
///
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_GPULOWERMEMORYSPACEATTRIBUTESPASS
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::gpu;
//===----------------------------------------------------------------------===//
// Conversion Target
//===----------------------------------------------------------------------===//
/// Returns true if the given `type` is considered as legal during memory space
/// attribute lowering.
static bool isLegalType(Type type) {
if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
return !memRefType.getMemorySpace()
.isa_and_nonnull<gpu::AddressSpaceAttr>();
}
return true;
}
/// Returns true if the given `attr` is considered legal during memory space
/// attribute lowering.
static bool isLegalAttr(Attribute attr) {
if (auto typeAttr = attr.dyn_cast<TypeAttr>())
return isLegalType(typeAttr.getValue());
return true;
}
/// Returns true if the given `op` is legal during memory space attribute
/// lowering.
static bool isLegalOp(Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
isLegalType);
}
auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
return attr.getValue();
});
return llvm::all_of(op->getOperandTypes(), isLegalType) &&
llvm::all_of(op->getResultTypes(), isLegalType) &&
llvm::all_of(attrs, isLegalAttr);
}
void gpu::populateLowerMemorySpaceOpLegality(ConversionTarget &target) {
target.markUnknownOpDynamicallyLegal(isLegalOp);
}
//===----------------------------------------------------------------------===//
// Type Converter
//===----------------------------------------------------------------------===//
IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}
void mlir::gpu::populateMemorySpaceAttributeTypeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addConversion([mapping](Type type) {
return type.replace([mapping](Attribute attr) -> std::optional<Attribute> {
auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
if (!memorySpaceAttr)
return std::nullopt;
auto newValue = wrapNumericMemorySpace(
attr.getContext(), mapping(memorySpaceAttr.getValue()));
return newValue;
});
});
}
namespace {
/// Converts any op that has operands/results/attributes with numeric MemRef
/// memory spaces.
struct LowerMemRefAddressSpacePattern final : public ConversionPattern {
LowerMemRefAddressSpacePattern(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<NamedAttribute> newAttrs;
newAttrs.reserve(op->getAttrs().size());
for (auto attr : op->getAttrs()) {
if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
} else {
newAttrs.push_back(attr);
}
}
SmallVector<Type> newResults;
(void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
newResults, newAttrs, op->getSuccessors());
for (Region &region : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
rewriter.applySignatureConversion(newRegion, result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
} // namespace
void mlir::gpu::populateMemorySpaceLoweringPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<LowerMemRefAddressSpacePattern>(patterns.getContext(),
typeConverter);
}
namespace {
class LowerMemorySpaceAttributesPass
: public mlir::impl::GPULowerMemorySpaceAttributesPassBase<
LowerMemorySpaceAttributesPass> {
public:
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
ConversionTarget target(getContext());
populateLowerMemorySpaceOpLegality(target);
TypeConverter typeConverter;
typeConverter.addConversion([](Type t) { return t; });
populateMemorySpaceAttributeTypeConversions(
typeConverter, [this](AddressSpace space) -> unsigned {
switch (space) {
case AddressSpace::Global:
return globalAddrSpace;
case AddressSpace::Workgroup:
return workgroupAddrSpace;
case AddressSpace::Private:
return privateAddrSpace;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
RewritePatternSet patterns(context);
populateMemorySpaceLoweringPatterns(typeConverter, patterns);
if (failed(applyFullConversion(op, target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

View File

@@ -3053,6 +3053,54 @@ auto TypeConverter::convertBlockSignature(Block *block)
return conversion;
}
//===----------------------------------------------------------------------===//
// Type attribute conversion
//===----------------------------------------------------------------------===//
TypeConverter::AttributeConversionResult
TypeConverter::AttributeConversionResult::result(Attribute attr) {
return AttributeConversionResult(attr, resultTag);
}
TypeConverter::AttributeConversionResult
TypeConverter::AttributeConversionResult::na() {
return AttributeConversionResult(nullptr, naTag);
}
TypeConverter::AttributeConversionResult
TypeConverter::AttributeConversionResult::abort() {
return AttributeConversionResult(nullptr, abortTag);
}
bool TypeConverter::AttributeConversionResult::hasResult() const {
return impl.getInt() == resultTag;
}
bool TypeConverter::AttributeConversionResult::isNa() const {
return impl.getInt() == naTag;
}
bool TypeConverter::AttributeConversionResult::isAbort() const {
return impl.getInt() == abortTag;
}
Attribute TypeConverter::AttributeConversionResult::getResult() const {
assert(hasResult() && "Cannot get result from N/A or abort");
return impl.getPointer();
}
Optional<Attribute> TypeConverter::convertTypeAttribute(Type type,
Attribute attr) {
for (TypeAttributeConversionCallbackFn &fn :
llvm::reverse(typeAttributeConversions)) {
AttributeConversionResult res = fn(type, attr);
if (res.hasResult())
return res.getResult();
if (res.isAbort())
return std::nullopt;
}
return std::nullopt;
}
//===----------------------------------------------------------------------===//
// FunctionOpInterfaceSignatureConversion
//===----------------------------------------------------------------------===//

View File

@@ -0,0 +1,48 @@
// RUN: mlir-opt %s -split-input-file -convert-gpu-to-rocdl | FileCheck %s --check-prefixes=CHECK,ROCDL
// RUN: mlir-opt %s -split-input-file -convert-gpu-to-nvvm | FileCheck %s --check-prefixes=CHECK,NVVM
gpu.module @kernel {
gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space<private>>) {
%c0 = arith.constant 0 : index
memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<private>>
gpu.return
}
}
// CHECK-LABEL: llvm.func @private
// CHECK: llvm.store
// ROCDL-SAME: : !llvm.ptr<f32, 5>
// NVVM-SAME: : !llvm.ptr<f32>
// -----
gpu.module @kernel {
gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space<workgroup>>) {
%c0 = arith.constant 0 : index
memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<workgroup>>
gpu.return
}
}
// CHECK-LABEL: llvm.func @workgroup
// CHECK: llvm.store
// CHECK-SAME: : !llvm.ptr<f32, 3>
// -----
gpu.module @kernel {
gpu.func @nested_memref(%arg0: memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>) -> f32 {
%c0 = arith.constant 0 : index
%inner = memref.load %arg0[%c0] : memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>
%value = memref.load %inner[%c0] : memref<4xf32, #gpu.address_space<global>>
gpu.return %value : f32
}
}
// CHECK-LABEL: llvm.func @nested_memref
// CHECK: llvm.load
// CHECK-SAME: : !llvm.ptr<{{.*}}, 1>
// CHECK: [[value:%.+]] = llvm.load
// CHECK-SAME: : !llvm.ptr<f32, 1>
// CHECK: llvm.return [[value]]

View File

@@ -0,0 +1,14 @@
// RUN: mlir-opt %s -finalize-memref-to-llvm -split-input-file 2>&1 | FileCheck %s
// Since the error is at an unknown location, we use FileCheck instead of
// -veri-y-diagnostics here
// CHECK: conversion of memref memory space "foo" to integer address space failed. Consider adding memory space conversions.
// CHECK-LABEL: @bad_address_space
func.func @bad_address_space(%a: memref<2xindex, "foo">) {
%c0 = arith.constant 0 : index
// CHECK: memref.store
memref.store %c0, %a[%c0] : memref<2xindex, "foo">
return
}
// -----

View File

@@ -1,55 +0,0 @@
// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes | FileCheck %s
// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes="private=0 global=0" | FileCheck %s --check-prefix=CUDA
gpu.module @kernel {
gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space<private>>) {
%c0 = arith.constant 0 : index
memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<private>>
gpu.return
}
}
// CHECK: gpu.func @private
// CHECK-SAME: private(%{{.+}}: memref<4xf32, 5>)
// CHECK: memref.store
// CHECK-SAME: : memref<4xf32, 5>
// CUDA: gpu.func @private
// CUDA-SAME: private(%{{.+}}: memref<4xf32>)
// CUDA: memref.store
// CUDA-SAME: : memref<4xf32>
// -----
gpu.module @kernel {
gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space<workgroup>>) {
%c0 = arith.constant 0 : index
memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<workgroup>>
gpu.return
}
}
// CHECK: gpu.func @workgroup
// CHECK-SAME: workgroup(%{{.+}}: memref<4xf32, 3>)
// CHECK: memref.store
// CHECK-SAME: : memref<4xf32, 3>
// -----
gpu.module @kernel {
gpu.func @nested_memref(%arg0: memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>) {
%c0 = arith.constant 0 : index
memref.load %arg0[%c0] : memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>
gpu.return
}
}
// CHECK: gpu.func @nested_memref
// CHECK-SAME: (%{{.+}}: memref<4xmemref<4xf32, 1>, 1>)
// CHECK: memref.load
// CHECK-SAME: : memref<4xmemref<4xf32, 1>, 1>
// CUDA: gpu.func @nested_memref
// CUDA-SAME: (%{{.+}}: memref<4xmemref<4xf32>>)
// CUDA: memref.load
// CUDA-SAME: : memref<4xmemref<4xf32>>