[mlir][ptr] Add ConstantOp with NullAttr and AddressAttr support (#157347)

This patch introduces the `ptr.constant` operation. It also adds the
`NullAttr` and `AddressAttr` for representing null pointers, and integer
raw addresses.

It also implements LLVM IR translation for `ptr.constant` with
`#ptr.null` or `#ptr.address` attributes.

Finally, it extends `FieldParser` to support APInt parsing.

Example:
```mlir
llvm.func @constant_address_op() ->
    !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>,
                  !ptr.ptr<#llvm.address_space<1>>,
                  !ptr.ptr<#llvm.address_space<2>>)> {
  %0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
  %1 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
  %2 = ptr.constant #ptr.address<3735928559> : !ptr.ptr<#llvm.address_space<2>>
  %3 = llvm.mlir.poison : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
  %4 = llvm.insertvalue %0, %3[0] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
  %5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
  %6 = llvm.insertvalue %2, %5[2] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
  llvm.return %6 : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
}
```
Result of translation to LLVM IR:
```llvm
define { ptr, ptr addrspace(1), ptr addrspace(2) } @constant_address_op() {
  ret { ptr, ptr addrspace(1), ptr addrspace(2) } { ptr null, ptr addrspace(1) inttoptr (i64 4096 to ptr addrspace(1)), ptr addrspace(2) inttoptr (i64 3735928559 to ptr addrspace(2)) }
}
```

This patch also changes all the `convert*` occurrences in function names
or comments to `translate` in the PtrToLLVM file.

---------

Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
This commit is contained in:
Fabian Mora
2025-09-14 11:44:59 -04:00
committed by GitHub
parent bfedb4a938
commit 1a65e63c59
8 changed files with 262 additions and 66 deletions

View File

@@ -22,6 +22,34 @@ class Ptr_Attr<string name, string attrMnemonic,
let mnemonic = attrMnemonic;
}
//===----------------------------------------------------------------------===//
// AddressAttr
//===----------------------------------------------------------------------===//
def Ptr_AddressAttr : Ptr_Attr<"Address", "address", [
DeclareAttrInterfaceMethods<TypedAttrInterface>
]> {
let summary = "Address attribute";
let description = [{
The `address` attribute represents a raw memory address, expressed in bytes.
Example:
```mlir
#ptr.address<0x1000> : !ptr.ptr<#ptr.generic_space>
```
}];
let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type,
APIntParameter<"">:$value);
let builders = [
AttrBuilderWithInferredContext<(ins "PtrType":$type,
"const llvm::APInt &":$value), [{
return $_get(type.getContext(), type, value);
}]>
];
let assemblyFormat = "`<` $value `>`";
}
//===----------------------------------------------------------------------===//
// GenericSpaceAttr
//===----------------------------------------------------------------------===//
@@ -37,16 +65,42 @@ def Ptr_GenericSpaceAttr :
- Load and store operations are always valid, regardless of the type.
- Atomic operations are always valid, regardless of the type.
- Cast operations to `generic_space` are always valid.
Example:
```mlir
#ptr.generic_space
#ptr.generic_space : !ptr.ptr<#ptr.generic_space>
```
}];
let assemblyFormat = "";
}
//===----------------------------------------------------------------------===//
// NullAttr
//===----------------------------------------------------------------------===//
def Ptr_NullAttr : Ptr_Attr<"Null", "null", [
DeclareAttrInterfaceMethods<TypedAttrInterface>
]> {
let summary = "Null pointer attribute";
let description = [{
The `null` attribute represents a null pointer.
Example:
```mlir
#ptr.null
```
}];
let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type);
let builders = [
AttrBuilderWithInferredContext<(ins "PtrType":$type), [{
return $_get(type.getContext(), type);
}]>
];
let assemblyFormat = "";
}
//===----------------------------------------------------------------------===//
// SpecAttr
//===----------------------------------------------------------------------===//
@@ -62,7 +116,7 @@ def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> {
- [Optional] index: bitwidth that should be used when performing index
computations for the type. Setting the field to `kOptionalSpecValue`, means
the field is optional.
Furthermore, the attribute will verify that all present values are divisible
by 8 (number of bits in a byte), and that `preferred` > `abi`.

View File

@@ -21,6 +21,12 @@
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
#include "mlir/Dialect/Ptr/IR/PtrEnums.h"
namespace mlir {
namespace ptr {
class PtrType;
} // namespace ptr
} // namespace mlir
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"

View File

@@ -36,7 +36,7 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
/*cppType=*/"::mlir::ShapedType">;
// A ptr-like type, either scalar or shaped type with value semantics.
def Ptr_PtrLikeType :
def Ptr_PtrLikeType :
AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
// An int-like type, either scalar or shaped type with value semantics.
@@ -57,6 +57,31 @@ def Ptr_Mask1DType :
def Ptr_Ptr1DType :
Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
def Ptr_ConstantOp : Pointer_Op<"constant", [
ConstantLike, Pure, AllTypesMatch<["value", "result"]>
]> {
let summary = "Pointer constant operation";
let description = [{
The `constant` operation produces a pointer constant. The attribute must be
a typed attribute of pointer type.
Example:
```mlir
// Create a null pointer
%null = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
```
}];
let arguments = (ins TypedAttrInterface:$value);
let results = (outs Ptr_PtrType:$result);
let assemblyFormat = "attr-dict $value";
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//
@@ -81,7 +106,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
```mlir
%typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr<f32, #ptr.generic_space>
%memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
// Cast the `%ptr` to a memref without utilizing metadata.
%memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
```
@@ -361,13 +386,13 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
// Scalar base and offset
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
// Shaped base with scalar offset
%ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
// Scalar base with shaped offset
%x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
// Both base and offset are shaped
%ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
```
@@ -382,7 +407,7 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}];
let hasFolder = 1;
let extraClassDeclaration = [{
/// `ViewLikeOp::getViewSource` method.
/// `ViewLikeOp::getViewSource` method.
Value getViewSource() { return getBase(); }
/// Returns the ptr type of the operation.
@@ -418,7 +443,7 @@ def Ptr_ScatterOp : Pointer_Op<"scatter", [
// Scatter values to multiple memory locations
ptr.scatter %value, %ptrs, %mask :
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
// Scatter with alignment
ptr.scatter %value, %ptrs, %mask alignment = 8 :
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>

View File

@@ -103,10 +103,11 @@ struct FieldParser<
/// Parse any integer.
template <typename IntT>
struct FieldParser<IntT,
std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
struct FieldParser<IntT, std::enable_if_t<(std::is_integral<IntT>::value ||
std::is_same_v<IntT, llvm::APInt>),
IntT>> {
static FailureOr<IntT> parse(AsmParser &parser) {
IntT value = 0;
IntT value{};
if (parser.parseInteger(value))
return failure();
return value;

View File

@@ -56,6 +56,12 @@ verifyAlignment(std::optional<int64_t> alignment,
return success();
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//

View File

@@ -29,7 +29,7 @@ namespace {
/// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
static llvm::AtomicOrdering
convertAtomicOrdering(ptr::AtomicOrdering ordering) {
translateAtomicOrdering(ptr::AtomicOrdering ordering) {
switch (ordering) {
case ptr::AtomicOrdering::not_atomic:
return llvm::AtomicOrdering::NotAtomic;
@@ -49,10 +49,10 @@ convertAtomicOrdering(ptr::AtomicOrdering ordering) {
llvm_unreachable("Unknown atomic ordering");
}
/// Convert ptr.ptr_add operation
/// Translate ptr.ptr_add operation to LLVM IR.
static LogicalResult
convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
translatePtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
@@ -83,18 +83,19 @@ convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
return success();
}
/// Convert ptr.load operation
static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
/// Translate ptr.load operation to LLVM IR.
static LogicalResult
translateLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
if (!ptr)
return loadOp.emitError("Failed to lookup pointer operand");
// Convert result type to LLVM type
// Translate result type to LLVM type
llvm::Type *resultType =
moduleTranslation.convertType(loadOp.getValue().getType());
if (!resultType)
return loadOp.emitError("Failed to convert result type");
return loadOp.emitError("Failed to translate result type");
// Create the load instruction.
llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
@@ -102,7 +103,7 @@ static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
resultType, ptr, alignment, loadOp.getVolatile_());
// Set op flags and metadata.
loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering()));
loadInst->setAtomic(translateAtomicOrdering(loadOp.getOrdering()));
// Set sync scope if specified
if (loadOp.getSyncscope().has_value()) {
llvm::LLVMContext &ctx = builder.getContext();
@@ -135,10 +136,10 @@ static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
return success();
}
/// Convert ptr.store operation
/// Translate ptr.store operation to LLVM IR.
static LogicalResult
convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
translateStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
@@ -151,7 +152,7 @@ convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
// Set op flags and metadata.
storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering()));
storeInst->setAtomic(translateAtomicOrdering(storeOp.getOrdering()));
// Set sync scope if specified
if (storeOp.getSyncscope().has_value()) {
llvm::LLVMContext &ctx = builder.getContext();
@@ -178,21 +179,21 @@ convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
return success();
}
/// Convert ptr.type_offset operation
/// Translate ptr.type_offset operation to LLVM IR.
static LogicalResult
convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
// Convert the element type to LLVM type
translateTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
// Translate the element type to LLVM type
llvm::Type *elementType =
moduleTranslation.convertType(typeOffsetOp.getElementType());
if (!elementType)
return typeOffsetOp.emitError("Failed to convert the element type");
return typeOffsetOp.emitError("Failed to translate the element type");
// Convert result type
// Translate result type
llvm::Type *resultType =
moduleTranslation.convertType(typeOffsetOp.getResult().getType());
if (!resultType)
return typeOffsetOp.emitError("Failed to convert the result type");
return typeOffsetOp.emitError("Failed to translate the result type");
// Use GEP with null pointer to compute type size/offset.
llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
@@ -204,10 +205,10 @@ convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
return success();
}
/// Convert ptr.gather operation
/// Translate ptr.gather operation to LLVM IR.
static LogicalResult
convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
translateGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
llvm::Value *passthrough =
@@ -216,11 +217,11 @@ convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
if (!ptrs || !mask || !passthrough)
return gatherOp.emitError("Failed to lookup operands");
// Convert result type to LLVM type.
// Translate result type to LLVM type.
llvm::Type *resultType =
moduleTranslation.convertType(gatherOp.getResult().getType());
if (!resultType)
return gatherOp.emitError("Failed to convert result type");
return gatherOp.emitError("Failed to translate result type");
// Get the alignment.
llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
@@ -233,10 +234,10 @@ convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
return success();
}
/// Convert ptr.masked_load operation
/// Translate ptr.masked_load operation to LLVM IR.
static LogicalResult
convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
translateMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
llvm::Value *passthrough =
@@ -245,11 +246,11 @@ convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
if (!ptr || !mask || !passthrough)
return maskedLoadOp.emitError("Failed to lookup operands");
// Convert result type to LLVM type.
// Translate result type to LLVM type.
llvm::Type *resultType =
moduleTranslation.convertType(maskedLoadOp.getResult().getType());
if (!resultType)
return maskedLoadOp.emitError("Failed to convert result type");
return maskedLoadOp.emitError("Failed to translate result type");
// Get the alignment.
llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0));
@@ -262,10 +263,11 @@ convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
return success();
}
/// Convert ptr.masked_store operation
/// Translate ptr.masked_store operation to LLVM IR.
static LogicalResult
convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
translateMaskedStoreOp(MaskedStoreOp maskedStoreOp,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue());
llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr());
llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask());
@@ -281,10 +283,10 @@ convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
return success();
}
/// Convert ptr.scatter operation
/// Translate ptr.scatter operation to LLVM IR.
static LogicalResult
convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
translateScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue());
llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs());
llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask());
@@ -300,7 +302,56 @@ convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
return success();
}
/// Implementation of the dialect interface that converts operations belonging
/// Translate ptr.constant operation to LLVM IR.
static LogicalResult
translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
// Translate result type to LLVM type
llvm::PointerType *resultType = dyn_cast_or_null<llvm::PointerType>(
moduleTranslation.convertType(constantOp.getResult().getType()));
if (!resultType)
return constantOp.emitError("Expected a valid pointer type");
llvm::Value *result = nullptr;
TypedAttr value = constantOp.getValue();
if (auto nullAttr = dyn_cast<ptr::NullAttr>(value)) {
// Create a null pointer constant
result = llvm::ConstantPointerNull::get(resultType);
} else if (auto addressAttr = dyn_cast<ptr::AddressAttr>(value)) {
// Create an integer constant and translate it to pointer
llvm::APInt addressValue = addressAttr.getValue();
// Determine the integer type width based on the target's pointer size
llvm::DataLayout dataLayout =
moduleTranslation.getLLVMModule()->getDataLayout();
unsigned pointerSizeInBits =
dataLayout.getPointerSizeInBits(resultType->getAddressSpace());
// Extend or truncate the address value to match pointer size if needed
if (addressValue.getBitWidth() != pointerSizeInBits) {
if (addressValue.getBitWidth() > pointerSizeInBits) {
constantOp.emitWarning()
<< "Truncating address value to fit pointer size";
}
addressValue = addressValue.getBitWidth() < pointerSizeInBits
? addressValue.zext(pointerSizeInBits)
: addressValue.trunc(pointerSizeInBits);
}
// Create integer constant and translate to pointer
llvm::Type *intType = builder.getIntNTy(pointerSizeInBits);
llvm::Value *intValue = llvm::ConstantInt::get(intType, addressValue);
result = builder.CreateIntToPtr(intValue, resultType);
} else {
return constantOp.emitError("Unsupported constant attribute type");
}
moduleTranslation.mapValue(constantOp.getResult(), result);
return success();
}
/// Implementation of the dialect interface that translates operations belonging
/// to the `ptr` dialect to LLVM IR.
class PtrDialectLLVMIRTranslationInterface
: public LLVMTranslationDialectInterface {
@@ -314,30 +365,35 @@ public:
LLVM::ModuleTranslation &moduleTranslation) const final {
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case([&](ConstantOp constantOp) {
return translateConstantOp(constantOp, builder, moduleTranslation);
})
.Case([&](PtrAddOp ptrAddOp) {
return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
return translatePtrAddOp(ptrAddOp, builder, moduleTranslation);
})
.Case([&](LoadOp loadOp) {
return convertLoadOp(loadOp, builder, moduleTranslation);
return translateLoadOp(loadOp, builder, moduleTranslation);
})
.Case([&](StoreOp storeOp) {
return convertStoreOp(storeOp, builder, moduleTranslation);
return translateStoreOp(storeOp, builder, moduleTranslation);
})
.Case([&](TypeOffsetOp typeOffsetOp) {
return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
return translateTypeOffsetOp(typeOffsetOp, builder,
moduleTranslation);
})
.Case<GatherOp>([&](GatherOp gatherOp) {
return convertGatherOp(gatherOp, builder, moduleTranslation);
return translateGatherOp(gatherOp, builder, moduleTranslation);
})
.Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
return convertMaskedLoadOp(maskedLoadOp, builder, moduleTranslation);
return translateMaskedLoadOp(maskedLoadOp, builder,
moduleTranslation);
})
.Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
return convertMaskedStoreOp(maskedStoreOp, builder,
moduleTranslation);
return translateMaskedStoreOp(maskedStoreOp, builder,
moduleTranslation);
})
.Case<ScatterOp>([&](ScatterOp scatterOp) {
return convertScatterOp(scatterOp, builder, moduleTranslation);
return translateScatterOp(scatterOp, builder, moduleTranslation);
})
.Default([&](Operation *op) {
return op->emitError("Translation for operation '")

View File

@@ -114,7 +114,7 @@ func.func @masked_store_ops_tensor(%value: tensor<8xi64>, %ptr: !ptr.ptr<#ptr.ge
}
/// Test operations with LLVM address space
func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
%mask: vector<4xi1>, %value: vector<4xf32>, %passthrough: vector<4xf32>) -> vector<4xf32> {
// Gather from shared memory (address space 3)
%0 = ptr.gather %ptrs, %mask, %passthrough alignment = 4 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf32>
@@ -189,3 +189,25 @@ func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.gener
%res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
}
/// Test constant operations with null pointer
func.func @constant_null_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>) {
%null_generic = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
%null_as1 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<1>>
return %null_generic, %null_as1 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>
}
/// Test constant operations with address values
func.func @constant_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>) {
%addr_0 = ptr.constant #ptr.address<0> : !ptr.ptr<#ptr.generic_space>
%addr_1000 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
%addr_deadbeef = ptr.constant #ptr.address<0xDEADBEEF> : !ptr.ptr<#llvm.address_space<3>>
return %addr_0, %addr_1000, %addr_deadbeef : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>
}
/// Test constant operations with large address values
func.func @constant_large_address_ops() -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>) {
%addr_max32 = ptr.constant #ptr.address<0xFFFFFFFF> : !ptr.ptr<#ptr.generic_space>
%addr_large = ptr.constant #ptr.address<0x123456789ABCDEF0> : !ptr.ptr<#llvm.address_space<0>>
return %addr_max32, %addr_large : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#llvm.address_space<0>>
}

View File

@@ -41,10 +41,10 @@ llvm.func @type_offset(%arg0: !ptr.ptr<#llvm.address_space<0>>) -> !llvm.struct<
%2 = ptr.type_offset i16 : i32
%3 = ptr.type_offset i32 : i32
%4 = llvm.mlir.poison : !llvm.struct<(i32, i32, i32, i32)>
%5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
%6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
%7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
%8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
%5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
%6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
%7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
%8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
llvm.return %8 : !llvm.struct<(i32, i32, i32, i32)>
}
@@ -194,7 +194,7 @@ llvm.func @scatter_ops_i64(%value: vector<8xi64>, %ptrs: vector<8x!ptr.ptr<#llvm
// CHECK-NEXT: call void @llvm.masked.store.v4f64.p3(<4 x double> %[[VALUE_F64]], ptr addrspace(3) %[[PTR_SHARED]], i32 8, <4 x i1> %[[MASK]])
// CHECK-NEXT: ret void
// CHECK-NEXT: }
llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>,
%mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) {
// Test with shared memory address space (3) and f64 elements
%0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64>
@@ -255,3 +255,29 @@ llvm.func @llvm_ops_with_ptr_nvvm_values(%arg0: !llvm.ptr) {
llvm.store %1, %arg0 : !ptr.ptr<#nvvm.memory_space<global>>, !llvm.ptr
llvm.return
}
// CHECK-LABEL: define { ptr, ptr addrspace(1), ptr addrspace(2) } @constant_address_op() {
// CHECK-NEXT: ret { ptr, ptr addrspace(1), ptr addrspace(2) } { ptr null, ptr addrspace(1) inttoptr (i64 4096 to ptr addrspace(1)), ptr addrspace(2) inttoptr (i64 3735928559 to ptr addrspace(2)) }
llvm.func @constant_address_op() ->
!llvm.struct<(!ptr.ptr<#llvm.address_space<0>>,
!ptr.ptr<#llvm.address_space<1>>,
!ptr.ptr<#llvm.address_space<2>>)> {
%0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
%1 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>>
%2 = ptr.constant #ptr.address<3735928559> : !ptr.ptr<#llvm.address_space<2>>
%3 = llvm.mlir.poison : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
%4 = llvm.insertvalue %0, %3[0] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
%5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
%6 = llvm.insertvalue %2, %5[2] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
llvm.return %6 : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)>
}
// Test gep folders.
// CHECK-LABEL: define ptr @ptr_add_cst() {
// CHECK-NEXT: ret ptr inttoptr (i64 42 to ptr)
llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
%off = llvm.mlir.constant(42 : i32) : i32
%ptr = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>>
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
}