mirror of
https://github.com/intel/llvm.git
synced 2026-01-18 07:57:36 +08:00
[mlir][sparse] introduce sparse_tensor::StorageSpecifierToLLVM pass
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D140122
This commit is contained in:
@@ -52,9 +52,9 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
|
||||
|
||||
let parameters = (ins SparseTensorEncodingAttr : $encoding);
|
||||
let builders = [
|
||||
TypeBuilder<(ins "SparseTensorEncodingAttr":$encoding)>,
|
||||
TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{
|
||||
assert(encoding && "sparse tensor encoding should not be null");
|
||||
return $_get(encoding.getContext(), encoding);
|
||||
return get(encoding.getContext(), encoding);
|
||||
}]>,
|
||||
TypeBuilderWithInferredContext<(ins "Type":$type), [{
|
||||
return get(getSparseTensorEncoding(type));
|
||||
@@ -71,6 +71,10 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
|
||||
Type getFieldType(StorageSpecifierKind kind, std::optional<APInt> dim) const;
|
||||
}];
|
||||
|
||||
// We skipped the default builder that simply takes the input sparse tensor encoding
|
||||
// attribute since we need to normalize the dimension level type and remove unrelated
|
||||
// fields that are irrelavant to sparse tensor storage scheme.
|
||||
let skipDefaultBuilders = 1;
|
||||
let assemblyFormat="`<` qualified($encoding) `>`";
|
||||
}
|
||||
|
||||
|
||||
@@ -158,6 +158,19 @@ std::unique_ptr<Pass>
|
||||
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
|
||||
bool enableConvert = true);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The SparseStorageSpecifierToLLVM pass.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class StorageSpecifierToLLVMTypeConverter : public TypeConverter {
|
||||
public:
|
||||
StorageSpecifierToLLVMTypeConverter();
|
||||
};
|
||||
|
||||
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Other rewriting rules and passes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -301,4 +301,28 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
|
||||
];
|
||||
}
|
||||
|
||||
def StorageSpecifierToLLVM : Pass<"sparse-storage-specifier-to-llvm", "ModuleOp"> {
|
||||
let summary = "Lower sparse storage specifer to llvm structure";
|
||||
let description = [{
|
||||
This pass rewrites sparse tensor storage specifier-related operations into
|
||||
LLVMDialect, and converts sparse tensor storage specifier into an llvm.struct.
|
||||
|
||||
Example of the conversion:
|
||||
```mlir
|
||||
Before:
|
||||
%0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
|
||||
: !sparse_tensor.storage_specifier<#CSR> to i64
|
||||
|
||||
After:
|
||||
%0 = llvm.extractvalue %arg0[0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
|
||||
```
|
||||
}];
|
||||
let constructor = "mlir::createStorageSpecifierToLLVMPass()";
|
||||
let dependentDialects = [
|
||||
"arith::ArithDialect",
|
||||
"LLVM::LLVMDialect",
|
||||
"sparse_tensor::SparseTensorDialect",
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
|
||||
|
||||
@@ -323,6 +323,28 @@ uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) {
|
||||
// SparseTensorDialect Types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// We normalized sparse tensor encoding attribute by always using
|
||||
/// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
|
||||
/// as other variants) lead to the same storage specifier type, and stripping
|
||||
/// irrelevant fields that does not alter the sparse tensor memory layout.
|
||||
static SparseTensorEncodingAttr
|
||||
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
|
||||
SmallVector<DimLevelType> dlts;
|
||||
for (auto dlt : enc.getDimLevelType())
|
||||
dlts.push_back(*getDimLevelType(*getLevelFormat(dlt), true, true));
|
||||
|
||||
return SparseTensorEncodingAttr::get(
|
||||
enc.getContext(), dlts,
|
||||
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
|
||||
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
|
||||
enc.getPointerBitWidth(), enc.getIndexBitWidth());
|
||||
}
|
||||
|
||||
StorageSpecifierType
|
||||
StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
|
||||
return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
|
||||
}
|
||||
|
||||
IntegerType StorageSpecifierType::getSizesType() const {
|
||||
unsigned idxBitWidth =
|
||||
getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;
|
||||
|
||||
@@ -3,10 +3,12 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
|
||||
CodegenEnv.cpp
|
||||
CodegenUtils.cpp
|
||||
SparseBufferRewriting.cpp
|
||||
SparseStorageSpecifierToLLVM.cpp
|
||||
SparseTensorCodegen.cpp
|
||||
SparseTensorConversion.cpp
|
||||
SparseTensorPasses.cpp
|
||||
SparseTensorRewriting.cpp
|
||||
SparseTensorStorageLayout.cpp
|
||||
SparseVectorization.cpp
|
||||
Sparsification.cpp
|
||||
SparsificationAndBufferizationPass.cpp
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "CodegenUtils.h"
|
||||
#include "SparseTensorStorageLayout.h"
|
||||
|
||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace sparse_tensor;
|
||||
|
||||
static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
|
||||
MLIRContext *ctx = tp.getContext();
|
||||
auto enc = tp.getEncoding();
|
||||
unsigned rank = enc.getDimLevelType().size();
|
||||
|
||||
SmallVector<Type, 2> result;
|
||||
auto indexType = tp.getSizesType();
|
||||
auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank);
|
||||
auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType,
|
||||
getNumDataFieldsFromEncoding(enc));
|
||||
result.push_back(dimSizes);
|
||||
result.push_back(memSizes);
|
||||
return result;
|
||||
}
|
||||
|
||||
static Type convertSpecifier(StorageSpecifierType tp) {
|
||||
return LLVM::LLVMStructType::getLiteral(tp.getContext(),
|
||||
getSpecifierFields(tp));
|
||||
}
|
||||
|
||||
StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
|
||||
addConversion([](Type type) { return type; });
|
||||
addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
|
||||
}
|
||||
|
||||
constexpr uint64_t kDimSizePosInSpecifier = 0;
|
||||
constexpr uint64_t kMemSizePosInSpecifier = 1;
|
||||
|
||||
class SpecifierStructBuilder : public StructBuilder {
|
||||
public:
|
||||
explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
|
||||
assert(value);
|
||||
}
|
||||
|
||||
// Undef value for dimension sizes, all zero value for memory sizes.
|
||||
static Value getInitValue(OpBuilder &builder, Location loc, Type structType);
|
||||
|
||||
Value dimSize(OpBuilder &builder, Location loc, unsigned dim);
|
||||
void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size);
|
||||
|
||||
Value memSize(OpBuilder &builder, Location loc, unsigned pos);
|
||||
void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
|
||||
};
|
||||
|
||||
Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
|
||||
Type structType) {
|
||||
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
|
||||
SpecifierStructBuilder md(metaData);
|
||||
auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
|
||||
.getBody()[kMemSizePosInSpecifier]
|
||||
.cast<LLVM::LLVMArrayType>();
|
||||
|
||||
Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
|
||||
// Fill memSizes array with zero.
|
||||
for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
|
||||
md.setMemSize(builder, loc, i, zero);
|
||||
|
||||
return md;
|
||||
}
|
||||
|
||||
/// Builds IR inserting the pos-th size into the descriptor.
|
||||
Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc,
|
||||
unsigned dim) {
|
||||
return builder.create<LLVM::ExtractValueOp>(
|
||||
loc, value, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
|
||||
}
|
||||
|
||||
/// Builds IR inserting the pos-th size into the descriptor.
|
||||
void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc,
|
||||
unsigned dim, Value size) {
|
||||
value = builder.create<LLVM::InsertValueOp>(
|
||||
loc, value, size, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
|
||||
}
|
||||
|
||||
/// Builds IR extracting the pos-th memory size into the descriptor.
|
||||
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
|
||||
unsigned pos) {
|
||||
return builder.create<LLVM::ExtractValueOp>(
|
||||
loc, value, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
|
||||
}
|
||||
|
||||
/// Builds IR inserting the pos-th memory size into the descriptor.
|
||||
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
|
||||
unsigned pos, Value size) {
|
||||
value = builder.create<LLVM::InsertValueOp>(
|
||||
loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
|
||||
}
|
||||
|
||||
template <typename Base, typename SourceOp>
|
||||
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
using OpConversionPattern<SourceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
SpecifierStructBuilder spec(adaptor.getSpecifier());
|
||||
Value v;
|
||||
if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) {
|
||||
v = Base::onDimSize(rewriter, op, spec,
|
||||
op.getDim().value().getZExtValue());
|
||||
} else {
|
||||
auto enc = op.getSpecifier().getType().getEncoding();
|
||||
builder::StorageLayout layout(enc);
|
||||
Optional<unsigned> dim = std::nullopt;
|
||||
if (op.getDim())
|
||||
dim = op.getDim().value().getZExtValue();
|
||||
unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim);
|
||||
v = Base::onMemSize(rewriter, op, spec, idx);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, v);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct StorageSpecifierSetOpConverter
|
||||
: public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
|
||||
SetStorageSpecifierOp> {
|
||||
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
|
||||
static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op,
|
||||
SpecifierStructBuilder &spec, unsigned d) {
|
||||
spec.setDimSize(builder, op.getLoc(), d, op.getValue());
|
||||
return spec;
|
||||
}
|
||||
|
||||
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
|
||||
SpecifierStructBuilder &spec, unsigned i) {
|
||||
spec.setMemSize(builder, op.getLoc(), i, op.getValue());
|
||||
return spec;
|
||||
}
|
||||
};
|
||||
|
||||
struct StorageSpecifierGetOpConverter
|
||||
: public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
|
||||
GetStorageSpecifierOp> {
|
||||
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
|
||||
static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op,
|
||||
SpecifierStructBuilder &spec, unsigned d) {
|
||||
return spec.dimSize(builder, op.getLoc(), d);
|
||||
}
|
||||
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
|
||||
SpecifierStructBuilder &spec, unsigned i) {
|
||||
return spec.memSize(builder, op.getLoc(), i);
|
||||
}
|
||||
};
|
||||
|
||||
struct StorageSpecifierInitOpConverter
|
||||
: public OpConversionPattern<StorageSpecifierInitOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
|
||||
rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
|
||||
rewriter, op.getLoc(), llvmType));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
|
||||
StorageSpecifierInitOpConverter>(converter,
|
||||
patterns.getContext());
|
||||
}
|
||||
@@ -28,6 +28,7 @@ namespace mlir {
|
||||
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
|
||||
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
|
||||
#define GEN_PASS_DEF_SPARSEVECTORIZATION
|
||||
#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
|
||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
@@ -193,9 +194,14 @@ struct SparseTensorCodegenPass
|
||||
target.addLegalOp<SortOp>();
|
||||
target.addLegalOp<SortCooOp>();
|
||||
target.addLegalOp<PushBackOp>();
|
||||
// All dynamic rules below accept new function, call, return, and various
|
||||
// tensor and bufferization operations as legal output of the rewriting
|
||||
// provided that all sparse tensor types have been fully rewritten.
|
||||
// Storage specifier outlives sparse tensor pipeline.
|
||||
target.addLegalOp<GetStorageSpecifierOp>();
|
||||
target.addLegalOp<SetStorageSpecifierOp>();
|
||||
target.addLegalOp<StorageSpecifierInitOp>();
|
||||
// All dynamic rules below accept new function, call, return, and
|
||||
// various tensor and bufferization operations as legal output of the
|
||||
// rewriting provided that all sparse tensor types have been fully
|
||||
// rewritten.
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||
return converter.isSignatureLegal(op.getFunctionType());
|
||||
});
|
||||
@@ -271,6 +277,44 @@ struct SparseVectorizationPass
|
||||
}
|
||||
};
|
||||
|
||||
struct StorageSpecifierToLLVMPass
|
||||
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
|
||||
|
||||
StorageSpecifierToLLVMPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
ConversionTarget target(*ctx);
|
||||
RewritePatternSet patterns(ctx);
|
||||
StorageSpecifierToLLVMTypeConverter converter;
|
||||
|
||||
// All ops in the sparse dialect must go!
|
||||
target.addIllegalDialect<SparseTensorDialect>();
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||
return converter.isSignatureLegal(op.getFunctionType());
|
||||
});
|
||||
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
|
||||
return converter.isSignatureLegal(op.getCalleeType());
|
||||
});
|
||||
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
|
||||
return converter.isLegal(op.getOperandTypes());
|
||||
});
|
||||
target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
|
||||
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
||||
converter);
|
||||
populateCallOpTypeConversionPattern(patterns, converter);
|
||||
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
|
||||
populateReturnOpTypeConversionPattern(patterns, converter);
|
||||
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
|
||||
target);
|
||||
populateStorageSpecifierToLLVMPatterns(converter, patterns);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -355,3 +399,7 @@ mlir::createSparseVectorizationPass(unsigned vectorLength,
|
||||
return std::make_unique<SparseVectorizationPass>(
|
||||
vectorLength, enableVLAVectorization, enableSIMDIndex32);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
|
||||
return std::make_unique<StorageSpecifierToLLVMPass>();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
//===- SparseTensorStorageLayout.cpp --------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "SparseTensorStorageLayout.h"
|
||||
#include "CodegenUtils.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace sparse_tensor;
|
||||
|
||||
static Value createIndexCast(OpBuilder &builder, Location loc, Value value,
|
||||
Type to) {
|
||||
if (value.getType() != to)
|
||||
return builder.create<arith::IndexCastOp>(loc, to, value);
|
||||
return value;
|
||||
}
|
||||
|
||||
static IntegerAttr fromOptionalInt(MLIRContext *ctx, Optional<unsigned> dim) {
|
||||
if (!dim)
|
||||
return nullptr;
|
||||
return IntegerAttr::get(IndexType::get(ctx), dim.value());
|
||||
}
|
||||
|
||||
unsigned
|
||||
builder::StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind,
|
||||
Optional<unsigned> dim) const {
|
||||
unsigned fieldIdx = -1u;
|
||||
foreachFieldInSparseTensor(
|
||||
enc,
|
||||
[dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind,
|
||||
unsigned fDim, DimLevelType dlt) -> bool {
|
||||
if ((dim && fDim == dim.value() && kind == fKind) ||
|
||||
(kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
|
||||
fieldIdx = fIdx;
|
||||
// Returns false to break the iteration.
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
assert(fieldIdx != -1u);
|
||||
return fieldIdx;
|
||||
}
|
||||
|
||||
unsigned
|
||||
builder::StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind,
|
||||
Optional<unsigned> dim) const {
|
||||
return getMemRefFieldIndex(toFieldKind(kind), dim);
|
||||
}
|
||||
|
||||
Value builder::SparseTensorSpecifier::getInitValue(OpBuilder &builder,
|
||||
Location loc,
|
||||
RankedTensorType rtp) {
|
||||
return builder.create<StorageSpecifierInitOp>(
|
||||
loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp)));
|
||||
}
|
||||
|
||||
Value builder::SparseTensorSpecifier::getSpecifierField(
|
||||
OpBuilder &builder, Location loc, StorageSpecifierKind kind,
|
||||
Optional<unsigned> dim) {
|
||||
return createIndexCast(builder, loc,
|
||||
builder.create<GetStorageSpecifierOp>(
|
||||
loc, getFieldType(kind, dim), specifier, kind,
|
||||
fromOptionalInt(specifier.getContext(), dim)),
|
||||
builder.getIndexType());
|
||||
}
|
||||
|
||||
void builder::SparseTensorSpecifier::setSpecifierField(
|
||||
OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind,
|
||||
Optional<unsigned> dim) {
|
||||
specifier = builder.create<SetStorageSpecifierOp>(
|
||||
loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim),
|
||||
createIndexCast(builder, loc, v, getFieldType(kind, dim)));
|
||||
}
|
||||
|
||||
constexpr uint64_t kDataFieldStartingIdx = 0;
|
||||
|
||||
void sparse_tensor::builder::foreachFieldInSparseTensor(
|
||||
const SparseTensorEncodingAttr enc,
|
||||
llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
|
||||
DimLevelType)>
|
||||
callback) {
|
||||
assert(enc);
|
||||
|
||||
#define RETURN_ON_FALSE(idx, kind, dim, dlt) \
|
||||
if (!(callback(idx, kind, dim, dlt))) \
|
||||
return;
|
||||
|
||||
static_assert(kDataFieldStartingIdx == 0);
|
||||
unsigned fieldIdx = kDataFieldStartingIdx;
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) {
|
||||
// Dimension level types apply in order to the reordered dimension.
|
||||
// As a result, the compound type can be constructed directly in the given
|
||||
// order.
|
||||
auto dlt = getDimLevelType(enc, r);
|
||||
if (isCompressedDLT(dlt)) {
|
||||
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
|
||||
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
|
||||
} else if (isSingletonDLT(dlt)) {
|
||||
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
|
||||
} else {
|
||||
assert(isDenseDLT(dlt)); // no fields
|
||||
}
|
||||
}
|
||||
// The values array.
|
||||
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
|
||||
DimLevelType::Undef);
|
||||
|
||||
// Put metadata at the end.
|
||||
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, -1u,
|
||||
DimLevelType::Undef);
|
||||
|
||||
#undef RETURN_ON_FALSE
|
||||
}
|
||||
|
||||
void sparse_tensor::builder::foreachFieldAndTypeInSparseTensor(
|
||||
RankedTensorType rType,
|
||||
llvm::function_ref<bool(Type, unsigned, SparseTensorFieldKind, unsigned,
|
||||
DimLevelType)>
|
||||
callback) {
|
||||
auto enc = getSparseTensorEncoding(rType);
|
||||
assert(enc);
|
||||
// Construct the basic types.
|
||||
Type idxType = enc.getIndexType();
|
||||
Type ptrType = enc.getPointerType();
|
||||
Type eltType = rType.getElementType();
|
||||
|
||||
Type metaDataType = StorageSpecifierType::get(enc);
|
||||
// memref<? x ptr> pointers
|
||||
Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
|
||||
// memref<? x idx> indices
|
||||
Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType);
|
||||
// memref<? x eltType> values
|
||||
Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
|
||||
|
||||
foreachFieldInSparseTensor(
|
||||
enc,
|
||||
[metaDataType, ptrMemType, idxMemType, valMemType,
|
||||
callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind,
|
||||
unsigned dim, DimLevelType dlt) -> bool {
|
||||
switch (fieldKind) {
|
||||
case SparseTensorFieldKind::StorageSpec:
|
||||
return callback(metaDataType, fieldIdx, fieldKind, dim, dlt);
|
||||
case SparseTensorFieldKind::PtrMemRef:
|
||||
return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt);
|
||||
case SparseTensorFieldKind::IdxMemRef:
|
||||
return callback(idxMemType, fieldIdx, fieldKind, dim, dlt);
|
||||
case SparseTensorFieldKind::ValMemRef:
|
||||
return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
|
||||
};
|
||||
llvm_unreachable("unrecognized field kind");
|
||||
});
|
||||
}
|
||||
|
||||
unsigned
|
||||
sparse_tensor::builder::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
|
||||
unsigned numFields = 0;
|
||||
foreachFieldInSparseTensor(enc,
|
||||
[&numFields](unsigned, SparseTensorFieldKind,
|
||||
unsigned, DimLevelType) -> bool {
|
||||
numFields++;
|
||||
return true;
|
||||
});
|
||||
return numFields;
|
||||
}
|
||||
|
||||
unsigned sparse_tensor::builder::getNumDataFieldsFromEncoding(
|
||||
SparseTensorEncodingAttr enc) {
|
||||
unsigned numFields = 0; // one value memref
|
||||
foreachFieldInSparseTensor(enc,
|
||||
[&numFields](unsigned fidx, SparseTensorFieldKind,
|
||||
unsigned, DimLevelType) -> bool {
|
||||
if (fidx >= kDataFieldStartingIdx)
|
||||
numFields++;
|
||||
return true;
|
||||
});
|
||||
numFields -= 1; // the last field is MetaData field
|
||||
assert(numFields ==
|
||||
builder::getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
|
||||
return numFields;
|
||||
}
|
||||
@@ -0,0 +1,361 @@
|
||||
//===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header file defines utilities for lowering and accessing sparse tensor
|
||||
// types.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
|
||||
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace sparse_tensor {
|
||||
|
||||
// FIXME: this is a tmp namespace
|
||||
namespace builder {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout
|
||||
// scheme.
|
||||
//
|
||||
// Sparse tensor storage scheme for rank-dimensional tensor is organized
|
||||
// as a single compound type with the following fields. Note that every
|
||||
// memref with ? size actually behaves as a "vector", i.e. the stored
|
||||
// size is the capacity and the used size resides in the memSizes array.
|
||||
//
|
||||
// struct {
|
||||
// ; per-dimension d:
|
||||
// ; if dense:
|
||||
// <nothing>
|
||||
// ; if compresed:
|
||||
// memref<? x ptr> pointers-d ; pointers for sparse dim d
|
||||
// memref<? x idx> indices-d ; indices for sparse dim d
|
||||
// ; if singleton:
|
||||
// memref<? x idx> indices-d ; indices for singleton dim d
|
||||
// memref<? x eltType> values ; values
|
||||
//
|
||||
// ; sparse tensor metadata
|
||||
// struct {
|
||||
// array<rank x int> dimSizes ; sizes for each dimension
|
||||
// array<n x int> memSizes; ; sizes for each data memref
|
||||
// }
|
||||
// };
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
enum class SparseTensorFieldKind : uint32_t {
|
||||
StorageSpec = 0,
|
||||
PtrMemRef = 1,
|
||||
IdxMemRef = 2,
|
||||
ValMemRef = 3
|
||||
};
|
||||
|
||||
static_assert(static_cast<uint32_t>(SparseTensorFieldKind::PtrMemRef) ==
|
||||
static_cast<uint32_t>(StorageSpecifierKind::PtrMemSize));
|
||||
static_assert(static_cast<uint32_t>(SparseTensorFieldKind::IdxMemRef) ==
|
||||
static_cast<uint32_t>(StorageSpecifierKind::IdxMemSize));
|
||||
static_assert(static_cast<uint32_t>(SparseTensorFieldKind::ValMemRef) ==
|
||||
static_cast<uint32_t>(StorageSpecifierKind::ValMemSize));
|
||||
|
||||
/// For each field that will be allocated for the given sparse tensor encoding,
|
||||
/// calls the callback with the corresponding field index, field kind, dimension
|
||||
/// (for sparse tensor level memrefs) and dimlevelType.
|
||||
/// The field index always starts with zero and increments by one between two
|
||||
/// callback invocations.
|
||||
/// Ideally, all other methods should rely on this function to query a sparse
|
||||
/// tensor fields instead of relying on ad-hoc index computation.
|
||||
void foreachFieldInSparseTensor(
|
||||
SparseTensorEncodingAttr,
|
||||
llvm::function_ref<bool(unsigned /*fieldIdx*/,
|
||||
SparseTensorFieldKind /*fieldKind*/,
|
||||
unsigned /*dim (if applicable)*/,
|
||||
DimLevelType /*DLT (if applicable)*/)>);
|
||||
|
||||
/// Same as above, except that it also builds the Type for the corresponding
|
||||
/// field.
|
||||
void foreachFieldAndTypeInSparseTensor(
|
||||
RankedTensorType,
|
||||
llvm::function_ref<bool(Type /*fieldType*/, unsigned /*fieldIdx*/,
|
||||
SparseTensorFieldKind /*fieldKind*/,
|
||||
unsigned /*dim (if applicable)*/,
|
||||
DimLevelType /*DLT (if applicable)*/)>);
|
||||
|
||||
/// Gets the total number of fields for the given sparse tensor encoding.
|
||||
unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
|
||||
|
||||
/// Gets the total number of data fields (index arrays, pointer arrays, and a
|
||||
/// value array) for the given sparse tensor encoding.
|
||||
unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc);
|
||||
|
||||
inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) {
|
||||
assert(kind != SparseTensorFieldKind::StorageSpec);
|
||||
return static_cast<StorageSpecifierKind>(kind);
|
||||
}
|
||||
|
||||
inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
|
||||
assert(kind != StorageSpecifierKind::DimSize);
|
||||
return static_cast<SparseTensorFieldKind>(kind);
|
||||
}
|
||||
|
||||
class StorageLayout {
|
||||
public:
|
||||
explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {}
|
||||
|
||||
///
|
||||
/// Getters: get the field index for required field.
|
||||
///
|
||||
unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
|
||||
Optional<unsigned> dim) const;
|
||||
|
||||
unsigned getMemRefFieldIndex(StorageSpecifierKind kind,
|
||||
Optional<unsigned> dim) const;
|
||||
|
||||
private:
|
||||
unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const;
|
||||
SparseTensorEncodingAttr enc;
|
||||
};
|
||||
|
||||
class SparseTensorSpecifier {
|
||||
public:
|
||||
explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {}
|
||||
|
||||
// Undef value for dimension sizes, all zero value for memory sizes.
|
||||
static Value getInitValue(OpBuilder &builder, Location loc,
|
||||
RankedTensorType rtp);
|
||||
|
||||
/*implicit*/ operator Value() { return specifier; }
|
||||
|
||||
Value getSpecifierField(OpBuilder &builder, Location loc,
|
||||
StorageSpecifierKind kind, Optional<unsigned> dim);
|
||||
|
||||
void setSpecifierField(OpBuilder &builder, Location loc, Value v,
|
||||
StorageSpecifierKind kind, Optional<unsigned> dim);
|
||||
|
||||
Type getFieldType(StorageSpecifierKind kind, Optional<unsigned> dim) {
|
||||
return specifier.getType().getFieldType(kind, dim);
|
||||
}
|
||||
|
||||
private:
|
||||
TypedValue<StorageSpecifierType> specifier;
|
||||
};
|
||||
|
||||
/// A helper class around an array of values that corresponding to a sparse
|
||||
/// tensor, provides a set of meaningful APIs to query and update a particular
|
||||
/// field in a consistent way.
|
||||
/// Users should not make assumption on how a sparse tensor is laid out but
|
||||
/// instead relies on this class to access the right value for the right field.
|
||||
template <bool mut>
|
||||
class SparseTensorDescriptorImpl {
|
||||
private:
|
||||
// Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
|
||||
// for mutable descriptors.
|
||||
// Using SmallVector for mutable descriptor allows users to reuse it as a tmp
|
||||
// buffers to append value for some special cases, though users should be
|
||||
// responsible to restore the buffer to legal states after their use. It is
|
||||
// probably not a clean way, but it is the most efficient way to avoid copying
|
||||
// the fields into another SmallVector. If a more clear way is wanted, we
|
||||
// should change it to MutableArrayRef instead.
|
||||
using ValueArrayRef = typename std::conditional<mut, SmallVectorImpl<Value> &,
|
||||
ValueRange>::type;
|
||||
|
||||
public:
|
||||
SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
|
||||
: rType(tp.cast<RankedTensorType>()), fields(fields) {
|
||||
assert(getSparseTensorEncoding(tp) &&
|
||||
builder::getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
|
||||
fields.size());
|
||||
// We should make sure the class is trivially copyable (and should be small
|
||||
// enough) such that we can pass it by value.
|
||||
static_assert(
|
||||
std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
|
||||
}
|
||||
|
||||
// Implicit (and cheap) type conversion from MutSparseTensorDescriptor to
|
||||
// SparseTensorDescriptor.
|
||||
template <typename T = SparseTensorDescriptorImpl<true>>
|
||||
/*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t<!mut, T> &mDesc)
|
||||
: rType(mDesc.getTensorType()), fields(mDesc.getFields()) {}
|
||||
|
||||
unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
|
||||
Optional<unsigned> dim) const {
|
||||
// Delegates to storage layout.
|
||||
StorageLayout layout(getSparseTensorEncoding(rType));
|
||||
return layout.getMemRefFieldIndex(kind, dim);
|
||||
}
|
||||
|
||||
unsigned getPtrMemRefIndex(unsigned ptrDim) const {
|
||||
return getMemRefFieldIndex(SparseTensorFieldKind::PtrMemRef, ptrDim);
|
||||
}
|
||||
|
||||
unsigned getIdxMemRefIndex(unsigned idxDim) const {
|
||||
return getMemRefFieldIndex(SparseTensorFieldKind::IdxMemRef, idxDim);
|
||||
}
|
||||
|
||||
unsigned getValMemRefIndex() const {
|
||||
return getMemRefFieldIndex(SparseTensorFieldKind::ValMemRef, std::nullopt);
|
||||
}
|
||||
|
||||
unsigned getNumFields() const { return fields.size(); }
|
||||
|
||||
///
|
||||
/// Getters: get the value for required field.
|
||||
///
|
||||
|
||||
Value getSpecifierField(OpBuilder &builder, Location loc,
|
||||
StorageSpecifierKind kind,
|
||||
Optional<unsigned> dim) const {
|
||||
SparseTensorSpecifier md(fields.back());
|
||||
return md.getSpecifierField(builder, loc, kind, dim);
|
||||
}
|
||||
|
||||
Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const {
|
||||
return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim);
|
||||
}
|
||||
|
||||
Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
|
||||
return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
|
||||
dim);
|
||||
}
|
||||
|
||||
Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
|
||||
return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
|
||||
dim);
|
||||
}
|
||||
|
||||
Value getValMemSize(OpBuilder &builder, Location loc) const {
|
||||
return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
Value getPtrMemRef(unsigned ptrDim) const {
|
||||
return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim);
|
||||
}
|
||||
|
||||
Value getIdxMemRef(unsigned idxDim) const {
|
||||
return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxDim);
|
||||
}
|
||||
|
||||
Value getValMemRef() const {
|
||||
return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
|
||||
}
|
||||
|
||||
Value getMemRefField(SparseTensorFieldKind kind,
|
||||
Optional<unsigned> dim) const {
|
||||
return fields[getMemRefFieldIndex(kind, dim)];
|
||||
}
|
||||
|
||||
Value getMemRefField(unsigned fidx) const {
|
||||
assert(fidx < fields.size() - 1);
|
||||
return fields[fidx];
|
||||
}
|
||||
|
||||
Value getField(unsigned fidx) const {
|
||||
assert(fidx < fields.size());
|
||||
return fields[fidx];
|
||||
}
|
||||
|
||||
///
|
||||
/// Setters: update the value for required field (only enabled for
|
||||
/// MutSparseTensorDescriptor).
|
||||
///
|
||||
|
||||
template <typename T = Value>
|
||||
void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim,
|
||||
std::enable_if_t<mut, T> v) {
|
||||
fields[getMemRefFieldIndex(kind, dim)] = v;
|
||||
}
|
||||
|
||||
template <typename T = Value>
|
||||
void setMemRefField(unsigned fidx, std::enable_if_t<mut, T> v) {
|
||||
assert(fidx < fields.size() - 1);
|
||||
fields[fidx] = v;
|
||||
}
|
||||
|
||||
template <typename T = Value>
|
||||
void setField(unsigned fidx, std::enable_if_t<mut, T> v) {
|
||||
assert(fidx < fields.size());
|
||||
fields[fidx] = v;
|
||||
}
|
||||
|
||||
template <typename T = Value>
|
||||
void setSpecifierField(OpBuilder &builder, Location loc,
|
||||
StorageSpecifierKind kind, Optional<unsigned> dim,
|
||||
std::enable_if_t<mut, T> v) {
|
||||
SparseTensorSpecifier md(fields.back());
|
||||
md.setSpecifierField(builder, loc, v, kind, dim);
|
||||
fields.back() = md;
|
||||
}
|
||||
|
||||
template <typename T = Value>
|
||||
void setDimSize(OpBuilder &builder, Location loc, unsigned dim,
|
||||
std::enable_if_t<mut, T> v) {
|
||||
setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
|
||||
}
|
||||
|
||||
ValueRange getMemRefFields() const {
|
||||
ValueRange ret = fields;
|
||||
// drop the last metadata fields
|
||||
return ret.slice(0, fields.size() - 1);
|
||||
}
|
||||
|
||||
Type getMemRefElementType(SparseTensorFieldKind kind,
|
||||
Optional<unsigned> dim) const {
|
||||
return getMemRefField(kind, dim)
|
||||
.getType()
|
||||
.template cast<MemRefType>()
|
||||
.getElementType();
|
||||
}
|
||||
|
||||
RankedTensorType getTensorType() const { return rType; }
|
||||
ValueArrayRef getFields() const { return fields; }
|
||||
|
||||
private:
|
||||
RankedTensorType rType;
|
||||
ValueArrayRef fields;
|
||||
};
|
||||
|
||||
using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
|
||||
using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
|
||||
|
||||
/// Returns the "tuple" value of the adapted tensor.
|
||||
inline UnrealizedConversionCastOp getTuple(Value tensor) {
|
||||
return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
|
||||
}
|
||||
|
||||
/// Packs the given values as a "tuple" value.
|
||||
inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
|
||||
ValueRange values) {
|
||||
return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
inline Value genTuple(OpBuilder &builder, Location loc,
|
||||
SparseTensorDescriptor desc) {
|
||||
return genTuple(builder, loc, desc.getTensorType(), desc.getFields());
|
||||
}
|
||||
|
||||
inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
|
||||
auto tuple = getTuple(tensor);
|
||||
return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
|
||||
}
|
||||
|
||||
inline MutSparseTensorDescriptor
|
||||
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
|
||||
auto tuple = getTuple(tensor);
|
||||
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
|
||||
return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields);
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace sparse_tensor
|
||||
} // namespace mlir
|
||||
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
|
||||
@@ -147,6 +147,7 @@ public:
|
||||
} else {
|
||||
pm.addPass(createSparseTensorCodegenPass(enableBufferInitialization));
|
||||
pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
|
||||
pm.addPass(createStorageSpecifierToLLVMPass());
|
||||
}
|
||||
if (failed(runPipeline(pm, getOperation())))
|
||||
return signalPassFailure();
|
||||
|
||||
38
mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
Normal file
38
mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
Normal file
@@ -0,0 +1,38 @@
|
||||
// RUN: mlir-opt %s -sparse-storage-specifier-to-llvm --cse --canonicalize | FileCheck %s
|
||||
|
||||
#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
|
||||
|
||||
// CHECK-LABEL: func.func @sparse_metadata_init() -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
|
||||
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_1]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
|
||||
// CHECK: return %[[VAL_4]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
|
||||
// CHECK: }
|
||||
func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#CSR> {
|
||||
%0 = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#CSR>
|
||||
return %0 : !sparse_tensor.storage_specifier<#CSR>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @sparse_get_md(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> i64 {
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
|
||||
// CHECK: return %[[VAL_1]] : i64
|
||||
func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> i64 {
|
||||
%0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
|
||||
: !sparse_tensor.storage_specifier<#CSR> to i64
|
||||
return %0 : i64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @sparse_set_md(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
|
||||
// CHECK: return %[[VAL_2]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
|
||||
func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: i64)
|
||||
-> !sparse_tensor.storage_specifier<#CSR> {
|
||||
%0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
|
||||
: i64, !sparse_tensor.storage_specifier<#CSR>
|
||||
return %0 : !sparse_tensor.storage_specifier<#CSR>
|
||||
}
|
||||
Reference in New Issue
Block a user