mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
Add generic type attribute mapping infrastructure, use it in GpuToX
Remapping memory spaces is a function often needed in type conversions, most often when going to LLVM or to/from SPIR-V (a future commit), and it is possible that such remappings may become more common in the future as dialects take advantage of the more generic memory space infrastructure. Currently, memory space remappings are handled by running a special-purpose conversion pass before the main conversion that changes the address space attributes. In this commit, this approach is replaced by adding a notion of type attribute conversions TypeConverter, which is then used to convert memory space attributes. Then, we use this infrastructure throughout the *ToLLVM conversions. This has the advantage of loosing the requirements on the inputs to those passes from "all address spaces must be integers" to "all memory spaces must be convertible to integer spaces", a looser requirement that reduces the coupling between portions of MLIR. ON top of that, this change leads to the removal of most of the calls to getMemorySpaceAsInt(), bringing us closer to removing it. (A rework of the SPIR-V conversions to use this new system will be in a folowup commit.) As a note, one long-term motivation for this change is that I would eventually like to add an allocaMemorySpace key to MLIR data layouts and then call getMemRefAddressSpace(allocaMemorySpace) in the relevant *ToLLVM in order to ensure all alloca()s, whether incoming or produces during the LLVM lowering, have the correct address space for a given target. I expect that the type attribute conversion system may be useful in other contexts. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D142159
This commit is contained in:
@@ -147,6 +147,11 @@ public:
|
||||
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
|
||||
const DataLayout &layout);
|
||||
|
||||
/// Return the LLVM address space corresponding to the memory space of the
|
||||
/// memref type `type` or failure if the memory space cannot be converted to
|
||||
/// an integer.
|
||||
FailureOr<unsigned> getMemRefAddressSpace(BaseMemRefType type);
|
||||
|
||||
/// Check if a memref type can be converted to a bare pointer.
|
||||
static bool canConvertToBarePtr(BaseMemRefType type);
|
||||
|
||||
|
||||
@@ -25,7 +25,8 @@ def AMDGPU_Dialect : Dialect {
|
||||
|
||||
|
||||
let dependentDialects = [
|
||||
"arith::ArithDialect"
|
||||
"arith::ArithDialect",
|
||||
"gpu::GPUDialect"
|
||||
];
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
}
|
||||
|
||||
@@ -61,23 +61,6 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
|
||||
}
|
||||
|
||||
namespace gpu {
|
||||
/// A function that maps a MemorySpace enum to a target-specific integer value.
|
||||
using MemorySpaceMapping =
|
||||
std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
|
||||
|
||||
/// Populates type conversion rules for lowering memory space attributes to
|
||||
/// numeric values.
|
||||
void populateMemorySpaceAttributeTypeConversions(
|
||||
TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
|
||||
|
||||
/// Populates patterns to lower memory space attributes to numeric values.
|
||||
void populateMemorySpaceLoweringPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Populates legality rules for lowering memory space attriutes to numeric
|
||||
/// values.
|
||||
void populateLowerMemorySpaceOpLegality(ConversionTarget &target);
|
||||
|
||||
/// Returns the default annotation name for GPU binary blobs.
|
||||
std::string getDefaultGpuBinaryAnnotation();
|
||||
|
||||
|
||||
@@ -37,23 +37,4 @@ def GpuMapParallelLoopsPass
|
||||
let dependentDialects = ["mlir::gpu::GPUDialect"];
|
||||
}
|
||||
|
||||
def GPULowerMemorySpaceAttributesPass
|
||||
: Pass<"gpu-lower-memory-space-attributes"> {
|
||||
let summary = "Assign numeric values to memref memory space symbolic placeholders";
|
||||
let description = [{
|
||||
Updates all memref types that have a memory space attribute
|
||||
that is a `gpu::AddressSpaceAttr`. These attributes are
|
||||
changed to `IntegerAttr`'s using a mapping that is given in the
|
||||
options.
|
||||
}];
|
||||
let options = [
|
||||
Option<"privateAddrSpace", "private", "unsigned", "5",
|
||||
"private address space numeric value">,
|
||||
Option<"workgroupAddrSpace", "workgroup", "unsigned", "3",
|
||||
"workgroup address space numeric value">,
|
||||
Option<"globalAddrSpace", "global", "unsigned", "1",
|
||||
"global address space numeric value">
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_GPU_PASSES
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
namespace mlir {
|
||||
|
||||
// Forward declarations.
|
||||
class Attribute;
|
||||
class Block;
|
||||
class ConversionPatternRewriter;
|
||||
class MLIRContext;
|
||||
@@ -87,6 +88,34 @@ public:
|
||||
SmallVector<Type, 4> argTypes;
|
||||
};
|
||||
|
||||
/// The general result of a type attribute conversion callback, allowing
|
||||
/// for early termination. The default constructor creates the na case.
|
||||
class AttributeConversionResult {
|
||||
public:
|
||||
constexpr AttributeConversionResult() : impl() {}
|
||||
AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
|
||||
|
||||
static AttributeConversionResult result(Attribute attr);
|
||||
static AttributeConversionResult na();
|
||||
static AttributeConversionResult abort();
|
||||
|
||||
bool hasResult() const;
|
||||
bool isNa() const;
|
||||
bool isAbort() const;
|
||||
|
||||
Attribute getResult() const;
|
||||
|
||||
private:
|
||||
AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
|
||||
|
||||
llvm::PointerIntPair<Attribute, 2> impl;
|
||||
// Note that na is 0 so that we can use PointerIntPair's default
|
||||
// constructor.
|
||||
static constexpr unsigned naTag = 0;
|
||||
static constexpr unsigned resultTag = 1;
|
||||
static constexpr unsigned abortTag = 2;
|
||||
};
|
||||
|
||||
/// Register a conversion function. A conversion function must be convertible
|
||||
/// to any of the following forms(where `T` is a class derived from `Type`:
|
||||
/// * std::optional<Type>(T)
|
||||
@@ -156,6 +185,34 @@ public:
|
||||
wrapMaterialization<T>(std::forward<FnT>(callback)));
|
||||
}
|
||||
|
||||
/// Register a conversion function for attributes within types. Type
|
||||
/// converters may call this function in order to allow hoking into the
|
||||
/// translation of attributes that exist within types. For example, a type
|
||||
/// converter for the `memref` type could use these conversions to convert
|
||||
/// memory spaces or layouts in an extensible way.
|
||||
///
|
||||
/// The conversion functions take a non-null Type or subclass of Type and a
|
||||
/// non-null Attribute (or subclass of Attribute), and returns a
|
||||
/// `AttributeConversionResult`. This result can either contan an `Attribute`,
|
||||
/// which may be `nullptr`, representing the conversion's success,
|
||||
/// `AttributeConversionResult::na()` (the default empty value), indicating
|
||||
/// that the conversion function did not apply and that further conversion
|
||||
/// functions should be checked, or `AttributeConversionResult::abort()`
|
||||
/// indicating that the conversion process should be aborted.
|
||||
///
|
||||
/// Registered conversion functions are callled in the reverse of the order in
|
||||
/// which they were registered.
|
||||
template <
|
||||
typename FnT,
|
||||
typename T =
|
||||
typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
|
||||
typename A =
|
||||
typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
|
||||
void addTypeAttributeConversion(FnT &&callback) {
|
||||
registerTypeAttributeConversion(
|
||||
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
|
||||
}
|
||||
|
||||
/// Convert the given type. This function should return failure if no valid
|
||||
/// conversion exists, success otherwise. If the new set of types is empty,
|
||||
/// the type is removed and any usages of the existing value are expected to
|
||||
@@ -226,6 +283,12 @@ public:
|
||||
resultType, inputs);
|
||||
}
|
||||
|
||||
/// Convert an attribute present `attr` from within the type `type` using
|
||||
/// the registered conversion functions. If no applicable conversion has been
|
||||
/// registered, return std::nullopt. Note that the empty attribute/`nullptr`
|
||||
/// is a valid return value for this function.
|
||||
std::optional<Attribute> convertTypeAttribute(Type type, Attribute attr);
|
||||
|
||||
private:
|
||||
/// The signature of the callback used to convert a type. If the new set of
|
||||
/// types is empty, the type is removed and any usages of the existing value
|
||||
@@ -237,6 +300,10 @@ private:
|
||||
using MaterializationCallbackFn = std::function<std::optional<Value>(
|
||||
OpBuilder &, Type, ValueRange, Location)>;
|
||||
|
||||
/// The signature of the callback used to convert a type attribute.
|
||||
using TypeAttributeConversionCallbackFn =
|
||||
std::function<AttributeConversionResult(Type, Attribute)>;
|
||||
|
||||
/// Attempt to materialize a conversion using one of the provided
|
||||
/// materialization functions.
|
||||
Value materializeConversion(
|
||||
@@ -311,6 +378,32 @@ private:
|
||||
};
|
||||
}
|
||||
|
||||
/// Generate a wrapper for the given memory space conversion callback. The
|
||||
/// callback may take any subclass of `Attribute` and the wrapper will check
|
||||
/// for the target attribute to be of the expected class before calling the
|
||||
/// callback.
|
||||
template <typename T, typename A, typename FnT>
|
||||
TypeAttributeConversionCallbackFn
|
||||
wrapTypeAttributeConversion(FnT &&callback) {
|
||||
return [callback = std::forward<FnT>(callback)](
|
||||
Type type, Attribute attr) -> AttributeConversionResult {
|
||||
if (T derivedType = type.dyn_cast<T>()) {
|
||||
if (A derivedAttr = attr.dyn_cast_or_null<A>())
|
||||
return callback(derivedType, derivedAttr);
|
||||
}
|
||||
return AttributeConversionResult::na();
|
||||
};
|
||||
}
|
||||
|
||||
/// Register a memory space conversion, clearing caches.
|
||||
void
|
||||
registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
|
||||
typeAttributeConversions.emplace_back(std::move(callback));
|
||||
// Clear type conversions in case a memory space is lingering inside.
|
||||
cachedDirectConversions.clear();
|
||||
cachedMultiConversions.clear();
|
||||
}
|
||||
|
||||
/// The set of registered conversion functions.
|
||||
SmallVector<ConversionCallbackFn, 4> conversions;
|
||||
|
||||
@@ -319,6 +412,9 @@ private:
|
||||
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
|
||||
SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
|
||||
|
||||
/// The list of registered type attribute conversion functions.
|
||||
SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
|
||||
|
||||
/// A set of cached conversions to avoid recomputing in the common case.
|
||||
/// Direct 1-1 conversions are the most common, so this cache stores the
|
||||
/// successful 1-1 conversions as well as all failed conversions.
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "GPUOpsLowering.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@@ -474,3 +475,18 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
|
||||
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
|
||||
}
|
||||
|
||||
void mlir::populateGpuMemorySpaceAttributeConversions(
|
||||
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
|
||||
typeConverter.addTypeAttributeConversion(
|
||||
[mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
|
||||
gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
|
||||
unsigned addressSpace = mapping(memorySpace);
|
||||
return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
|
||||
addressSpace);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -112,6 +112,14 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// A function that maps a MemorySpace enum to a target-specific integer value.
|
||||
using MemorySpaceMapping =
|
||||
std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
|
||||
|
||||
/// Populates memory space attribute conversion rules for lowering
|
||||
/// gpu.address_space to integer values.
|
||||
void populateGpuMemorySpaceAttributeConversions(
|
||||
TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
|
||||
|
||||
@@ -241,38 +241,26 @@ struct LowerGpuOpsToNVVMOpsPass
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// MemRef conversion for GPU to NVVM lowering.
|
||||
{
|
||||
RewritePatternSet patterns(m.getContext());
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type t) { return t; });
|
||||
// NVVM uses alloca in the default address space to represent private
|
||||
// memory allocations, so drop private annotations. NVVM uses address
|
||||
// space 3 for shared memory. NVVM uses the default address space to
|
||||
// represent global memory.
|
||||
gpu::populateMemorySpaceAttributeTypeConversions(
|
||||
typeConverter, [](gpu::AddressSpace space) -> unsigned {
|
||||
switch (space) {
|
||||
case gpu::AddressSpace::Global:
|
||||
return static_cast<unsigned>(
|
||||
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
|
||||
case gpu::AddressSpace::Workgroup:
|
||||
return static_cast<unsigned>(
|
||||
NVVM::NVVMMemorySpace::kSharedMemorySpace);
|
||||
case gpu::AddressSpace::Private:
|
||||
return 0;
|
||||
}
|
||||
llvm_unreachable("unknown address space enum value");
|
||||
return 0;
|
||||
});
|
||||
gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
|
||||
ConversionTarget target(getContext());
|
||||
gpu::populateLowerMemorySpaceOpLegality(target);
|
||||
if (failed(applyFullConversion(m, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
LLVMTypeConverter converter(m.getContext(), options);
|
||||
// NVVM uses alloca in the default address space to represent private
|
||||
// memory allocations, so drop private annotations. NVVM uses address
|
||||
// space 3 for shared memory. NVVM uses the default address space to
|
||||
// represent global memory.
|
||||
populateGpuMemorySpaceAttributeConversions(
|
||||
converter, [](gpu::AddressSpace space) -> unsigned {
|
||||
switch (space) {
|
||||
case gpu::AddressSpace::Global:
|
||||
return static_cast<unsigned>(
|
||||
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
|
||||
case gpu::AddressSpace::Workgroup:
|
||||
return static_cast<unsigned>(
|
||||
NVVM::NVVMMemorySpace::kSharedMemorySpace);
|
||||
case gpu::AddressSpace::Private:
|
||||
return 0;
|
||||
}
|
||||
llvm_unreachable("unknown address space enum value");
|
||||
return 0;
|
||||
});
|
||||
// Lowering for MMAMatrixType.
|
||||
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
|
||||
return convertMMAToLLVMType(type);
|
||||
|
||||
@@ -132,33 +132,21 @@ struct LowerGpuOpsToROCDLOpsPass
|
||||
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
|
||||
}
|
||||
|
||||
// Apply memory space lowering. The target uses 3 for workgroup memory and 5
|
||||
// for private memory.
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type t) { return t; });
|
||||
gpu::populateMemorySpaceAttributeTypeConversions(
|
||||
typeConverter, [](gpu::AddressSpace space) {
|
||||
switch (space) {
|
||||
case gpu::AddressSpace::Global:
|
||||
return 1;
|
||||
case gpu::AddressSpace::Workgroup:
|
||||
return 3;
|
||||
case gpu::AddressSpace::Private:
|
||||
return 5;
|
||||
}
|
||||
llvm_unreachable("unknown address space enum value");
|
||||
return 0;
|
||||
});
|
||||
ConversionTarget target(getContext());
|
||||
gpu::populateLowerMemorySpaceOpLegality(target);
|
||||
gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
|
||||
if (failed(applyFullConversion(m, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
LLVMTypeConverter converter(ctx, options);
|
||||
populateGpuMemorySpaceAttributeConversions(
|
||||
converter, [](gpu::AddressSpace space) {
|
||||
switch (space) {
|
||||
case gpu::AddressSpace::Global:
|
||||
return 1;
|
||||
case gpu::AddressSpace::Workgroup:
|
||||
return 3;
|
||||
case gpu::AddressSpace::Private:
|
||||
return 5;
|
||||
}
|
||||
llvm_unreachable("unknown address space enum value");
|
||||
return 0;
|
||||
});
|
||||
|
||||
RewritePatternSet llvmPatterns(ctx);
|
||||
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
|
||||
@@ -112,8 +112,10 @@ bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
|
||||
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
|
||||
auto elementType = type.getElementType();
|
||||
auto structElementType = typeConverter->convertType(elementType);
|
||||
return getTypeConverter()->getPointerType(structElementType,
|
||||
type.getMemorySpaceAsInt());
|
||||
auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
|
||||
if (failed(addressSpace))
|
||||
return {};
|
||||
return getTypeConverter()->getPointerType(structElementType, *addressSpace);
|
||||
}
|
||||
|
||||
void ConvertToLLVMPattern::getMemRefDescriptorSizes(
|
||||
|
||||
@@ -158,6 +158,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
|
||||
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
|
||||
.getResult(0);
|
||||
});
|
||||
|
||||
// Integer memory spaces map to themselves.
|
||||
addTypeAttributeConversion(
|
||||
[](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
|
||||
}
|
||||
|
||||
/// Returns the MLIR context.
|
||||
@@ -318,8 +322,17 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
|
||||
if (!elementType)
|
||||
return {};
|
||||
|
||||
LLVM::LLVMPointerType ptrTy =
|
||||
getPointerType(elementType, type.getMemorySpaceAsInt());
|
||||
FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
|
||||
if (failed(addressSpace)) {
|
||||
emitError(UnknownLoc::get(type.getContext()),
|
||||
"conversion of memref memory space ")
|
||||
<< type.getMemorySpace()
|
||||
<< " to integer address space "
|
||||
"failed. Consider adding memory space conversions.";
|
||||
return {};
|
||||
}
|
||||
auto ptrTy = getPointerType(elementType, *addressSpace);
|
||||
|
||||
auto indexTy = getIndexType();
|
||||
|
||||
SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
|
||||
@@ -337,7 +350,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
|
||||
unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
|
||||
const DataLayout &layout) {
|
||||
// Compute the descriptor size given that of its components indicated above.
|
||||
unsigned space = type.getMemorySpaceAsInt();
|
||||
unsigned space = *getMemRefAddressSpace(type);
|
||||
return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
|
||||
(1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
|
||||
}
|
||||
@@ -369,7 +382,7 @@ unsigned
|
||||
LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
|
||||
const DataLayout &layout) {
|
||||
// Compute the descriptor size given that of its components indicated above.
|
||||
unsigned space = type.getMemorySpaceAsInt();
|
||||
unsigned space = *getMemRefAddressSpace(type);
|
||||
return layout.getTypeSize(getIndexType()) +
|
||||
llvm::divideCeil(getPointerBitwidth(space), 8);
|
||||
}
|
||||
@@ -381,6 +394,21 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
|
||||
getUnrankedMemRefDescriptorFields());
|
||||
}
|
||||
|
||||
FailureOr<unsigned>
|
||||
LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
|
||||
if (!type.getMemorySpace()) // Default memory space -> 0.
|
||||
return 0;
|
||||
Optional<Attribute> converted =
|
||||
convertTypeAttribute(type, type.getMemorySpace());
|
||||
if (!converted)
|
||||
return failure();
|
||||
if (!(*converted)) // Conversion to default is 0.
|
||||
return 0;
|
||||
if (auto explicitSpace = converted->dyn_cast_or_null<IntegerAttr>())
|
||||
return explicitSpace.getInt();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check if a memref type can be converted to a bare pointer.
|
||||
bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
|
||||
if (type.isa<UnrankedMemRefType>())
|
||||
@@ -412,7 +440,10 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
|
||||
Type elementType = convertType(type.getElementType());
|
||||
if (!elementType)
|
||||
return {};
|
||||
return getPointerType(elementType, type.getMemorySpaceAsInt());
|
||||
FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
|
||||
if (failed(addressSpace))
|
||||
return {};
|
||||
return getPointerType(elementType, *addressSpace);
|
||||
}
|
||||
|
||||
/// Convert an n-D vector type to an LLVM vector type:
|
||||
|
||||
@@ -59,11 +59,12 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
|
||||
MemRefType memRefType, Type elementPtrType,
|
||||
LLVMTypeConverter &typeConverter) {
|
||||
auto allocatedPtrTy = allocatedPtr.getType().cast<LLVM::LLVMPointerType>();
|
||||
if (allocatedPtrTy.getAddressSpace() != memRefType.getMemorySpaceAsInt())
|
||||
unsigned memrefAddrSpace = *typeConverter.getMemRefAddressSpace(memRefType);
|
||||
if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
|
||||
allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
|
||||
loc,
|
||||
typeConverter.getPointerType(allocatedPtrTy.getElementType(),
|
||||
memRefType.getMemorySpaceAsInt()),
|
||||
memrefAddrSpace),
|
||||
allocatedPtr);
|
||||
|
||||
if (!typeConverter.useOpaquePointers())
|
||||
|
||||
@@ -96,8 +96,10 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
|
||||
auto allocaOp = cast<memref::AllocaOp>(op);
|
||||
auto elementType =
|
||||
typeConverter->convertType(allocaOp.getType().getElementType());
|
||||
auto elementPtrType = getTypeConverter()->getPointerType(
|
||||
elementType, allocaOp.getType().getMemorySpaceAsInt());
|
||||
unsigned addrSpace =
|
||||
*getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
|
||||
auto elementPtrType =
|
||||
getTypeConverter()->getPointerType(elementType, addrSpace);
|
||||
|
||||
auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, elementPtrType, elementType, sizeBytes,
|
||||
@@ -400,10 +402,11 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type operandType = dimOp.getSource().getType();
|
||||
if (operandType.isa<UnrankedMemRefType>()) {
|
||||
rewriter.replaceOp(
|
||||
dimOp, {extractSizeOfUnrankedMemRef(
|
||||
operandType, dimOp, adaptor.getOperands(), rewriter)});
|
||||
|
||||
FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
|
||||
operandType, dimOp, adaptor.getOperands(), rewriter);
|
||||
if (failed(extractedSize))
|
||||
return failure();
|
||||
rewriter.replaceOp(dimOp, {*extractedSize});
|
||||
return success();
|
||||
}
|
||||
if (operandType.isa<MemRefType>()) {
|
||||
@@ -416,15 +419,23 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
|
||||
}
|
||||
|
||||
private:
|
||||
Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
FailureOr<Value>
|
||||
extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = dimOp.getLoc();
|
||||
|
||||
auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
|
||||
auto scalarMemRefType =
|
||||
MemRefType::get({}, unrankedMemRefType.getElementType());
|
||||
unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
|
||||
FailureOr<unsigned> maybeAddressSpace =
|
||||
getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
|
||||
if (failed(maybeAddressSpace)) {
|
||||
dimOp.emitOpError("memref memory space must be convertible to an integer "
|
||||
"address space");
|
||||
return failure();
|
||||
}
|
||||
unsigned addressSpace = *maybeAddressSpace;
|
||||
|
||||
// Extract pointer to the underlying ranked descriptor and bitcast it to a
|
||||
// memref<element_type> descriptor pointer to minimize the number of GEP
|
||||
@@ -455,8 +466,9 @@ private:
|
||||
Value sizePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
|
||||
idxPlusOne);
|
||||
return rewriter.create<LLVM::LoadOp>(
|
||||
loc, getTypeConverter()->getIndexType(), sizePtr);
|
||||
return rewriter
|
||||
.create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
|
||||
@@ -670,10 +682,14 @@ struct GlobalMemrefOpLowering
|
||||
}
|
||||
|
||||
uint64_t alignment = global.getAlignment().value_or(0);
|
||||
|
||||
FailureOr<unsigned> addressSpace =
|
||||
getTypeConverter()->getMemRefAddressSpace(type);
|
||||
if (failed(addressSpace))
|
||||
return global.emitOpError(
|
||||
"memory space cannot be converted to an integer address space");
|
||||
auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
|
||||
global, arrayTy, global.getConstant(), linkage, global.getSymName(),
|
||||
initialValue, alignment, type.getMemorySpaceAsInt());
|
||||
initialValue, alignment, *addressSpace);
|
||||
if (!global.isExternal() && global.isUninitialized()) {
|
||||
Block *blk = new Block();
|
||||
newGlobal.getInitializerRegion().push_back(blk);
|
||||
@@ -701,7 +717,10 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
|
||||
Operation *op) const override {
|
||||
auto getGlobalOp = cast<memref::GetGlobalOp>(op);
|
||||
MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
|
||||
unsigned memSpace = type.getMemorySpaceAsInt();
|
||||
|
||||
// This is called after a type conversion, which would have failed if this
|
||||
// call fails.
|
||||
unsigned memSpace = *getTypeConverter()->getMemRefAddressSpace(type);
|
||||
|
||||
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
|
||||
Type resTy = getTypeConverter()->getPointerType(arrayTy, memSpace);
|
||||
@@ -1097,8 +1116,9 @@ static void extractPointersAndOffset(Location loc,
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned memorySpace =
|
||||
operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
|
||||
// These will all cause assert()s on unconvertible types.
|
||||
unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
|
||||
operandType.cast<UnrankedMemRefType>());
|
||||
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
|
||||
Type llvmElementType = typeConverter.convertType(elementType);
|
||||
LLVM::LLVMPointerType elementPtrType =
|
||||
@@ -1298,7 +1318,8 @@ private:
|
||||
// Extract address space and element type.
|
||||
auto targetType =
|
||||
reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
|
||||
unsigned addressSpace = targetType.getMemorySpaceAsInt();
|
||||
unsigned addressSpace =
|
||||
*getTypeConverter()->getMemRefAddressSpace(targetType);
|
||||
Type elementType = targetType.getElementType();
|
||||
|
||||
// Create the unranked memref descriptor that holds the ranked one. The
|
||||
@@ -1564,14 +1585,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
||||
// Field 1: Copy the allocated pointer, used for malloc/free.
|
||||
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||
auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
|
||||
unsigned sourceMemorySpace =
|
||||
*getTypeConverter()->getMemRefAddressSpace(srcMemRefType);
|
||||
Value bitcastPtr;
|
||||
if (getTypeConverter()->useOpaquePointers())
|
||||
bitcastPtr = allocatedPtr;
|
||||
else
|
||||
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
LLVM::LLVMPointerType::get(targetElementTy,
|
||||
srcMemRefType.getMemorySpaceAsInt()),
|
||||
loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
|
||||
allocatedPtr);
|
||||
|
||||
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
|
||||
@@ -1587,9 +1608,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
||||
bitcastPtr = alignedPtr;
|
||||
} else {
|
||||
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
LLVM::LLVMPointerType::get(targetElementTy,
|
||||
srcMemRefType.getMemorySpaceAsInt()),
|
||||
loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
|
||||
alignedPtr);
|
||||
}
|
||||
|
||||
|
||||
@@ -572,16 +572,24 @@ struct NVGPUAsyncCopyLowering
|
||||
Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
|
||||
adaptor.getDstIndices(), rewriter);
|
||||
auto i8Ty = IntegerType::get(op.getContext(), 8);
|
||||
auto dstPointerType =
|
||||
LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
|
||||
FailureOr<unsigned> dstAddressSpace =
|
||||
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
|
||||
if (failed(dstAddressSpace))
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "destination memref address space not convertible to integer");
|
||||
auto dstPointerType = LLVM::LLVMPointerType::get(i8Ty, *dstAddressSpace);
|
||||
dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
|
||||
|
||||
auto srcMemrefType = op.getSrc().getType().cast<MemRefType>();
|
||||
FailureOr<unsigned> srcAddressSpace =
|
||||
getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
|
||||
if (failed(srcAddressSpace))
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "source memref address space not convertible to integer");
|
||||
|
||||
Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
|
||||
adaptor.getSrcIndices(), rewriter);
|
||||
auto srcPointerType =
|
||||
LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
|
||||
auto srcPointerType = LLVM::LLVMPointerType::get(i8Ty, *srcAddressSpace);
|
||||
scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
|
||||
// Intrinsics takes a global pointer so we need an address space cast.
|
||||
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
|
||||
|
||||
@@ -742,6 +742,8 @@ convertTransferReadToLoads(vector::TransferReadOp op,
|
||||
if (failed(warpMatrixInfo))
|
||||
return failure();
|
||||
|
||||
Attribute memorySpace =
|
||||
op.getSource().getType().cast<MemRefType>().getMemorySpace();
|
||||
bool isLdMatrixCompatible =
|
||||
isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
|
||||
nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
@@ -92,21 +93,25 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
||||
}
|
||||
|
||||
// Check if the last stride is non-unit or the memory space is not zero.
|
||||
static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
|
||||
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
|
||||
LLVMTypeConverter &converter) {
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
|
||||
if (failed(successStrides) || strides.back() != 1 ||
|
||||
memRefType.getMemorySpaceAsInt() != 0)
|
||||
FailureOr<unsigned> addressSpace =
|
||||
converter.getMemRefAddressSpace(memRefType);
|
||||
if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) ||
|
||||
*addressSpace != 0)
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Add an index vector component to a base pointer.
|
||||
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
MemRefType memRefType, Value llvmMemref, Value base,
|
||||
Value index, uint64_t vLen) {
|
||||
assert(succeeded(isMemRefTypeSupported(memRefType)) &&
|
||||
assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
|
||||
"unsupported memref type");
|
||||
auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
|
||||
auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
|
||||
@@ -116,8 +121,10 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
|
||||
// Casts a strided element pointer to a vector pointer. The vector pointer
|
||||
// will be in the same address space as the incoming memref type.
|
||||
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value ptr, MemRefType memRefType, Type vt) {
|
||||
auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
|
||||
Value ptr, MemRefType memRefType, Type vt,
|
||||
LLVMTypeConverter &converter) {
|
||||
unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType);
|
||||
auto pType = LLVM::LLVMPointerType::get(vt, addressSpace);
|
||||
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
|
||||
}
|
||||
|
||||
@@ -245,7 +252,8 @@ public:
|
||||
.template cast<VectorType>();
|
||||
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype,
|
||||
*this->getTypeConverter());
|
||||
|
||||
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
|
||||
return success();
|
||||
@@ -264,7 +272,7 @@ public:
|
||||
MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
|
||||
assert(memRefType && "The base should be bufferized");
|
||||
|
||||
if (failed(isMemRefTypeSupported(memRefType)))
|
||||
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
|
||||
return failure();
|
||||
|
||||
auto loc = gather->getLoc();
|
||||
@@ -283,8 +291,8 @@ public:
|
||||
if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
|
||||
auto vType = gather.getVectorType();
|
||||
// Resolve address.
|
||||
Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
|
||||
adaptor.getIndexVec(),
|
||||
Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
|
||||
memRefType, base, ptr, adaptor.getIndexVec(),
|
||||
/*vLen=*/vType.getDimSize(0));
|
||||
// Replace with the gather intrinsic.
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
|
||||
@@ -293,11 +301,14 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
auto callback = [align, memRefType, base, ptr, loc, &rewriter](
|
||||
Type llvm1DVectorTy, ValueRange vectorOperands) {
|
||||
LLVMTypeConverter &typeConverter = *this->getTypeConverter();
|
||||
auto callback = [align, memRefType, base, ptr, loc, &rewriter,
|
||||
&typeConverter](Type llvm1DVectorTy,
|
||||
ValueRange vectorOperands) {
|
||||
// Resolve address.
|
||||
Value ptrs = getIndexedPtrs(
|
||||
rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
|
||||
rewriter, loc, typeConverter, memRefType, base, ptr,
|
||||
/*index=*/vectorOperands[0],
|
||||
LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
|
||||
// Create the gather intrinsic.
|
||||
return rewriter.create<LLVM::masked_gather>(
|
||||
@@ -323,7 +334,7 @@ public:
|
||||
auto loc = scatter->getLoc();
|
||||
MemRefType memRefType = scatter.getMemRefType();
|
||||
|
||||
if (failed(isMemRefTypeSupported(memRefType)))
|
||||
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
|
||||
return failure();
|
||||
|
||||
// Resolve alignment.
|
||||
@@ -335,9 +346,9 @@ public:
|
||||
VectorType vType = scatter.getVectorType();
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
Value ptrs =
|
||||
getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
|
||||
adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
|
||||
Value ptrs = getIndexedPtrs(
|
||||
rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
|
||||
ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
|
||||
|
||||
// Replace with the scatter intrinsic.
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
@@ -48,7 +49,16 @@ void AMDGPUDialect::initialize() {
|
||||
template <typename T>
|
||||
static LogicalResult verifyRawBufferOp(T &op) {
|
||||
MemRefType bufferType = op.getMemref().getType().template cast<MemRefType>();
|
||||
if (bufferType.getMemorySpaceAsInt() != 0)
|
||||
Attribute memorySpace = bufferType.getMemorySpace();
|
||||
bool isGlobal = false;
|
||||
if (!memorySpace)
|
||||
isGlobal = true;
|
||||
else if (auto intMemorySpace = memorySpace.dyn_cast<IntegerAttr>())
|
||||
isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
|
||||
else if (auto gpuMemorySpace = memorySpace.dyn_cast<gpu::AddressSpaceAttr>())
|
||||
isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
|
||||
|
||||
if (!isGlobal)
|
||||
return op.emitOpError(
|
||||
"Buffer ops must operate on a memref in global memory");
|
||||
if (!bufferType.hasRank())
|
||||
|
||||
@@ -11,6 +11,8 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithDialect
|
||||
# Needed for GPU address space enum definition
|
||||
MLIRGPUOps
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
||||
|
||||
@@ -52,7 +52,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
|
||||
Transforms/SerializeToBlob.cpp
|
||||
Transforms/SerializeToCubin.cpp
|
||||
Transforms/SerializeToHsaco.cpp
|
||||
Transforms/LowerMemorySpaceAttributes.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
//===- LowerMemorySpaceAttributes.cpp ------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
///
|
||||
/// Implementation of a pass that rewrites the IR so that uses of
|
||||
/// `gpu::AddressSpaceAttr` in memref memory space annotations are replaced
|
||||
/// with caller-specified numeric values.
|
||||
///
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/Transforms/Passes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_GPULOWERMEMORYSPACEATTRIBUTESPASS
|
||||
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::gpu;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion Target
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if the given `type` is considered as legal during memory space
|
||||
/// attribute lowering.
|
||||
static bool isLegalType(Type type) {
|
||||
if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
|
||||
return !memRefType.getMemorySpace()
|
||||
.isa_and_nonnull<gpu::AddressSpaceAttr>();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if the given `attr` is considered legal during memory space
|
||||
/// attribute lowering.
|
||||
static bool isLegalAttr(Attribute attr) {
|
||||
if (auto typeAttr = attr.dyn_cast<TypeAttr>())
|
||||
return isLegalType(typeAttr.getValue());
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if the given `op` is legal during memory space attribute
|
||||
/// lowering.
|
||||
static bool isLegalOp(Operation *op) {
|
||||
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
|
||||
return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
|
||||
llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
|
||||
llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
|
||||
isLegalType);
|
||||
}
|
||||
|
||||
auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
|
||||
return attr.getValue();
|
||||
});
|
||||
|
||||
return llvm::all_of(op->getOperandTypes(), isLegalType) &&
|
||||
llvm::all_of(op->getResultTypes(), isLegalType) &&
|
||||
llvm::all_of(attrs, isLegalAttr);
|
||||
}
|
||||
|
||||
void gpu::populateLowerMemorySpaceOpLegality(ConversionTarget &target) {
|
||||
target.markUnknownOpDynamicallyLegal(isLegalOp);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Converter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
|
||||
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
|
||||
}
|
||||
|
||||
void mlir::gpu::populateMemorySpaceAttributeTypeConversions(
|
||||
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
|
||||
typeConverter.addConversion([mapping](Type type) {
|
||||
return type.replace([mapping](Attribute attr) -> std::optional<Attribute> {
|
||||
auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
|
||||
if (!memorySpaceAttr)
|
||||
return std::nullopt;
|
||||
auto newValue = wrapNumericMemorySpace(
|
||||
attr.getContext(), mapping(memorySpaceAttr.getValue()));
|
||||
return newValue;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Converts any op that has operands/results/attributes with numeric MemRef
|
||||
/// memory spaces.
|
||||
struct LowerMemRefAddressSpacePattern final : public ConversionPattern {
|
||||
LowerMemRefAddressSpacePattern(MLIRContext *context, TypeConverter &converter)
|
||||
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
SmallVector<NamedAttribute> newAttrs;
|
||||
newAttrs.reserve(op->getAttrs().size());
|
||||
for (auto attr : op->getAttrs()) {
|
||||
if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
|
||||
auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
|
||||
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
|
||||
} else {
|
||||
newAttrs.push_back(attr);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> newResults;
|
||||
(void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
|
||||
|
||||
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
|
||||
newResults, newAttrs, op->getSuccessors());
|
||||
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
Region *newRegion = state.addRegion();
|
||||
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
|
||||
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
|
||||
(void)getTypeConverter()->convertSignatureArgs(
|
||||
newRegion->getArgumentTypes(), result);
|
||||
rewriter.applySignatureConversion(newRegion, result);
|
||||
}
|
||||
|
||||
Operation *newOp = rewriter.create(state);
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::gpu::populateMemorySpaceLoweringPatterns(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<LowerMemRefAddressSpacePattern>(patterns.getContext(),
|
||||
typeConverter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class LowerMemorySpaceAttributesPass
|
||||
: public mlir::impl::GPULowerMemorySpaceAttributesPassBase<
|
||||
LowerMemorySpaceAttributesPass> {
|
||||
public:
|
||||
using Base::Base;
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
Operation *op = getOperation();
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
populateLowerMemorySpaceOpLegality(target);
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type t) { return t; });
|
||||
populateMemorySpaceAttributeTypeConversions(
|
||||
typeConverter, [this](AddressSpace space) -> unsigned {
|
||||
switch (space) {
|
||||
case AddressSpace::Global:
|
||||
return globalAddrSpace;
|
||||
case AddressSpace::Workgroup:
|
||||
return workgroupAddrSpace;
|
||||
case AddressSpace::Private:
|
||||
return privateAddrSpace;
|
||||
}
|
||||
llvm_unreachable("unknown address space enum value");
|
||||
return 0;
|
||||
});
|
||||
RewritePatternSet patterns(context);
|
||||
populateMemorySpaceLoweringPatterns(typeConverter, patterns);
|
||||
if (failed(applyFullConversion(op, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
@@ -3053,6 +3053,54 @@ auto TypeConverter::convertBlockSignature(Block *block)
|
||||
return conversion;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type attribute conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
TypeConverter::AttributeConversionResult
|
||||
TypeConverter::AttributeConversionResult::result(Attribute attr) {
|
||||
return AttributeConversionResult(attr, resultTag);
|
||||
}
|
||||
|
||||
TypeConverter::AttributeConversionResult
|
||||
TypeConverter::AttributeConversionResult::na() {
|
||||
return AttributeConversionResult(nullptr, naTag);
|
||||
}
|
||||
|
||||
TypeConverter::AttributeConversionResult
|
||||
TypeConverter::AttributeConversionResult::abort() {
|
||||
return AttributeConversionResult(nullptr, abortTag);
|
||||
}
|
||||
|
||||
bool TypeConverter::AttributeConversionResult::hasResult() const {
|
||||
return impl.getInt() == resultTag;
|
||||
}
|
||||
|
||||
bool TypeConverter::AttributeConversionResult::isNa() const {
|
||||
return impl.getInt() == naTag;
|
||||
}
|
||||
|
||||
bool TypeConverter::AttributeConversionResult::isAbort() const {
|
||||
return impl.getInt() == abortTag;
|
||||
}
|
||||
|
||||
Attribute TypeConverter::AttributeConversionResult::getResult() const {
|
||||
assert(hasResult() && "Cannot get result from N/A or abort");
|
||||
return impl.getPointer();
|
||||
}
|
||||
|
||||
Optional<Attribute> TypeConverter::convertTypeAttribute(Type type,
|
||||
Attribute attr) {
|
||||
for (TypeAttributeConversionCallbackFn &fn :
|
||||
llvm::reverse(typeAttributeConversions)) {
|
||||
AttributeConversionResult res = fn(type, attr);
|
||||
if (res.hasResult())
|
||||
return res.getResult();
|
||||
if (res.isAbort())
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FunctionOpInterfaceSignatureConversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
48
mlir/test/Conversion/GPUCommon/lower-memory-space-attrs.mlir
Normal file
48
mlir/test/Conversion/GPUCommon/lower-memory-space-attrs.mlir
Normal file
@@ -0,0 +1,48 @@
|
||||
// RUN: mlir-opt %s -split-input-file -convert-gpu-to-rocdl | FileCheck %s --check-prefixes=CHECK,ROCDL
|
||||
// RUN: mlir-opt %s -split-input-file -convert-gpu-to-nvvm | FileCheck %s --check-prefixes=CHECK,NVVM
|
||||
|
||||
gpu.module @kernel {
|
||||
gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space<private>>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<private>>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @private
|
||||
// CHECK: llvm.store
|
||||
// ROCDL-SAME: : !llvm.ptr<f32, 5>
|
||||
// NVVM-SAME: : !llvm.ptr<f32>
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @kernel {
|
||||
gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space<workgroup>>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<workgroup>>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @workgroup
|
||||
// CHECK: llvm.store
|
||||
// CHECK-SAME: : !llvm.ptr<f32, 3>
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @kernel {
|
||||
gpu.func @nested_memref(%arg0: memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>) -> f32 {
|
||||
%c0 = arith.constant 0 : index
|
||||
%inner = memref.load %arg0[%c0] : memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>
|
||||
%value = memref.load %inner[%c0] : memref<4xf32, #gpu.address_space<global>>
|
||||
gpu.return %value : f32
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @nested_memref
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: : !llvm.ptr<{{.*}}, 1>
|
||||
// CHECK: [[value:%.+]] = llvm.load
|
||||
// CHECK-SAME: : !llvm.ptr<f32, 1>
|
||||
// CHECK: llvm.return [[value]]
|
||||
14
mlir/test/Conversion/MemRefToLLVM/invalid.mlir
Normal file
14
mlir/test/Conversion/MemRefToLLVM/invalid.mlir
Normal file
@@ -0,0 +1,14 @@
|
||||
// RUN: mlir-opt %s -finalize-memref-to-llvm -split-input-file 2>&1 | FileCheck %s
|
||||
// Since the error is at an unknown location, we use FileCheck instead of
|
||||
// -veri-y-diagnostics here
|
||||
|
||||
// CHECK: conversion of memref memory space "foo" to integer address space failed. Consider adding memory space conversions.
|
||||
// CHECK-LABEL: @bad_address_space
|
||||
func.func @bad_address_space(%a: memref<2xindex, "foo">) {
|
||||
%c0 = arith.constant 0 : index
|
||||
// CHECK: memref.store
|
||||
memref.store %c0, %a[%c0] : memref<2xindex, "foo">
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -1,55 +0,0 @@
|
||||
// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes | FileCheck %s
|
||||
// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes="private=0 global=0" | FileCheck %s --check-prefix=CUDA
|
||||
|
||||
gpu.module @kernel {
|
||||
gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space<private>>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<private>>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @private
|
||||
// CHECK-SAME: private(%{{.+}}: memref<4xf32, 5>)
|
||||
// CHECK: memref.store
|
||||
// CHECK-SAME: : memref<4xf32, 5>
|
||||
|
||||
// CUDA: gpu.func @private
|
||||
// CUDA-SAME: private(%{{.+}}: memref<4xf32>)
|
||||
// CUDA: memref.store
|
||||
// CUDA-SAME: : memref<4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @kernel {
|
||||
gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space<workgroup>>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<workgroup>>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @workgroup
|
||||
// CHECK-SAME: workgroup(%{{.+}}: memref<4xf32, 3>)
|
||||
// CHECK: memref.store
|
||||
// CHECK-SAME: : memref<4xf32, 3>
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @kernel {
|
||||
gpu.func @nested_memref(%arg0: memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.load %arg0[%c0] : memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @nested_memref
|
||||
// CHECK-SAME: (%{{.+}}: memref<4xmemref<4xf32, 1>, 1>)
|
||||
// CHECK: memref.load
|
||||
// CHECK-SAME: : memref<4xmemref<4xf32, 1>, 1>
|
||||
|
||||
// CUDA: gpu.func @nested_memref
|
||||
// CUDA-SAME: (%{{.+}}: memref<4xmemref<4xf32>>)
|
||||
// CUDA: memref.load
|
||||
// CUDA-SAME: : memref<4xmemref<4xf32>>
|
||||
Reference in New Issue
Block a user