mirror of
https://github.com/intel/llvm.git
synced 2026-01-21 04:14:03 +08:00
[mlir][ptr] Add ConstantOp with NullAttr and AddressAttr support (#157347)
This patch introduces the `ptr.constant` operation. It also adds the
`NullAttr` and `AddressAttr` for representing null pointers, and integer
raw addresses.
It also implements LLVM IR translation for `ptr.constant` with
`#ptr.null` or `#ptr.address` attributes.
Finally, it extends `FieldParser` to support APInt parsing.
Example:
```mlir
llvm.func @constant_address_op() ->
!llvm.struct<(!ptr.ptr<#llvm.address_space<0>>,
!ptr.ptr<#llvm.address_space<1>>,
!ptr.ptr<#llvm.address_space<2>>)> {
%0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
%1 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
%2 = ptr.constant #ptr.address<3735928559> : !ptr.ptr<#llvm.address_space<2>>
%3 = llvm.mlir.poison : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
%4 = llvm.insertvalue %0, %3[0] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
%5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
%6 = llvm.insertvalue %2, %5[2] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
llvm.return %6 : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
}
```
Result of translation to LLVM IR:
```llvm
define { ptr, ptr addrspace(1), ptr addrspace(2) } @constant_address_op() {
ret { ptr, ptr addrspace(1), ptr addrspace(2) } { ptr null, ptr addrspace(1) inttoptr (i64 4096 to ptr addrspace(1)), ptr addrspace(2) inttoptr (i64 3735928559 to ptr addrspace(2)) }
}
```
This patch also changes all the `convert*` occurrences in function names
or comments to `translate` in the PtrToLLVM file.
---------
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
This commit is contained in:
@@ -22,6 +22,34 @@ class Ptr_Attr<string name, string attrMnemonic,
|
||||
let mnemonic = attrMnemonic;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddressAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Ptr_AddressAttr : Ptr_Attr<"Address", "address", [
|
||||
DeclareAttrInterfaceMethods<TypedAttrInterface>
|
||||
]> {
|
||||
let summary = "Address attribute";
|
||||
let description = [{
|
||||
The `address` attribute represents a raw memory address, expressed in bytes.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
#ptr.address<0x1000> : !ptr.ptr<#ptr.generic_space>
|
||||
```
|
||||
}];
|
||||
let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type,
|
||||
APIntParameter<"">:$value);
|
||||
let builders = [
|
||||
AttrBuilderWithInferredContext<(ins "PtrType":$type,
|
||||
"const llvm::APInt &":$value), [{
|
||||
return $_get(type.getContext(), type, value);
|
||||
}]>
|
||||
];
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericSpaceAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -37,16 +65,42 @@ def Ptr_GenericSpaceAttr :
|
||||
- Load and store operations are always valid, regardless of the type.
|
||||
- Atomic operations are always valid, regardless of the type.
|
||||
- Cast operations to `generic_space` are always valid.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
#ptr.generic_space
|
||||
#ptr.generic_space : !ptr.ptr<#ptr.generic_space>
|
||||
```
|
||||
}];
|
||||
let assemblyFormat = "";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NullAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Ptr_NullAttr : Ptr_Attr<"Null", "null", [
|
||||
DeclareAttrInterfaceMethods<TypedAttrInterface>
|
||||
]> {
|
||||
let summary = "Null pointer attribute";
|
||||
let description = [{
|
||||
The `null` attribute represents a null pointer.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
#ptr.null
|
||||
```
|
||||
}];
|
||||
let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type);
|
||||
let builders = [
|
||||
AttrBuilderWithInferredContext<(ins "PtrType":$type), [{
|
||||
return $_get(type.getContext(), type);
|
||||
}]>
|
||||
];
|
||||
let assemblyFormat = "";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SpecAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -62,7 +116,7 @@ def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> {
|
||||
- [Optional] index: bitwidth that should be used when performing index
|
||||
computations for the type. Setting the field to `kOptionalSpecValue`, means
|
||||
the field is optional.
|
||||
|
||||
|
||||
Furthermore, the attribute will verify that all present values are divisible
|
||||
by 8 (number of bits in a byte), and that `preferred` > `abi`.
|
||||
|
||||
|
||||
@@ -21,6 +21,12 @@
|
||||
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
|
||||
#include "mlir/Dialect/Ptr/IR/PtrEnums.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace ptr {
|
||||
class PtrType;
|
||||
} // namespace ptr
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
|
||||
/*cppType=*/"::mlir::ShapedType">;
|
||||
|
||||
// A ptr-like type, either scalar or shaped type with value semantics.
|
||||
def Ptr_PtrLikeType :
|
||||
def Ptr_PtrLikeType :
|
||||
AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
|
||||
|
||||
// An int-like type, either scalar or shaped type with value semantics.
|
||||
@@ -57,6 +57,31 @@ def Ptr_Mask1DType :
|
||||
def Ptr_Ptr1DType :
|
||||
Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Ptr_ConstantOp : Pointer_Op<"constant", [
|
||||
ConstantLike, Pure, AllTypesMatch<["value", "result"]>
|
||||
]> {
|
||||
let summary = "Pointer constant operation";
|
||||
let description = [{
|
||||
The `constant` operation produces a pointer constant. The attribute must be
|
||||
a typed attribute of pointer type.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Create a null pointer
|
||||
%null = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins TypedAttrInterface:$value);
|
||||
let results = (outs Ptr_PtrType:$result);
|
||||
let assemblyFormat = "attr-dict $value";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FromPtrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -81,7 +106,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
|
||||
```mlir
|
||||
%typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr<f32, #ptr.generic_space>
|
||||
%memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
|
||||
|
||||
|
||||
// Cast the `%ptr` to a memref without utilizing metadata.
|
||||
%memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
|
||||
```
|
||||
@@ -361,13 +386,13 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
|
||||
// Scalar base and offset
|
||||
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
|
||||
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
|
||||
|
||||
|
||||
// Shaped base with scalar offset
|
||||
%ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
|
||||
|
||||
|
||||
// Scalar base with shaped offset
|
||||
%x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
|
||||
|
||||
|
||||
// Both base and offset are shaped
|
||||
%ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
|
||||
```
|
||||
@@ -382,7 +407,7 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let extraClassDeclaration = [{
|
||||
/// `ViewLikeOp::getViewSource` method.
|
||||
/// `ViewLikeOp::getViewSource` method.
|
||||
Value getViewSource() { return getBase(); }
|
||||
|
||||
/// Returns the ptr type of the operation.
|
||||
@@ -418,7 +443,7 @@ def Ptr_ScatterOp : Pointer_Op<"scatter", [
|
||||
// Scatter values to multiple memory locations
|
||||
ptr.scatter %value, %ptrs, %mask :
|
||||
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
|
||||
|
||||
|
||||
// Scatter with alignment
|
||||
ptr.scatter %value, %ptrs, %mask alignment = 8 :
|
||||
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
|
||||
|
||||
@@ -103,10 +103,11 @@ struct FieldParser<
|
||||
|
||||
/// Parse any integer.
|
||||
template <typename IntT>
|
||||
struct FieldParser<IntT,
|
||||
std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
|
||||
struct FieldParser<IntT, std::enable_if_t<(std::is_integral<IntT>::value ||
|
||||
std::is_same_v<IntT, llvm::APInt>),
|
||||
IntT>> {
|
||||
static FailureOr<IntT> parse(AsmParser &parser) {
|
||||
IntT value = 0;
|
||||
IntT value{};
|
||||
if (parser.parseInteger(value))
|
||||
return failure();
|
||||
return value;
|
||||
|
||||
@@ -56,6 +56,12 @@ verifyAlignment(std::optional<int64_t> alignment,
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FromPtrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -29,7 +29,7 @@ namespace {
|
||||
|
||||
/// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
|
||||
static llvm::AtomicOrdering
|
||||
convertAtomicOrdering(ptr::AtomicOrdering ordering) {
|
||||
translateAtomicOrdering(ptr::AtomicOrdering ordering) {
|
||||
switch (ordering) {
|
||||
case ptr::AtomicOrdering::not_atomic:
|
||||
return llvm::AtomicOrdering::NotAtomic;
|
||||
@@ -49,10 +49,10 @@ convertAtomicOrdering(ptr::AtomicOrdering ordering) {
|
||||
llvm_unreachable("Unknown atomic ordering");
|
||||
}
|
||||
|
||||
/// Convert ptr.ptr_add operation
|
||||
/// Translate ptr.ptr_add operation to LLVM IR.
|
||||
static LogicalResult
|
||||
convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
translatePtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
|
||||
llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
|
||||
|
||||
@@ -83,18 +83,19 @@ convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Convert ptr.load operation
|
||||
static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
/// Translate ptr.load operation to LLVM IR.
|
||||
static LogicalResult
|
||||
translateLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
|
||||
if (!ptr)
|
||||
return loadOp.emitError("Failed to lookup pointer operand");
|
||||
|
||||
// Convert result type to LLVM type
|
||||
// Translate result type to LLVM type
|
||||
llvm::Type *resultType =
|
||||
moduleTranslation.convertType(loadOp.getValue().getType());
|
||||
if (!resultType)
|
||||
return loadOp.emitError("Failed to convert result type");
|
||||
return loadOp.emitError("Failed to translate result type");
|
||||
|
||||
// Create the load instruction.
|
||||
llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
|
||||
@@ -102,7 +103,7 @@ static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
|
||||
resultType, ptr, alignment, loadOp.getVolatile_());
|
||||
|
||||
// Set op flags and metadata.
|
||||
loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering()));
|
||||
loadInst->setAtomic(translateAtomicOrdering(loadOp.getOrdering()));
|
||||
// Set sync scope if specified
|
||||
if (loadOp.getSyncscope().has_value()) {
|
||||
llvm::LLVMContext &ctx = builder.getContext();
|
||||
@@ -135,10 +136,10 @@ static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Convert ptr.store operation
|
||||
/// Translate ptr.store operation to LLVM IR.
|
||||
static LogicalResult
|
||||
convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
translateStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
|
||||
llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
|
||||
|
||||
@@ -151,7 +152,7 @@ convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
|
||||
builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
|
||||
|
||||
// Set op flags and metadata.
|
||||
storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering()));
|
||||
storeInst->setAtomic(translateAtomicOrdering(storeOp.getOrdering()));
|
||||
// Set sync scope if specified
|
||||
if (storeOp.getSyncscope().has_value()) {
|
||||
llvm::LLVMContext &ctx = builder.getContext();
|
||||
@@ -178,21 +179,21 @@ convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Convert ptr.type_offset operation
|
||||
/// Translate ptr.type_offset operation to LLVM IR.
|
||||
static LogicalResult
|
||||
convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
// Convert the element type to LLVM type
|
||||
translateTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
// Translate the element type to LLVM type
|
||||
llvm::Type *elementType =
|
||||
moduleTranslation.convertType(typeOffsetOp.getElementType());
|
||||
if (!elementType)
|
||||
return typeOffsetOp.emitError("Failed to convert the element type");
|
||||
return typeOffsetOp.emitError("Failed to translate the element type");
|
||||
|
||||
// Convert result type
|
||||
// Translate result type
|
||||
llvm::Type *resultType =
|
||||
moduleTranslation.convertType(typeOffsetOp.getResult().getType());
|
||||
if (!resultType)
|
||||
return typeOffsetOp.emitError("Failed to convert the result type");
|
||||
return typeOffsetOp.emitError("Failed to translate the result type");
|
||||
|
||||
// Use GEP with null pointer to compute type size/offset.
|
||||
llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
|
||||
@@ -204,10 +205,10 @@ convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Convert ptr.gather operation
|
||||
/// Translate ptr.gather operation to LLVM IR.
|
||||
static LogicalResult
|
||||
convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
translateGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
|
||||
llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
|
||||
llvm::Value *passthrough =
|
||||
@@ -216,11 +217,11 @@ convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
|
||||
if (!ptrs || !mask || !passthrough)
|
||||
return gatherOp.emitError("Failed to lookup operands");
|
||||
|
||||
// Convert result type to LLVM type.
|
||||
// Translate result type to LLVM type.
|
||||
llvm::Type *resultType =
|
||||
moduleTranslation.convertType(gatherOp.getResult().getType());
|
||||
if (!resultType)
|
||||
return gatherOp.emitError("Failed to convert result type");
|
||||
return gatherOp.emitError("Failed to translate result type");
|
||||
|
||||
// Get the alignment.
|
||||
llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
|
||||
@@ -233,10 +234,10 @@ convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Convert ptr.masked_load operation
|
||||
/// Translate ptr.masked_load operation to LLVM IR.
|
||||
static LogicalResult
|
||||
convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
translateMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
|
||||
llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
|
||||
llvm::Value *passthrough =
|
||||
@@ -245,11 +246,11 @@ convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
|
||||
if (!ptr || !mask || !passthrough)
|
||||
return maskedLoadOp.emitError("Failed to lookup operands");
|
||||
|
||||
// Convert result type to LLVM type.
|
||||
// Translate result type to LLVM type.
|
||||
llvm::Type *resultType =
|
||||
moduleTranslation.convertType(maskedLoadOp.getResult().getType());
|
||||
if (!resultType)
|
||||
return maskedLoadOp.emitError("Failed to convert result type");
|
||||
return maskedLoadOp.emitError("Failed to translate result type");
|
||||
|
||||
// Get the alignment.
|
||||
llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0));
|
||||
@@ -262,10 +263,11 @@ convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Convert ptr.masked_store operation
|
||||
/// Translate ptr.masked_store operation to LLVM IR.
|
||||
static LogicalResult
|
||||
convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
translateMaskedStoreOp(MaskedStoreOp maskedStoreOp,
|
||||
llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue());
|
||||
llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr());
|
||||
llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask());
|
||||
@@ -281,10 +283,10 @@ convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Convert ptr.scatter operation
|
||||
/// Translate ptr.scatter operation to LLVM IR.
|
||||
static LogicalResult
|
||||
convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
translateScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue());
|
||||
llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs());
|
||||
llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask());
|
||||
@@ -300,7 +302,56 @@ convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Implementation of the dialect interface that converts operations belonging
|
||||
/// Translate ptr.constant operation to LLVM IR.
|
||||
static LogicalResult
|
||||
translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
// Translate result type to LLVM type
|
||||
llvm::PointerType *resultType = dyn_cast_or_null<llvm::PointerType>(
|
||||
moduleTranslation.convertType(constantOp.getResult().getType()));
|
||||
if (!resultType)
|
||||
return constantOp.emitError("Expected a valid pointer type");
|
||||
|
||||
llvm::Value *result = nullptr;
|
||||
|
||||
TypedAttr value = constantOp.getValue();
|
||||
if (auto nullAttr = dyn_cast<ptr::NullAttr>(value)) {
|
||||
// Create a null pointer constant
|
||||
result = llvm::ConstantPointerNull::get(resultType);
|
||||
} else if (auto addressAttr = dyn_cast<ptr::AddressAttr>(value)) {
|
||||
// Create an integer constant and translate it to pointer
|
||||
llvm::APInt addressValue = addressAttr.getValue();
|
||||
|
||||
// Determine the integer type width based on the target's pointer size
|
||||
llvm::DataLayout dataLayout =
|
||||
moduleTranslation.getLLVMModule()->getDataLayout();
|
||||
unsigned pointerSizeInBits =
|
||||
dataLayout.getPointerSizeInBits(resultType->getAddressSpace());
|
||||
|
||||
// Extend or truncate the address value to match pointer size if needed
|
||||
if (addressValue.getBitWidth() != pointerSizeInBits) {
|
||||
if (addressValue.getBitWidth() > pointerSizeInBits) {
|
||||
constantOp.emitWarning()
|
||||
<< "Truncating address value to fit pointer size";
|
||||
}
|
||||
addressValue = addressValue.getBitWidth() < pointerSizeInBits
|
||||
? addressValue.zext(pointerSizeInBits)
|
||||
: addressValue.trunc(pointerSizeInBits);
|
||||
}
|
||||
|
||||
// Create integer constant and translate to pointer
|
||||
llvm::Type *intType = builder.getIntNTy(pointerSizeInBits);
|
||||
llvm::Value *intValue = llvm::ConstantInt::get(intType, addressValue);
|
||||
result = builder.CreateIntToPtr(intValue, resultType);
|
||||
} else {
|
||||
return constantOp.emitError("Unsupported constant attribute type");
|
||||
}
|
||||
|
||||
moduleTranslation.mapValue(constantOp.getResult(), result);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Implementation of the dialect interface that translates operations belonging
|
||||
/// to the `ptr` dialect to LLVM IR.
|
||||
class PtrDialectLLVMIRTranslationInterface
|
||||
: public LLVMTranslationDialectInterface {
|
||||
@@ -314,30 +365,35 @@ public:
|
||||
LLVM::ModuleTranslation &moduleTranslation) const final {
|
||||
|
||||
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
|
||||
.Case([&](ConstantOp constantOp) {
|
||||
return translateConstantOp(constantOp, builder, moduleTranslation);
|
||||
})
|
||||
.Case([&](PtrAddOp ptrAddOp) {
|
||||
return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
|
||||
return translatePtrAddOp(ptrAddOp, builder, moduleTranslation);
|
||||
})
|
||||
.Case([&](LoadOp loadOp) {
|
||||
return convertLoadOp(loadOp, builder, moduleTranslation);
|
||||
return translateLoadOp(loadOp, builder, moduleTranslation);
|
||||
})
|
||||
.Case([&](StoreOp storeOp) {
|
||||
return convertStoreOp(storeOp, builder, moduleTranslation);
|
||||
return translateStoreOp(storeOp, builder, moduleTranslation);
|
||||
})
|
||||
.Case([&](TypeOffsetOp typeOffsetOp) {
|
||||
return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
|
||||
return translateTypeOffsetOp(typeOffsetOp, builder,
|
||||
moduleTranslation);
|
||||
})
|
||||
.Case<GatherOp>([&](GatherOp gatherOp) {
|
||||
return convertGatherOp(gatherOp, builder, moduleTranslation);
|
||||
return translateGatherOp(gatherOp, builder, moduleTranslation);
|
||||
})
|
||||
.Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
|
||||
return convertMaskedLoadOp(maskedLoadOp, builder, moduleTranslation);
|
||||
return translateMaskedLoadOp(maskedLoadOp, builder,
|
||||
moduleTranslation);
|
||||
})
|
||||
.Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
|
||||
return convertMaskedStoreOp(maskedStoreOp, builder,
|
||||
moduleTranslation);
|
||||
return translateMaskedStoreOp(maskedStoreOp, builder,
|
||||
moduleTranslation);
|
||||
})
|
||||
.Case<ScatterOp>([&](ScatterOp scatterOp) {
|
||||
return convertScatterOp(scatterOp, builder, moduleTranslation);
|
||||
return translateScatterOp(scatterOp, builder, moduleTranslation);
|
||||
})
|
||||
.Default([&](Operation *op) {
|
||||
return op->emitError("Translation for operation '")
|
||||
|
||||
@@ -114,7 +114,7 @@ func.func @masked_store_ops_tensor(%value: tensor<8xi64>, %ptr: !ptr.ptr<#ptr.ge
|
||||
}
|
||||
|
||||
/// Test operations with LLVM address space
|
||||
func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
|
||||
func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
|
||||
%mask: vector<4xi1>, %value: vector<4xf32>, %passthrough: vector<4xf32>) -> vector<4xf32> {
|
||||
// Gather from shared memory (address space 3)
|
||||
%0 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf32>
|
||||
@@ -189,3 +189,25 @@ func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.gener
|
||||
%res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
|
||||
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
|
||||
}
|
||||
|
||||
/// Test constant operations with null pointer
|
||||
func.func @constant_null_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>) {
|
||||
%null_generic = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
|
||||
%null_as1 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<1>>
|
||||
return %null_generic, %null_as1 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>
|
||||
}
|
||||
|
||||
/// Test constant operations with address values
|
||||
func.func @constant_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>) {
|
||||
%addr_0 = ptr.constant #ptr.address<0> : !ptr.ptr<#ptr.generic_space>
|
||||
%addr_1000 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
|
||||
%addr_deadbeef = ptr.constant #ptr.address<0xDEADBEEF> : !ptr.ptr<#llvm.address_space<3>>
|
||||
return %addr_0, %addr_1000, %addr_deadbeef : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>
|
||||
}
|
||||
|
||||
/// Test constant operations with large address values
|
||||
func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>) {
|
||||
%addr_max32 = ptr.constant #ptr.address<0xFFFFFFFF> : !ptr.ptr<#ptr.generic_space>
|
||||
%addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>>
|
||||
return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>
|
||||
}
|
||||
|
||||
@@ -41,10 +41,10 @@ llvm.func @type_offset(%arg0: !ptr.ptr<#llvm.address_space<0>>) -> !llvm.struct<
|
||||
%2 = ptr.type_offset i16 : i32
|
||||
%3 = ptr.type_offset i32 : i32
|
||||
%4 = llvm.mlir.poison : !llvm.struct<(i32, i32, i32, i32)>
|
||||
%5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
%6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
%7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
%8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
%5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
%6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
%7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
%8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
llvm.return %8 : !llvm.struct<(i32, i32, i32, i32)>
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ llvm.func @scatter_ops_i64(%value: vector<8xi64>, %ptrs: vector<8x!ptr.ptr<#llvm
|
||||
// CHECK-NEXT: call void @llvm.masked.store.v4f64.p3(<4 x double> %[[VALUE_F64]], ptr addrspace(3) %[[PTR_SHARED]], i32 8, <4 x i1> %[[MASK]])
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK-NEXT: }
|
||||
llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
|
||||
llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
|
||||
%mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) {
|
||||
// Test with shared memory address space (3) and f64 elements
|
||||
%0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64>
|
||||
@@ -255,3 +255,29 @@ llvm.func @llvm_ops_with_ptr_nvvm_values(%arg0: !llvm.ptr) {
|
||||
llvm.store %1, %arg0 : !ptr.ptr<#nvvm.memory_space<global>>, !llvm.ptr
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define { ptr, ptr addrspace(1), ptr addrspace(2) } @constant_address_op() {
|
||||
// CHECK-NEXT: ret { ptr, ptr addrspace(1), ptr addrspace(2) } { ptr null, ptr addrspace(1) inttoptr (i64 4096 to ptr addrspace(1)), ptr addrspace(2) inttoptr (i64 3735928559 to ptr addrspace(2)) }
|
||||
llvm.func @constant_address_op() ->
|
||||
!llvm.struct<(!ptr.ptr<#llvm.address_space<0>>,
|
||||
!ptr.ptr<#llvm.address_space<1>>,
|
||||
!ptr.ptr<#llvm.address_space<2>>)> {
|
||||
%0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
|
||||
%1 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
|
||||
%2 = ptr.constant #ptr.address<3735928559> : !ptr.ptr<#llvm.address_space<2>>
|
||||
%3 = llvm.mlir.poison : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
|
||||
%4 = llvm.insertvalue %0, %3[0] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
|
||||
%5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
|
||||
%6 = llvm.insertvalue %2, %5[2] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
|
||||
llvm.return %6 : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
|
||||
}
|
||||
|
||||
// Test gep folders.
|
||||
// CHECK-LABEL: define ptr @ptr_add_cst() {
|
||||
// CHECK-NEXT: ret ptr inttoptr (i64 42 to ptr)
|
||||
llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
|
||||
%off = llvm.mlir.constant(42 : i32) : i32
|
||||
%ptr = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
|
||||
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
|
||||
llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user